In [4]:
import pandas as pd
import torch
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

In [5]:
atom = pd.read_csv(
    "dataset/Mutagenesis-42/atoms.csv",
    delimiter=";",
)
bond = pd.read_csv("dataset/Mutagenesis-42/bonds.csv", delimiter=";")
molecule = pd.read_csv("dataset/Mutagenesis-42/drugs.csv", delimiter=";")
atom.head()

Unnamed: 0,id,drug_id,element,atom_type,charge
0,0,0,c,22,-0.11
1,1,0,c,22,-0.11
2,2,0,c,22,-0.11
3,3,0,c,22,-0.11
4,4,0,h,3,0.15


In [6]:
atom["atom_type"].max()

195

In [7]:
molecule.head()

Unnamed: 0,id,ind1,inda,act,logp,lumo,active
0,0,0.0,0.0,-0.7,2.29,-3.025,0
1,1,0.0,0.0,0.57,2.13,-0.798,1
2,2,1.0,0.0,0.77,4.35,-2.138,1
3,3,1.0,0.0,-0.22,5.41,-1.429,0
4,4,1.0,0.0,-0.22,5.41,-1.478,0


In [8]:
molecule_ids = molecule["id"].unique().tolist()
molecule_ids[:3]

[0, 1, 2]

In [9]:
edges_dict = bond.values.tolist()
edges_dict = {(s[2], s[3]): s[4] for s in edges_dict}

In [10]:
edges_dict

{(0, 11): 7,
 (11, 19): 7,
 (19, 20): 7,
 (20, 21): 7,
 (21, 22): 7,
 (22, 0): 7,
 (11, 23): 1,
 (22, 24): 1,
 (19, 25): 7,
 (25, 1): 7,
 (1, 2): 7,
 (2, 3): 7,
 (3, 20): 7,
 (25, 4): 1,
 (2, 5): 1,
 (3, 6): 1,
 (1, 7): 1,
 (21, 8): 1,
 (0, 9): 1,
 (8, 10): 2,
 (8, 12): 2,
 (6, 13): 2,
 (6, 14): 2,
 (7, 15): 2,
 (7, 16): 2,
 (9, 17): 2,
 (9, 18): 2,
 (26, 36): 7,
 (36, 37): 7,
 (37, 38): 7,
 (38, 39): 7,
 (39, 40): 7,
 (40, 26): 7,
 (37, 41): 7,
 (41, 42): 7,
 (38, 43): 7,
 (43, 42): 7,
 (40, 27): 1,
 (27, 28): 2,
 (27, 29): 2,
 (41, 30): 1,
 (26, 31): 1,
 (36, 32): 1,
 (39, 33): 1,
 (42, 34): 1,
 (43, 35): 1,
 (44, 55): 7,
 (55, 66): 7,
 (66, 68): 7,
 (68, 69): 7,
 (69, 70): 7,
 (70, 44): 7,
 (68, 71): 7,
 (71, 72): 7,
 (72, 73): 7,
 (73, 45): 7,
 (45, 69): 7,
 (71, 46): 7,
 (46, 47): 7,
 (47, 48): 7,
 (48, 49): 7,
 (49, 72): 7,
 (66, 50): 7,
 (50, 51): 7,
 (51, 46): 7,
 (47, 52): 7,
 (51, 53): 7,
 (53, 52): 7,
 (44, 54): 1,
 (55, 56): 1,
 (70, 57): 1,
 (73, 58): 1,
 (45, 59): 1,
 (48

In [95]:
def get_molecule_data():
    data = []
    for sample_id in tqdm(molecule_ids):
        # 首先获取节点信息
        nodes = atom["id"].loc[atom["drug_id"] == sample_id].tolist()
        nodes = sorted(nodes)

        # 构建邻接矩阵
        adj_matrix = np.zeros((len(nodes), len(nodes)))
        # adj_matrix[-1, :-1] = 1  # 构造一个超级节点
        for i in range(len(nodes)):
            for j in range(len(nodes)):
                if i == j:
                    continue
                if (nodes[i], nodes[j]) in edges_dict:
                    adj_matrix[i][j] = edges_dict[nodes[i], nodes[j]]
        # 获取前全局特征
        ind1 = molecule["ind1"].loc[molecule["id"] == sample_id].tolist()[0]
        inda = molecule["inda"].loc[molecule["id"] == sample_id].tolist()[0]
        logp = molecule["logp"].loc[molecule["id"] == sample_id].tolist()[0]
        lumo = molecule["lumo"].loc[molecule["id"] == sample_id].tolist()[0]
        global_feature = {"ind1": ind1, "inda": inda, "logp": logp, "lumo": lumo}
        # 获取每个节点的特征
        node_features = []
        mean_charge = 0
        for node in nodes:
            node_type = atom["atom_type"].loc[atom["id"] == node].tolist()[0]
            node_charge = atom["charge"].loc[atom["id"] == node].tolist()[0]
            mean_charge += node_charge
            node_feature = {"type": node_type, "charge": node_charge}
            node_features.append(node_feature)
        # mean_charge /= len(nodes)
        # node_features.append({"type": 1, "charge": mean_charge})  # 超级节点
        # node
        sample = dict()
        sample["drug_id"] = sample_id
        nodes = np.array(nodes) - nodes[0]
        nodes = nodes.astype(np.int64).tolist()
        # nodes.append(len(nodes) - 1)
        sample["nodes"] = nodes
        sample["adj_matrix"] = adj_matrix
        sample["global_features"] = global_feature
        sample["node_features"] = node_features
        sample["label"] = (
            molecule["active"].loc[molecule["id"] == sample_id].tolist()[0]
        )
        data.append(sample)
    return data

In [96]:
data = get_molecule_data()

100%|██████████| 42/42 [00:00<00:00, 44.31it/s]


In [97]:
data[0]["adj_matrix"]

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 7., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 7., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 7., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 7., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2., 2., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 2.,
        2., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 

In [98]:
import torch.nn as nn
from torch.nn import functional as F

In [99]:
class Phi(nn.Module):
    def __init__(
        self,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        embed_dim: int,
    ):
        super(Phi, self).__init__()
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.ind1_embedding = nn.Embedding(num_ind1_type, embed_dim)
        self.inda_embedding = nn.Embedding(num_inda_type, embed_dim)
        self.linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, l_n, l_n_u, l_u):
        """transition network，用于生成转换矩阵 ${A}_{n,u}$
            转换矩阵的作用是
        Args:
            l_n (_type_): 节点的特征 包括：
                1. atom_type class
                2. charge float
                3. ind1 class
                4. inda class
                5. logp float
                6. lumo float
            l_n_u (): 边的特征
            l_u (_type_): 邻居节点的特征
        """
        pass


class HwNonLinear(nn.Module):
    def __init__(
        self,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        num_edge_type: int,
        embed_dim: int,
    ):
        """非线性的GNN的H_w函数实现，在利用节点、边、邻居节点的特征构造相关向量，
        然后拼接/加和节点状态向量，使用三层的FNN网络计算新的节点状态。
        Args:
            num_atom_type (int): 原子的种类数量
            num_ind1_type (int): ind1的种类数量
            num_inda_type (int): inda的种类数量
            embed_dim (int): 嵌入的维度
        """
        super(HwNonLinear, self).__init__()
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.ind1_embedding = nn.Embedding(num_ind1_type, embed_dim)
        self.inda_embedding = nn.Embedding(num_inda_type, embed_dim)
        self.edge_embedding = nn.Embedding(num_edge_type + 1, embed_dim)
        self.logp_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.lumo_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.charge_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.fnn_1 = nn.Linear(embed_dim, embed_dim)
        self.fnn_2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, l_n, adj_matrix):
        """前向传播，根据节点原来的状态向量、节点的特征l_n、邻居节点的特征l_u、边的特征l_{n,u}
        计算新的节点状态向量。为了实现高效的计算，使用批量计算的方式，并且使用平均的方式进行融合。
        传入的l_n是一张图的所有节点的特征，邻居节点的特征也可以从中获取，邻居可以有邻接矩阵得到。
        1. 首先对边进行嵌入，直接使用邻接矩阵进行嵌入，将0也看作一个类别，得到一个 n x n x e的张量 \n
        2. 然后对节点进行嵌入，对每个都进行嵌入，然后加和得到一个 n x e的张量 \n
        3. 将节点特征和边特征进行加和，得到一个 n x n x e 的张量 \n
        4. 根据邻接矩阵得到mask矩阵，根据mask矩阵在第二个维度上求和，得到 n x e的张量 \n
        5. 将上述得到的向量和x拼接，得到一个 n x (2e)的张量 \n
        6. 经过一个三层的FNN网络，得到一个 n x e的张量，即新的节点状态向量 \n
        Args:
            l_n (_type_): 节点的特征 包括：
                dict:
                1. atom_type class [n]
                2. charge float [n]
                3. ind1 class [n]
                4. inda class [n]
                5. logp float [n]
                6. lumo float [n]
            adj_matrix (): 邻接矩阵

        Returns:
            _type_: 新的节点状态向量
        """
        atom_embed = self.atom_embedding(l_n["atom_type"])  # [n,e]
        ind1_embed = self.ind1_embedding(l_n["ind1"])
        inda_embed = self.inda_embedding(l_n["inda"])
        logp_embed = self.logp_embedding * l_n["logp"].unsqueeze(dim=-1)
        lumo_embed = self.lumo_embedding * l_n["lumo"].unsqueeze(dim=-1)  # [n,e]
        charge_embed = self.charge_embedding * l_n["charge"].unsqueeze(dim=-1)
        # 计算节点特征
        node_embed = (
            atom_embed
            + ind1_embed
            + inda_embed
            + logp_embed
            + lumo_embed
            + charge_embed
        ) / 6  # [n,e]
        # 计算边特征
        edge_embed = self.edge_embedding(adj_matrix)  # [n,n,e]
        # 计算节点特征和边特征的加和
        # 计算mask矩阵
        mask = adj_matrix.unsqueeze(dim=-1) != 0  # [n,n,1]
        mask = mask.float()
        # 计算mask矩阵在第二个维度上的和
        edge_embed = torch.sum(edge_embed * mask, dim=1) / (
            mask.squeeze().sum(dim=1, keepdim=True) + 1e-10
        )  # [n,e]
        # 拼接节点特征和边特征
        # x = torch.cat([x, node_edge_embed], dim=-1)  # [n,2e]
        node_agg = (mask.squeeze() @ (x + node_embed)) / (
            mask.squeeze().sum(dim=1, keepdim=True) + 1e-10
        )
        x = (node_agg + edge_embed) / 2
        # 经过一个三层的FNN网络
        x = torch.tanh(self.fnn_1(x))  # [n,e]
        # x = torch.tanh(self.fnn_2(x))  # [n,e]
        return x

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class GNN(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        num_edge_type: int,
        t,
    ):
        super(GNN, self).__init__()
        self.embed_dim = embed_dim
        self.t = t
        self.hw = HwNonLinear(
            num_atom_type, num_ind1_type, num_inda_type, num_edge_type, embed_dim
        )
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, 1),
        )
        self.criterion = nn.BCEWithLogitsLoss()
        self.l2_reg = nn.MSELoss()

    def contraction_penalty(self, params, threshold=0.9):
        """计算收缩映射的惩罚项"""
        penalty = 0
        for param in params:
            # 计算参数的范数
            norm = torch.norm(param, p=2)
            # 如果范数大于阈值，则添加惩罚项
            penalty += torch.pow(torch.relu(norm - threshold), 2)
        return penalty

    def forward(self, l_n, adj_matrix, label):
        x = torch.zeros(adj_matrix.shape[0], self.embed_dim, requires_grad=False)
        constraint_loss = 0
        for i in range(self.t):
            x2 = self.hw(x, l_n, adj_matrix)
            constraint_loss += torch.norm(x2 - x, p=1)
            x = x2
        constraint_loss = 0
        # x = x.sum(dim=0)
        index = adj_matrix.sum(dim=1).argmax()
        logits = self.output_layer(x[index])
        label[label == 0] = -1
        # loss = self.criterion(logits, label.float())
        hw_params = self.hw.get_parameters()
        penalty = self.contraction_penalty(hw_params, threshold=1)
        loss = (
            F.mse_loss(logits, label.float())
            + penalty * 0.1
            + constraint_loss / self.t
            + 0
        )
        return logits, loss
        # return logits, loss

In [100]:
class MyDataset(Dataset):
    def __init__(self, data, device):
        super(MyDataset, self).__init__()
        self.data = data
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        nodes = sample["nodes"]
        adj_matrix = sample["adj_matrix"]
        label = sample["label"]
        atom_type = [t["type"] for t in sample["node_features"]]
        charge = [t["charge"] for t in sample["node_features"]]
        ind1 = [sample["global_features"]["ind1"]] * len(nodes)
        inda = [sample["global_features"]["inda"]] * len(nodes)
        logp = [sample["global_features"]["logp"]] * len(nodes)
        lumo = [sample["global_features"]["lumo"]] * len(nodes)
        l_n = {
            "atom_type": torch.LongTensor(atom_type).to(self.device),
            "charge": torch.tensor(charge).to(self.device),
            "ind1": torch.LongTensor(ind1).to(self.device),
            "inda": torch.LongTensor(inda).to(self.device),
            "logp": torch.tensor(logp).to(self.device),
            "lumo": torch.tensor(lumo).to(self.device),
        }
        adj_matrix = torch.LongTensor(adj_matrix).to(self.device)
        label = torch.tensor([label]).to(self.device)
        return {"l_n": l_n, "adj_matrix": adj_matrix, "label": label}

    def collate_fn(batch):
        l_n = [f["l_n"] for f in batch]
        adj_matrix = [f["adj_matrix"] for f in batch]
        label = [f["label"] for f in batch]
        return l_n, adj_matrix, label

In [101]:
train_data, test_data = train_test_split(data, test_size=0.1, random_state=100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [102]:
train_dataset = MyDataset(train_data, device)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=MyDataset.collate_fn,
)
meta_data = {
    "num_atom_type": 233,
    "num_ind1_type": 2,
    "num_inda_type": 2,
    "num_edge_type": 8,
}

In [103]:
model = GNN(embed_dim=100, t=10, **meta_data)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [104]:
def train(model, optimizer, train_dataset, epochs):

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=MyDataset.collate_fn,
    )
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epochs, gamma=0.7)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            inputs = {
                "l_n": batch[0][0],
                "adj_matrix": batch[1][0],
                "label": batch[2][0],
            }
            optimizer.zero_grad()
            logits, loss = model(**inputs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        # scheduler.step()
        # print(f"Epoch:{epoch+1}, Loss:{total_loss}")

In [105]:
test_dataset = MyDataset(test_data, device)

In [106]:
from torchmetrics import Accuracy

# acc_cal = Accuracy(task="binary", threshold=0)

In [107]:
def eval(model, test_dataset):
    model.eval()
    test_loader = DataLoader(
        test_dataset, batch_size=1, shuffle=False, collate_fn=MyDataset.collate_fn
    )
    labels = []
    preds = []
    for batch in test_loader:
        inputs = {
            "l_n": batch[0][0],
            "adj_matrix": batch[1][0],
            "label": batch[2][0],
        }
        with torch.no_grad():
            logits, _ = model(**inputs)
        labels.append(inputs["label"])
        preds.append(logits)
    labels = torch.stack(labels)
    preds = torch.stack(preds)
    labels = labels.view(-1)
    preds = preds.view(-1)
    # acc = acc_cal(preds, labels)
    labels[labels == -1] = 0
    preds = preds > 0
    acc = (preds == labels).sum() / len(labels)
    return acc, preds, labels

In [108]:
def eval(model, test_dataset):
    model.eval()
    test_loader = DataLoader(
        test_dataset, batch_size=1, shuffle=False, collate_fn=MyDataset.collate_fn
    )
    labels = []
    preds = []
    for batch in test_loader:
        inputs = {
            "l_n": batch[0][0],
            "adj_matrix": batch[1][0],
            "label": batch[2][0],
        }
        with torch.no_grad():
            logits, _ = model(**inputs)
        labels.append(inputs["label"])
        preds.append(torch.sigmoid(logits))
    labels = torch.stack(labels)
    preds = torch.stack(preds)
    labels = labels.view(-1)
    preds = preds.view(-1)
    # acc = acc_cal(preds, labels)
    # labels[labels==-1] = 0
    preds = preds > 0.5
    acc = (preds == labels).sum() / len(labels)
    return acc, preds, labels

In [109]:
train(model, optimizer, train_dataset, epochs=10)
acc, preds, labels = eval(model, test_dataset)

In [91]:
acc

tensor(0.7500)

In [92]:
preds, labels

(tensor([False, False, False, False, False]), tensor([-1,  1, -1,  1, -1]))

In [115]:
class Phi(nn.Module):
    def __init__(
        self,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        embed_dim: int,
    ):
        super(Phi, self).__init__()
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.ind1_embedding = nn.Embedding(num_ind1_type, embed_dim)
        self.inda_embedding = nn.Embedding(num_inda_type, embed_dim)
        self.linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, l_n, l_n_u, l_u):
        """transition network，用于生成转换矩阵 ${A}_{n,u}$
            转换矩阵的作用是
        Args:
            l_n (_type_): 节点的特征 包括：
                1. atom_type class
                2. charge float
                3. ind1 class
                4. inda class
                5. logp float
                6. lumo float
            l_n_u (): 边的特征
            l_u (_type_): 邻居节点的特征
        """
        pass


class HwNonLinear(nn.Module):
    def __init__(
        self,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        num_edge_type: int,
        embed_dim: int,
    ):
        """非线性的GNN的H_w函数实现，在利用节点、边、邻居节点的特征构造相关向量，
        然后拼接/加和节点状态向量，使用三层的FNN网络计算新的节点状态。
        Args:
            num_atom_type (int): 原子的种类数量
            num_ind1_type (int): ind1的种类数量
            num_inda_type (int): inda的种类数量
            embed_dim (int): 嵌入的维度
        """
        super(HwNonLinear, self).__init__()
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.ind1_embedding = nn.Embedding(num_ind1_type, embed_dim)
        self.inda_embedding = nn.Embedding(num_inda_type, embed_dim)
        self.edge_embedding = nn.Embedding(num_edge_type + 1, embed_dim)
        self.logp_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.lumo_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.charge_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.trans = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
        )
        self._init()

    def _init(self):
        """初始化参数"""
        nn.init.xavier_uniform_(self.atom_embedding.weight)
        nn.init.xavier_uniform_(self.ind1_embedding.weight)
        nn.init.xavier_uniform_(self.inda_embedding.weight)
        nn.init.xavier_uniform_(self.edge_embedding.weight)

    def forward(self, x, l_n, adj_matrix):
        """前向传播，根据节点原来的状态向量、节点的特征l_n、邻居节点的特征l_u、边的特征l_{n,u}
        计算新的节点状态向量。为了实现高效的计算，使用批量计算的方式，并且使用平均的方式进行融合。
        传入的l_n是一张图的所有节点的特征，邻居节点的特征也可以从中获取，邻居可以有邻接矩阵得到。
        1. 首先对边进行嵌入，直接使用邻接矩阵进行嵌入，将0也看作一个类别，得到一个 n x n x e的张量 \n
        2. 然后对节点进行嵌入，对每个都进行嵌入，然后加和得到一个 n x e的张量 \n
        3. 将节点特征和边特征进行加和，得到一个 n x n x e 的张量 \n
        4. 根据邻接矩阵得到mask矩阵，根据mask矩阵在第二个维度上求和，得到 n x e的张量 \n
        5. 将上述得到的向量和x拼接，得到一个 n x (2e)的张量 \n
        6. 经过一个三层的FNN网络，得到一个 n x e的张量，即新的节点状态向量 \n
        Args:
            l_n (_type_): 节点的特征 包括：
                dict:
                1. atom_type class [n]
                2. charge float [n]
                3. ind1 class [n]
                4. inda class [n]
                5. logp float [n]
                6. lumo float [n]
            adj_matrix (): 邻接矩阵

        Returns:
            _type_: 新的节点状态向量
        """
        atom_embed = self.atom_embedding(l_n["atom_type"])  # [n,e]
        ind1_embed = self.ind1_embedding(l_n["ind1"])
        inda_embed = self.inda_embedding(l_n["inda"])
        logp_embed = self.logp_embedding * l_n["logp"].unsqueeze(dim=-1)
        lumo_embed = self.lumo_embedding * l_n["lumo"].unsqueeze(dim=-1)  # [n,e]
        charge_embed = self.charge_embedding * l_n["charge"].unsqueeze(dim=-1)
        l_n = (
            atom_embed
            + ind1_embed
            + inda_embed
            + logp_embed
            + lumo_embed
            + charge_embed
        )
        l_n_u = self.edge_embedding(adj_matrix)  # [n,n,e]
        x_u = x
        mask = adj_matrix.unsqueeze(dim=-1) != 0  # [n,n,1]
        mask = mask.float()
        x_list = []
        for i in range(x.shape[0]):
            l_n_i = l_n[i].unsqueeze(dim=0).repeat(x.shape[0], 1)  # [n,e]
            s = l_n_i + l_n_u[i] + x_u + l_n  # [n,e]
            s = self.trans(s)
            s = s * mask[i]  # [n,e]
            s = torch.sum(s, dim=0)
            x_list.append(s)
        x = torch.stack(x_list, dim=0)
        return x

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class GNN(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        num_edge_type: int,
        t,
    ):
        super(GNN, self).__init__()
        self.embed_dim = embed_dim
        self.t = t
        self.hw = HwNonLinear(
            num_atom_type, num_ind1_type, num_inda_type, num_edge_type, embed_dim
        )
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, 1),
        )
        self.criterion = nn.BCEWithLogitsLoss()
        self.l2_reg = nn.MSELoss()

    def contraction_penalty(self, params, threshold=0.9):
        """计算收缩映射的惩罚项"""
        penalty = 0
        for param in params:
            # 计算参数的范数
            norm = torch.norm(param, p=2)
            # 如果范数大于阈值，则添加惩罚项
            penalty += torch.pow(torch.relu(norm - threshold), 2)
        return penalty

    def forward(self, l_n, adj_matrix, label):
        x = torch.randn(adj_matrix.shape[0], self.embed_dim, requires_grad=False)
        constraint_loss = 0
        # x2 = self.hw(x, l_n, adj_matrix)
        # while torch.norm(x2 - x, p=2) > 1e-5:
        #     x = x2
        for i in range(self.t):
            x2 = self.hw(x, l_n, adj_matrix)
            constraint_loss += torch.norm(x2 - x, p=1)
            x = x2
        constraint_loss = 0
        x = x2.sum(dim=0)
        logits = self.output_layer(x)
        loss = self.criterion(logits, label.float())
        hw_params = self.hw.get_parameters()
        penalty = self.contraction_penalty(hw_params, threshold=1)
        loss = self.criterion(logits, label.float())
        return logits, loss

In [116]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=10, shuffle=True, random_state=100)
acc_scores = []
for train_index, test_index in tqdm(kf.split(data)):
    model = GNN(embed_dim=100, t=10, **meta_data)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.001,
    )
    data_train = [data[i] for i in train_index]
    data_test = [data[i] for i in test_index]
    train_dataset = MyDataset(data_train, device)
    train(model, optimizer, train_dataset, epochs=10)
    test_dataset = MyDataset(data_test, device)
    acc, _, _ = eval(model, test_dataset)
    print(acc)
    acc_scores.append(acc.item())

1it [00:37, 37.96s/it]

tensor(0.8000)


2it [01:12, 36.24s/it]

tensor(0.6000)


3it [01:54, 38.42s/it]

tensor(1.)


4it [02:36, 39.92s/it]

tensor(0.7500)


5it [03:10, 37.97s/it]

tensor(0.7500)


6it [03:40, 35.34s/it]

tensor(0.7500)


7it [04:10, 33.39s/it]

tensor(1.)


8it [04:38, 31.66s/it]

tensor(0.7500)


9it [05:06, 30.61s/it]

tensor(1.)


10it [05:33, 33.30s/it]

tensor(0.7500)





In [114]:
np.array(acc_scores).mean()

0.7000000029802322