In [52]:
# 直接使用标签作为向量，不嵌入了

In [53]:
import pandas as pd
import torch
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn as nn
from torch.nn import functional as F
from sklearn.preprocessing import StandardScaler

In [54]:
atom = pd.read_csv(
    "dataset/Mutagenesis-42/atoms.csv",
    delimiter=";",
)
bond = pd.read_csv("dataset/Mutagenesis-42/bonds.csv", delimiter=";")
molecule = pd.read_csv("dataset/Mutagenesis-42/drugs.csv", delimiter=";")
atom.head()

Unnamed: 0,id,drug_id,element,atom_type,charge
0,0,0,c,22,-0.11
1,1,0,c,22,-0.11
2,2,0,c,22,-0.11
3,3,0,c,22,-0.11
4,4,0,h,3,0.15


In [55]:
molecule.head()

Unnamed: 0,id,ind1,inda,act,logp,lumo,active
0,0,0.0,0.0,-0.7,2.29,-3.025,0
1,1,0.0,0.0,0.57,2.13,-0.798,1
2,2,1.0,0.0,0.77,4.35,-2.138,1
3,3,1.0,0.0,-0.22,5.41,-1.429,0
4,4,1.0,0.0,-0.22,5.41,-1.478,0


In [56]:
bond.head(40)

Unnamed: 0,id,drug_id,atom1_id,atom2_id,bond_type
0,0,0,0,11,7
1,1,0,11,19,7
2,2,0,19,20,7
3,3,0,20,21,7
4,4,0,21,22,7
5,5,0,22,0,7
6,6,0,11,23,1
7,7,0,22,24,1
8,8,0,19,25,7
9,9,0,25,1,7


In [57]:
bond["atom2_id"].loc[bond["atom1_id"] == 11]

1    19
6    23
Name: atom2_id, dtype: int64

In [58]:
molecule_ids = list(range(molecule.shape[0]))

In [59]:
edges_dict = bond.values.tolist()
edges_dict = {(s[2], s[3]): s[4] for s in edges_dict}

In [60]:
def get_molecule_data():
    data = []
    for sample_id in tqdm(molecule_ids):
        # 首先获取节点信息
        nodes = atom["id"].loc[atom["drug_id"] == sample_id].tolist()
        nodes = sorted(nodes)

        # 构建edge_list 和 edge_attr
        edge_attr = bond["bond_type"].loc[bond["drug_id"] == sample_id].tolist()
        edge_attr = edge_attr + edge_attr
        source_nodes = bond["atom1_id"].loc[bond["drug_id"] == sample_id] - nodes[0]
        target_nodes = bond["atom2_id"].loc[bond["drug_id"] == sample_id] - nodes[0]
        source_nodes = source_nodes.tolist()
        target_nodes = target_nodes.tolist()
        self_connect = [i for i in range(len(nodes))]
        edge_list = [
            source_nodes + target_nodes,
            target_nodes + source_nodes,
        ]

        # 获取前全局特征
        ind1 = molecule["ind1"].loc[molecule["id"] == sample_id].tolist()[0]
        inda = molecule["inda"].loc[molecule["id"] == sample_id].tolist()[0]
        logp = molecule["logp"].loc[molecule["id"] == sample_id].tolist()[0]
        lumo = molecule["lumo"].loc[molecule["id"] == sample_id].tolist()[0]
        global_feature = {"ind1": ind1, "inda": inda, "logp": logp, "lumo": lumo}
        # 获取每个节点的特征
        node_features = []
        mean_charge = 0
        for node in nodes:
            node_type = atom["atom_type"].loc[atom["id"] == node].tolist()[0]
            node_charge = atom["charge"].loc[atom["id"] == node].tolist()[0]
            mean_charge += node_charge
            node_feature = {"type": node_type, "charge": node_charge}
            node_features.append(node_feature)
        sample = dict()
        sample["drug_id"] = sample_id
        nodes = np.array(nodes) - nodes[0]
        nodes = nodes.astype(np.int64).tolist()
        # nodes.append(len(nodes) - 1)
        sample["nodes"] = nodes
        sample["edge_list"] = edge_list
        sample["edge_attr"] = edge_attr
        sample["global_features"] = global_feature
        sample["node_features"] = node_features
        sample["label"] = (
            molecule["active"].loc[molecule["id"] == sample_id].tolist()[0]
        )
        data.append(sample)
    return data

In [61]:
data = get_molecule_data()

100%|██████████| 42/42 [00:00<00:00, 150.07it/s]


In [62]:
class MyDataset(Dataset):
    def __init__(self, data, device):
        super(MyDataset, self).__init__()
        self.data = data
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        nodes = sample["nodes"]
        edge_list = sample["edge_list"]
        label = sample["label"]
        atom_type = [t["type"] for t in sample["node_features"]]
        charge = [t["charge"] for t in sample["node_features"]]
        ind1 = [sample["global_features"]["ind1"]] * len(nodes)
        inda = [sample["global_features"]["inda"]] * len(nodes)
        logp = [sample["global_features"]["logp"]] * len(nodes)
        lumo = [sample["global_features"]["lumo"]] * len(nodes)
        l_n = {
            "atom_type": torch.LongTensor(atom_type).to(self.device),
            "charge": torch.tensor(charge).to(self.device),
            "ind1": torch.LongTensor(ind1).to(self.device),
            "inda": torch.LongTensor(inda).to(self.device),
            "logp": torch.tensor(logp).to(self.device),
            "lumo": torch.tensor(lumo).to(self.device),
        }
        edge_list = torch.LongTensor(edge_list).to(self.device)
        edge_attr = torch.LongTensor(sample["edge_attr"]).to(self.device)
        label = torch.tensor([label]).to(self.device)
        num_nodes = len(nodes)
        return {
            "num_nodes": num_nodes,
            "node_labels": l_n,
            "edge_list": edge_list,
            "edge_attr": edge_attr,
            "label": label,
        }

    def collate_fn(batch):
        num_nodes = [sample["num_nodes"] for sample in batch]
        node_labels = [sample["node_labels"] for sample in batch]
        edges_list = [sample["edge_list"] for sample in batch]
        label = [sample["label"] for sample in batch]
        edge_attr = [sample["edge_attr"] for sample in batch]
        return num_nodes, node_labels, edges_list, edge_attr, label

In [63]:
def train(model, optimizer, train_dataset, test_dataset, epochs):

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=MyDataset.collate_fn,
    )
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=1,
        end_factor=0.1,
        total_iters=epochs * len(train_dataloader),
    )
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epochs, gamma=0.7)
    max_acc = -1
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            inputs = {
                "num_nodes": batch[0][0],
                "node_labels": batch[1][0],
                "edge_list": batch[2][0],
                "edge_attr": batch[3][0],
                "label": batch[4][0],
            }
            optimizer.zero_grad()
            logits, loss = model(**inputs)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            scheduler.step()
        acc = eval(model, test_dataset)[0]
        if acc > max_acc:
            max_acc = acc
    return max_acc

    # print(f"Epoch:{epoch+1}, Loss:{total_loss}")


def eval(model, test_dataset):
    model.eval()
    test_loader = DataLoader(
        test_dataset, batch_size=1, shuffle=False, collate_fn=MyDataset.collate_fn
    )
    labels = []
    preds = []
    for batch in test_loader:
        inputs = {
            "num_nodes": batch[0][0],
            "node_labels": batch[1][0],
            "edge_list": batch[2][0],
            "edge_attr": batch[3][0],
            "label": batch[4][0],
        }
        with torch.no_grad():
            logits, _ = model(**inputs)
        labels.append(inputs["label"])
        preds.append(torch.sigmoid(logits))
    labels = torch.stack(labels)
    preds = torch.stack(preds)
    labels = labels.view(-1)
    preds = preds.view(-1)
    # acc = acc_cal(preds, labels)
    # labels[labels==-1] = 0
    preds = preds > 0.5
    acc = (preds == labels).sum() / len(labels)
    return acc, preds, labels

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data, test_data = train_test_split(data, test_size=0.1, random_state=100)
test_dataset = MyDataset(test_data, device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = MyDataset(train_data, device)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=MyDataset.collate_fn,
)
meta_data = {
    "num_atom_type": 233,
    "num_ind1_type": 2,
    "num_inda_type": 2,
    "num_edge_type": 8,
}
# model = GNN(embed_dim=100, t=10, **meta_data)

In [74]:
class NodeEncoder(nn.Module):
    def __init__(
        self,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        embed_dim: int,
    ):
        super(NodeEncoder, self).__init__()
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.embed_dim = embed_dim
        self.atom_embedding = nn.Embedding(num_atom_type, embed_dim)
        self.ind1_embedding = nn.Embedding(num_ind1_type, embed_dim)
        self.inda_embedding = nn.Embedding(num_inda_type, embed_dim)
        self.logp_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.lumo_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.charge_embedding = nn.Parameter(torch.randn(1, embed_dim))
        self.trans = nn.Linear(embed_dim * 6, embed_dim)

    def init(self):
        nn.init.xavier_uniform_(self.atom_embedding.weight)
        nn.init.xavier_uniform_(self.ind1_embedding.weight)
        nn.init.xavier_uniform_(self.inda_embedding.weight)
        nn.init.xavier_uniform_(self.logp_embedding)
        nn.init.xavier_uniform_(self.lumo_embedding)
        nn.init.xavier_uniform_(self.charge_embedding)

    def forward(self, node_labels):
        atom_embed = self.atom_embedding(node_labels["atom_type"])  # [n,e]
        ind1_embed = self.ind1_embedding(node_labels["ind1"])
        inda_embed = self.inda_embedding(node_labels["inda"])
        logp_embed = self.logp_embedding * node_labels["logp"].unsqueeze(dim=-1)
        lumo_embed = self.lumo_embedding * node_labels["lumo"].unsqueeze(
            dim=-1
        )  # [n,e]
        charge_embed = self.charge_embedding * node_labels["charge"].unsqueeze(dim=-1)
        # node_features = (
        #     atom_embed
        #     + ind1_embed
        #     + inda_embed
        #     + logp_embed
        #     + lumo_embed
        #     + charge_embed
        # )
        node_features = torch.cat(
            [
                atom_embed,
                ind1_embed,
                inda_embed,
                logp_embed,
                lumo_embed,
                charge_embed,
            ],
            dim=-1,
        )
        node_features = torch.tanh(self.trans(node_features))
        return node_features  # [n,e]

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class EdgeEncoder(nn.Module):
    def __init__(self, num_edge_type: int, embed_dim: int):
        super(EdgeEncoder, self).__init__()
        self.edge_embedding = nn.Embedding(num_edge_type + 1, embed_dim)
        nn.init.xavier_uniform_(self.edge_embedding.weight)

    def forward(self, edge_attr):
        edge_features = self.edge_embedding(edge_attr)  # [n,n,e]
        return edge_features

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class HwNonLinear(nn.Module):
    def __init__(
        self,
        embed_dim: int,
    ):
        """非线性的GNN的H_w函数实现，在利用节点、边、邻居节点的特征构造相关向量，
        然后拼接/加和节点状态向量，使用三层的FNN网络计算新的节点状态。
        Args:
            num_atom_type (int): 原子的种类数量
            num_ind1_type (int): ind1的种类数量
            num_inda_type (int): inda的种类数量
            embed_dim (int): 嵌入的维度
        """
        super(HwNonLinear, self).__init__()
        self.embed_dim = embed_dim
        self.trans = nn.Sequential(
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
        )
        nn.init.xavier_uniform_(self.trans[0].weight)  # 初始化参数
        nn.init.xavier_uniform_(self.trans[2].weight)  # 初始化参数
        # nn.init.xavier_uniform_(self.trans[4].weight)  # 初始化参数

    def forward(
        self,
        node_features,
        edge_features,
        neighbor_state,
        neighbor_features,
    ):
        """实现h_w函数，节点的特征l_n、边的特征l_nu、邻居节点的状态x_u、邻居节点的特征l_u，结合前馈神经网络进行前项传播。
        Args:
            node_features (torch.Tensor): l_n [n,e]
            edge_features (torch.Tensor): l_nu [n,n,e]
            neighbor_state (torch.Tensor): x_u [n,e]
            neighbor_features (torch.Tensor): l_u [n,e]

        Returns:
            _type_: _description_
        """
        # x = self.trans(neighbor_state)  # [m,e]
        x = torch.cat(
            [node_features, edge_features, neighbor_state, neighbor_features], dim=-1
        )
        x = self.trans(x)
        return x

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class Aggr(nn.Module):
    def __init__(self, embed_dim: int):
        super(Aggr, self).__init__()
        self.embed_dim = embed_dim

    def forward(
        self,
        x,
        aggregate_map,
    ):
        """聚合函数，聚合邻居节点的状态向量，得到新的节点状态向量。
        Args:
            x (torch.Tensor): 需要传递的信息向量 [m,e]
            aggregate_map (torch.Tensor): 聚合映射矩阵 [m,n]
        Returns:
            _type_: _description_
        """
        x = torch.einsum("me,mn->ne", x, aggregate_map)
        x = x / aggregate_map.T.sum(dim=-1, keepdim=True)
        return x


class GNN(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_atom_type: int,
        num_ind1_type: int,
        num_inda_type: int,
        num_edge_type: int,
        t,
    ):
        super(GNN, self).__init__()
        self.embed_dim = embed_dim
        self.t = t
        self.node_encoder = NodeEncoder(
            num_atom_type, num_ind1_type, num_inda_type, embed_dim
        )
        self.edge_encoder = EdgeEncoder(num_edge_type, embed_dim)
        self.hw = HwNonLinear(embed_dim)
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, 1),
        )
        self.aggr = Aggr(embed_dim)
        self.criterion = nn.BCEWithLogitsLoss()
        self.l2_reg = nn.MSELoss()
        self.thresh = 1e-7

    def contraction_penalty(self, params, threshold=0.9):
        """计算收缩映射的惩罚项"""
        penalty = 0
        for param in params:
            # 计算参数的范数
            norm = torch.norm(param, p=2)
            # 如果范数大于阈值，则添加惩罚项
            penalty += torch.pow(torch.relu(norm - threshold), 2)
        return penalty

    def get_aggregate_map(self, edge_list, num_nodes):
        """获取聚合映射矩阵，将邻接表中的第二个位置的节点的信息聚合到第一个位置的节点上。

        Args:
            edge_list (torch.Tensor): [2,m]
            num_nodes (int): 节点的数量

        Returns:
            torch.Tensor: [m,n]
        """
        aggregate_map = torch.zeros(edge_list.shape[1], num_nodes)
        aggregate_map[range(aggregate_map.shape[0]), edge_list[0]] = 1
        return aggregate_map

    def forward(self, num_nodes, node_labels, edge_list, edge_attr, label):
        """前向传播
        Args:
            num_nodes (int): 节点的数量
            node_labels (torch.Tensor): 节点的标签 [n,e]
            edge_list (torch.Tensor): 边的列表 [2,n]
            edge_attr (torch.Tensor): 边的属性 [n,n]
            label (torch.Tensor): 标签 [1]
        Returns:
            _type_: _description_
        """
        if self.training:
            node_states = torch.zeros(num_nodes, self.embed_dim, requires_grad=False)
        else:
            node_states = torch.zeros(
                num_nodes, self.embed_dim, requires_grad=False
            )  # [n,e]
        aggregate_map = self.get_aggregate_map(edge_list, num_nodes)
        node_features = self.node_encoder(node_labels)  # [n,e]
        node_states = node_features
        edge_features = self.edge_encoder(edge_attr)  # [m,e]
        l_n = torch.index_select(node_features, 0, edge_list[0])  # [m,e]
        l_u = torch.index_select(node_features, 0, edge_list[1])
        l_nu = edge_features
        x_u = torch.index_select(node_states, 0, edge_list[1])
        for i in range(self.t):
            x = self.hw(l_n, l_nu, x_u, l_u)
            new_state = self.aggr(x, aggregate_map)
            with torch.no_grad():
                distance = torch.norm(new_state - node_states, p=2, dim=-1)
                # print(distance.mean())
                check = distance < self.thresh
            if check.all():
                break
            node_states = new_state
            x_u = torch.index_select(node_states, 0, edge_list[1])
        logits = self.output_layer(node_states[0])
        # logits = self.output_layer(node_states.mean(dim=0))
        hw_params = self.hw.get_parameters()
        node_encoder_params = self.node_encoder.get_parameters()
        edge_encoder_params = self.edge_encoder.get_parameters()
        penalty = (
            self.contraction_penalty(hw_params, threshold=1)
            + self.contraction_penalty(node_encoder_params, threshold=1)
            + self.contraction_penalty(edge_encoder_params, threshold=1)
        )
        loss = self.criterion(logits, label.float())
        return logits, loss + penalty * 0.01

In [75]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=10, shuffle=True, random_state=42)
acc_scores = []
for train_index, test_index in tqdm(kf.split(data)):
    model = GNN(embed_dim=100, t=1000, **meta_data)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.001,
    )
    data_train = [data[i] for i in train_index]
    data_test = [data[i] for i in test_index]
    train_dataset = MyDataset(data_train, device)
    test_dataset = MyDataset(data_test, device)
    acc = train(model, optimizer, train_dataset, test_dataset, epochs=100)
    # acc, _, _ = eval(model, test_dataset)
    print(acc)
    acc_scores.append(acc.item())

1it [04:41, 281.55s/it]

tensor(1.)


2it [10:20, 315.57s/it]

tensor(1.)


3it [13:58, 270.71s/it]

tensor(0.7500)


3it [15:37, 312.38s/it]


KeyboardInterrupt: 

In [72]:
torch.tensor(acc_scores).mean()

tensor(0.8750)