In [1]:
# 使用批量训练的方式进行训练

In [1]:
import pandas as pd
import torch
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn as nn
from torch.nn import functional as F
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data, Batch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
atom = pd.read_csv(
    "dataset/Mutagenesis-42/atoms.csv",
    delimiter=";",
)
bond = pd.read_csv("dataset/Mutagenesis-42/bonds.csv", delimiter=";")
molecule = pd.read_csv("dataset/Mutagenesis-42/drugs.csv", delimiter=";")
atom.head()
elements = atom["element"].unique()
element2id = {element: i for i, element in enumerate(elements)}
atom["element_type"] = atom["element"].map(element2id)
atom.head()

Unnamed: 0,id,drug_id,element,atom_type,charge,element_type
0,0,0,c,22,-0.11,0
1,1,0,c,22,-0.11,0
2,2,0,c,22,-0.11,0
3,3,0,c,22,-0.11,0
4,4,0,h,3,0.15,1


In [7]:
atom["atom_type"].unique().shape

(26,)

In [4]:
element2id

{'c': 0, 'h': 1, 'n': 2, 'o': 3, 'cl': 4, 'f': 5, 's': 6}

In [6]:
molecule.head()

Unnamed: 0,id,ind1,inda,act,logp,lumo,active
0,0,0.0,0.0,-0.7,2.29,-3.025,0
1,1,0.0,0.0,0.57,2.13,-0.798,1
2,2,1.0,0.0,0.77,4.35,-2.138,1
3,3,1.0,0.0,-0.22,5.41,-1.429,0
4,4,1.0,0.0,-0.22,5.41,-1.478,0


In [7]:
bond.head(2)

Unnamed: 0,id,drug_id,atom1_id,atom2_id,bond_type
0,0,0,0,11,7
1,1,0,11,19,7


In [8]:
bond["atom2_id"].loc[bond["atom1_id"] == 11]

1    19
6    23
Name: atom2_id, dtype: int64

In [9]:
molecule_ids = list(range(molecule.shape[0]))

In [10]:
edges_dict = bond.values.tolist()
edges_dict = {(s[2], s[3]): s[4] for s in edges_dict}

In [11]:
def get_molecule_data():
    data = []
    for sample_id in tqdm(molecule_ids):
        # 首先获取节点信息
        nodes = atom["id"].loc[atom["drug_id"] == sample_id].tolist()
        nodes = sorted(nodes)

        # 构建edge_index 和 edge_attr
        edge_attr = bond["bond_type"].loc[bond["drug_id"] == sample_id].tolist()
        edge_attr = edge_attr + edge_attr
        edge_attr = (
            torch.LongTensor(
                edge_attr,
            ).unsqueeze(1)
            - 1
        )
        edge_attr = torch.zeros([len(edge_attr), 7]).scatter_(1, edge_attr, 1)
        source_nodes = bond["atom1_id"].loc[bond["drug_id"] == sample_id] - nodes[0]
        target_nodes = bond["atom2_id"].loc[bond["drug_id"] == sample_id] - nodes[0]
        source_nodes = source_nodes.tolist()
        target_nodes = target_nodes.tolist()
        edge_index = [
            source_nodes + target_nodes,
            target_nodes + source_nodes,
        ]

        # 获取节点的特征
        ind1 = molecule["ind1"].loc[molecule["id"] == sample_id].tolist()[0]
        inda = molecule["inda"].loc[molecule["id"] == sample_id].tolist()[0]
        logp = molecule["logp"].loc[molecule["id"] == sample_id].tolist()[0]
        lumo = molecule["lumo"].loc[molecule["id"] == sample_id].tolist()[0]
        node_features = []
        for node in nodes:
            node_type_index = atom["element_type"].loc[atom["id"] == node].tolist()[0]
            node_type = [0] * 7
            node_type[node_type_index] = 1
            node_charge = atom["charge"].loc[atom["id"] == node].tolist()[0]
            node_feature = node_type + [node_charge, ind1, inda, logp, lumo]
            node_features.append(node_feature)
        x = torch.tensor(node_features, dtype=torch.float)
        sample = Data(
            x=x,
            edge_index=torch.LongTensor(edge_index),
            edge_attr=edge_attr,
            y=torch.LongTensor(
                [molecule["active"].loc[molecule["id"] == sample_id].tolist()[0]]
            ),
        )
        data.append(sample)
    return data

In [12]:
dataset = get_molecule_data()

100%|██████████| 42/42 [00:00<00:00, 222.82it/s]


In [13]:
dataset[0]

Data(x=[26, 12], edge_index=[2, 54], edge_attr=[54, 7], y=[1])

In [14]:
class MutagDataset(Dataset):
    def __init__(self, dataset):
        super(MutagDataset, self).__init__()
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def collate_fn(batch):
        data_list = [data for data in batch]
        batch_data = Batch.from_data_list(data_list)
        return batch_data

In [20]:
def train(
    model, optimizer, batch_size, train_dataset, test_dataset, epochs, verbose=True
):
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=MutagDataset.collate_fn,
    )
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=1,
        end_factor=0.01,
        total_iters=epochs * len(train_dataloader),
    )
    model.train()
    max_acc = -1
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            inputs = {
                "x": batch.x.to(device),
                "edge_index": batch.edge_index.to(device),
                "edge_attr": batch.edge_attr.to(device),
                "label": batch.y.to(device),
                "ptr": batch.ptr.to(device),
            }
            optimizer.zero_grad()
            logits, loss = model(**inputs)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        acc = eval(model, test_dataset)[0]
        if verbose:
            print(f"Epoch:{epoch+1}, Loss:{total_loss}, Accuracy:{acc}")
        if acc > max_acc:
            max_acc = acc
    return max_acc


def eval(model, test_dataset):
    model.eval()
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=MutagDataset.collate_fn,
    )
    labels = []
    preds = []
    for batch in test_loader:
        inputs = {
            "x": batch.x.to(device),
            "edge_index": batch.edge_index.to(device),
            "edge_attr": batch.edge_attr.to(device),
            "label": batch.y.to(device),
            "ptr": batch.ptr.to(device),
        }
        with torch.no_grad():
            logits, _ = model(**inputs)
        labels.append(inputs["label"])
        preds.append(torch.sigmoid(logits))
    labels = torch.stack(labels)
    preds = torch.stack(preds)
    labels = labels.view(-1)
    preds = preds.view(-1)
    # acc = acc_cal(preds, labels)
    # labels[labels==-1] = 0
    preds = preds > 0.5
    acc = (preds == labels).sum() / len(labels)
    return acc, preds, labels

In [16]:
dataset[0]

Data(x=[26, 12], edge_index=[2, 54], edge_attr=[54, 7], y=[1])

In [26]:
class NodeEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int = 12,
        embed_dim: int = 50,
    ):
        super(NodeEncoder, self).__init__()
        self.node_embedding = nn.Linear(input_dim, embed_dim)
        nn.init.xavier_uniform_(self.node_embedding.weight)

    def forward(self, x):
        x = self.node_embedding(x)
        return x  # [n,e]

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class EdgeEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int = 7,
        embed_dim: int = 50,
    ):
        super(EdgeEncoder, self).__init__()
        self.edge_embedding = nn.Linear(input_dim, embed_dim)
        nn.init.xavier_uniform_(self.edge_embedding.weight)

    def forward(self, edge_attr):
        edge_features = self.edge_embedding(edge_attr)  # [n,n,e]
        return edge_features

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class HwNonLinear(nn.Module):
    def __init__(
        self,
        embed_dim: int,
    ):
        """非线性的GNN的H_w函数实现，在利用节点、边、邻居节点的特征构造相关向量，
        然后拼接/加和节点状态向量，使用三层的FNN网络计算新的节点状态。
        Args:
            num_atom_type (int): 原子的种类数量
            num_ind1_type (int): ind1的种类数量
            num_inda_type (int): inda的种类数量
            embed_dim (int): 嵌入的维度
        """
        super(HwNonLinear, self).__init__()
        self.embed_dim = embed_dim
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
            # nn.Tanh(),
            # nn.Linear(embed_dim, embed_dim),
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(
        self,
        node_features,
        edge_features,
        neighbor_state,
        neighbor_features,
    ):
        """实现h_w函数，节点的特征l_n、边的特征l_nu、邻居节点的状态x_u、邻居节点的特征l_u，结合前馈神经网络进行前项传播。
        Args:
            node_features (torch.Tensor): l_n [n,e]
            edge_features (torch.Tensor): l_nu [n,n,e]
            neighbor_state (torch.Tensor): x_u [n,e]
            neighbor_features (torch.Tensor): l_u [n,e]

        Returns:
            _type_: _description_
        """
        # x = torch.cat(
        #     [node_features, edge_features, neighbor_state, neighbor_features], dim=-1
        # )
        # x = self.mlp(x)

        x = self.mlp2(
            (node_features + edge_features + neighbor_state + neighbor_features) / 4
        )
        return x

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class Aggr(nn.Module):
    def __init__(self, embed_dim: int):
        super(Aggr, self).__init__()
        self.embed_dim = embed_dim

    def forward(
        self,
        x,
        aggregate_map,
    ):
        """聚合函数，聚合邻居节点的状态向量，得到新的节点状态向量。
        Args:
            x (torch.Tensor): 需要传递的信息向量 [m,e]
            aggregate_map (torch.Tensor): 聚合映射矩阵 [m,n]
        Returns:
            _type_: _description_
        """
        x = torch.einsum("me,mn->ne", x, aggregate_map)
        # x = x / aggregate_map.T.sum(dim=-1, keepdim=True)
        # x = x / aggregate_map.T.norm(dim=-1, keepdim=True,p=2)
        return x


class GNN(nn.Module):
    def __init__(
        self,
        embed_dim,
        node_input_dim=12,
        edge_input_dim=7,
        t=10,
    ):
        super(GNN, self).__init__()
        self.embed_dim = embed_dim
        self.t = t
        self.node_encoder = NodeEncoder(node_input_dim, embed_dim)
        self.edge_encoder = EdgeEncoder(edge_input_dim, embed_dim)
        self.hw = HwNonLinear(embed_dim)
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, 1),
        )
        self.aggr = Aggr(embed_dim)
        self.criterion = nn.BCEWithLogitsLoss()
        self.l2_reg = nn.MSELoss()
        self.thresh = 1e-5
        nn.init.xavier_uniform_(self.output_layer[0].weight)

    def contraction_penalty(self, params, threshold=0.9):
        """计算收缩映射的惩罚项"""
        penalty = 0
        for param in params:
            # 计算参数的范数
            norm = torch.norm(param, p=2)
            # 如果范数大于阈值，则添加惩罚项
            penalty += torch.pow(torch.relu(norm - threshold), 2)
        return penalty

    def get_aggregate_map(self, edge_index, num_nodes):
        """获取聚合映射矩阵，将邻接表中的第二个位置的节点的信息聚合到第一个位置的节点上。

        Args:
            edge_index (torch.Tensor): [2,m]
            num_nodes (int): 节点的数量

        Returns:
            torch.Tensor: [m,n]
        """
        aggregate_map = torch.zeros(edge_index.shape[1], num_nodes)
        aggregate_map[range(aggregate_map.shape[0]), edge_index[0]] = 1
        return aggregate_map

    def forward(self, x, edge_index, edge_attr, label, ptr):
        """前向传播
        Args:
            num_nodes (int): 节点的数量
            node_labels (torch.Tensor): 节点的标签 [n,e]
            edge_index (torch.Tensor): 边的列表 [2,n]
            edge_attr (torch.Tensor): 边的属性 [n,n]
            label (torch.Tensor): 标签 [1]
        Returns:
            _type_: _description_
        """
        num_nodes = x.shape[0]
        if self.training:
            node_states = torch.zeros(
                num_nodes, self.embed_dim, requires_grad=False
            ).to(x)
        else:
            node_states = torch.zeros(
                num_nodes, self.embed_dim, requires_grad=False
            ).to(
                x
            )  # [n,e]
        aggregate_map = self.get_aggregate_map(edge_index, num_nodes).to(x)
        node_features = self.node_encoder(x)  # [n,e]
        edge_features = self.edge_encoder(edge_attr)  # [m,e]
        # node_states = node_features
        l_n = torch.index_select(node_features, 0, edge_index[0])  # [m,e]
        l_u = torch.index_select(node_features, 0, edge_index[1])
        l_nu = edge_features
        x_u = torch.index_select(node_states, 0, edge_index[1])
        for i in range(self.t):
            x = self.hw(l_n, l_nu, x_u, l_u)
            new_state = self.aggr(x, aggregate_map)
            # new_state = new_state / torch.norm(new_state, p=2, dim=-1, keepdim=True)
            with torch.no_grad():
                distance = torch.norm(new_state - node_states, p=2, dim=-1)
                # print(distance.mean().item())
                # print(new_state.mean().item())
                check = distance < self.thresh
            if check.all():
                # print("yes")
                break
            node_states = new_state
            # node_states = new_state * 0.9 + node_states * 1
            x_u = torch.index_select(node_states, 0, edge_index[1])
        # logits = self.output_layer(torch.index_select(node_states,dim=0,index=ptr))  # [b,1]
        logits = self.output_layer(
            torch.index_select(node_states, dim=0, index=ptr[:-1])
        )  # [b,1]
        # logits = self.output_layer(node_states.mean(dim=0))
        hw_params = self.hw.get_parameters()
        node_encoder_params = self.node_encoder.get_parameters()
        edge_encoder_params = self.edge_encoder.get_parameters()
        penalty = (
            self.contraction_penalty(hw_params, threshold=1)
            + self.contraction_penalty(node_encoder_params, threshold=1)
            + self.contraction_penalty(edge_encoder_params, threshold=1)
        )
        label[label == -1] = 0
        # print(label)
        loss = self.criterion(logits.view(-1), label.float())
        # node_states.requires_grad_(True)
        # grad_x = torch.autograd.grad(loss, node_states, create_graph=True)[0]
        return logits, loss + 0.01 * penalty

In [29]:
batch_size = 1
train_data, test_data = train_test_split(dataset, test_size=0.1, random_state=3)
train_dataset = MutagDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataset = MutagDataset(test_data)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)
model = GNN(embed_dim=50, t=1000, edge_input_dim=7, node_input_dim=12)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
acc = train(model, optimizer, batch_size, train_dataset, test_dataset, epochs=100)

Epoch:1, Loss:39.95088902115822, Accuracy:0.800000011920929
Epoch:2, Loss:36.88047909736633, Accuracy:0.800000011920929
Epoch:3, Loss:32.263562858104706, Accuracy:1.0
Epoch:4, Loss:26.761616557836533, Accuracy:1.0
Epoch:5, Loss:25.750604957342148, Accuracy:1.0
Epoch:6, Loss:25.568319715559483, Accuracy:1.0
Epoch:7, Loss:24.370466731488705, Accuracy:1.0
Epoch:8, Loss:25.288087122142315, Accuracy:1.0
Epoch:9, Loss:23.510294195264578, Accuracy:1.0


KeyboardInterrupt: 

In [40]:
from sklearn.model_selection import KFold

batch_size = 1
kf = KFold(n_splits=10, shuffle=True, random_state=42)
kf.split(dataset)
all_acc = []
for train_index, test_index in kf.split(dataset):
    train_data = [dataset[i] for i in train_index]
    test_data = [dataset[i] for i in test_index]
    train_dataset = MutagDataset(train_data)
    test_dataset = MutagDataset(test_data)
    model = GNN(embed_dim=50, t=1000, edge_input_dim=7, node_input_dim=12)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    acc = train(
        model,
        optimizer,
        batch_size,
        train_dataset,
        test_dataset,
        epochs=100,
        verbose=False,
    )
    all_acc.append(acc)
    print(acc)

tensor(0.8000)
tensor(1.)
tensor(0.7500)
tensor(1.)
tensor(1.)
tensor(0.7500)
tensor(0.7500)


KeyboardInterrupt: 