In [1]:
import pandas as pd
import torch
import numpy as np
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import torch.nn as nn
from torch.nn import functional as F
from sklearn.preprocessing import StandardScaler
from torch_geometric.datasets import TUDataset
from torch_geometric.data import Data, Batch

In [2]:
device = torch.device("cpu")

In [3]:
dataset = TUDataset(root="./dataset/temp", name="MUTAG")
# data = dataset.shuffle()
train_data, test_data = train_test_split(dataset, test_size=0.1, random_state=100)

In [6]:
type(dataset)

torch_geometric.datasets.tu_dataset.TUDataset

In [5]:
dataset[0].y

tensor([1])

In [50]:
batch = Batch.from_data_list(dataset[0:4])
batch.ptr

tensor([ 0, 17, 30, 43, 62])

In [24]:
class MutagDataset(Dataset):
    def __init__(self, dataset):
        super(MutagDataset, self).__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def collate_fn(batch):
        return Batch.from_data_list(batch)

In [25]:
def train(model, optimizer, train_dataset, test_dataset, epochs, verbose=True):
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=MutagDataset.collate_fn,
    )
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=1,
        end_factor=0.01,
        total_iters=epochs * len(train_dataloader),
    )
    model.train()
    max_acc = -1
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            inputs = {
                "x": batch.x.to(device),
                "edge_index": batch.edge_index.to(device),
                "edge_attr": batch.edge_attr.to(device),
                "label": batch.y.to(device),
                "ptr": batch.ptr.to(device),
            }
            optimizer.zero_grad()
            logits, loss = model(**inputs)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        acc = eval(model, test_dataset)[0]
        if verbose:
            print(f"Epoch:{epoch+1}, Loss:{total_loss}, Accuracy:{acc}")
        if acc > max_acc:
            max_acc = acc
    return max_acc


def eval(model, test_dataset):
    model.eval()
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=MutagDataset.collate_fn,
    )
    labels = []
    preds = []
    for batch in test_loader:
        inputs = {
            "x": batch.x.to(device),
            "edge_index": batch.edge_index.to(device),
            "edge_attr": batch.edge_attr.to(device),
            "label": batch.y.to(device),
            "ptr": batch.ptr.to(device),
        }
        with torch.no_grad():
            logits, _ = model(**inputs)
        labels.append(inputs["label"])
        preds.append(torch.sigmoid(logits))
    labels = torch.stack(labels)
    preds = torch.stack(preds)
    labels = labels.view(-1)
    preds = preds.view(-1)
    # acc = acc_cal(preds, labels)
    # labels[labels==-1] = 0
    preds = preds > 0.5
    acc = (preds == labels).sum() / len(labels)
    return acc, preds, labels

In [46]:
class NodeEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
    ):
        super(NodeEncoder, self).__init__()
        self.node_embedding = nn.Linear(input_dim, embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            # nn.Tanh(),
            # nn.Linear(embed_dim, embed_dim),
        )
        nn.init.xavier_uniform_(self.node_embedding.weight)
        nn.init.xavier_uniform_(self.mlp[0].weight)

    def forward(self, x):
        x = self.node_embedding(x)
        # x = torch.tanh(x)
        # x = x / torch.norm(x, p=2, dim=-1, keepdim=True)
        return x  # [n,e]

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class EdgeEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
    ):
        super(EdgeEncoder, self).__init__()
        self.edge_embedding = nn.Linear(input_dim, embed_dim)
        nn.init.xavier_uniform_(self.edge_embedding.weight)

    def forward(self, edge_attr):
        edge_features = self.edge_embedding(edge_attr)  # [n,n,e]
        # edge_features = torch.tanh(edge_features)
        # edge_features = edge_features / torch.norm(
        #     edge_features, p=2, dim=-1, keepdim=True
        # )
        return edge_features

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class HwNonLinear(nn.Module):
    def __init__(
        self,
        embed_dim: int,
    ):
        """非线性的GNN的H_w函数实现，在利用节点、边、邻居节点的特征构造相关向量，
        然后拼接/加和节点状态向量，使用三层的FNN网络计算新的节点状态。
        Args:
            num_atom_type (int): 原子的种类数量
            num_ind1_type (int): ind1的种类数量
            num_inda_type (int): inda的种类数量
            embed_dim (int): 嵌入的维度
        """
        super(HwNonLinear, self).__init__()
        self.embed_dim = embed_dim
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
            # nn.Tanh(),
            # nn.Linear(embed_dim, embed_dim),
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
            # nn.Tanh(),
            # nn.Linear(embed_dim, embed_dim),
            # nn.Tanh(),
            # nn.Linear(embed_dim, embed_dim),
        )
        # nn.init.xavier_uniform_(self.mlp[0].weight)
        # nn.init.xavier_uniform_(self.mlp[2].weight)
        # nn.init.xavier_uniform_(self.mlp2[0].weight)
        # nn.init.xavier_uniform_(self.mlp2[2].weight)

    def forward(
        self,
        node_features,
        edge_features,
        neighbor_state,
        neighbor_features,
    ):
        """实现h_w函数，节点的特征l_n、边的特征l_nu、邻居节点的状态x_u、邻居节点的特征l_u，结合前馈神经网络进行前项传播。
        Args:
            node_features (torch.Tensor): l_n [n,e]
            edge_features (torch.Tensor): l_nu [n,n,e]
            neighbor_state (torch.Tensor): x_u [n,e]
            neighbor_features (torch.Tensor): l_u [n,e]

        Returns:
            _type_: _description_
        """
        # x = torch.cat(
        #     [node_features, edge_features, neighbor_state, neighbor_features], dim=-1
        # )
        # x = self.mlp(x)

        x = self.mlp2((edge_features + neighbor_state + neighbor_features) / 3)
        return x

    def get_parameters(self):
        """获取所有待学习参数"""
        return list(self.parameters())


class Aggr(nn.Module):
    def __init__(self, embed_dim: int):
        super(Aggr, self).__init__()
        self.embed_dim = embed_dim

    def forward(
        self,
        x,
        aggregate_map,
    ):
        """聚合函数，聚合邻居节点的状态向量，得到新的节点状态向量。
        Args:
            x (torch.Tensor): 需要传递的信息向量 [m,e]
            aggregate_map (torch.Tensor): 聚合映射矩阵 [m,n]
        Returns:
            _type_: _description_
        """
        x = torch.einsum("me,mn->ne", x, aggregate_map)
        # x = x / aggregate_map.T.sum(dim=-1, keepdim=True)
        # x = x / aggregate_map.T.norm(dim=-1, keepdim=True,p=2)
        return x


class GNN(nn.Module):
    def __init__(
        self,
        embed_dim,
        node_input_dim=7,
        edge_input_dim=4,
        t=10,
    ):
        super(GNN, self).__init__()
        self.embed_dim = embed_dim
        self.t = t
        self.node_encoder = NodeEncoder(node_input_dim, embed_dim)
        self.edge_encoder = EdgeEncoder(edge_input_dim, embed_dim)
        self.hw = HwNonLinear(embed_dim)
        self.output_layer = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, 1),
        )
        self.aggr = Aggr(embed_dim)
        self.criterion = nn.BCEWithLogitsLoss()
        self.l2_reg = nn.MSELoss()
        self.thresh = 1e-5
        nn.init.xavier_uniform_(self.output_layer[0].weight)

    def contraction_penalty(self, params, threshold=0.9):
        """计算收缩映射的惩罚项"""
        penalty = 0
        for param in params:
            # 计算参数的范数
            norm = torch.norm(param, p=2)
            # 如果范数大于阈值，则添加惩罚项
            penalty += torch.pow(torch.relu(norm - threshold), 2)
        return penalty

    def get_aggregate_map(self, edge_index, num_nodes):
        """获取聚合映射矩阵，将邻接表中的第二个位置的节点的信息聚合到第一个位置的节点上。

        Args:
            edge_index (torch.Tensor): [2,m]
            num_nodes (int): 节点的数量

        Returns:
            torch.Tensor: [m,n]
        """
        aggregate_map = torch.zeros(edge_index.shape[1], num_nodes)
        aggregate_map[range(aggregate_map.shape[0]), edge_index[0]] = 1
        return aggregate_map

    def forward(self, x, edge_index, edge_attr, label, ptr):
        """前向传播
        Args:
            num_nodes (int): 节点的数量
            node_labels (torch.Tensor): 节点的标签 [n,e]
            edge_index (torch.Tensor): 边的列表 [2,n]
            edge_attr (torch.Tensor): 边的属性 [n,n]
            label (torch.Tensor): 标签 [1]
        Returns:
            _type_: _description_
        """
        num_nodes = x.shape[0]
        if self.training:
            node_states = torch.zeros(
                num_nodes, self.embed_dim, requires_grad=False
            ).to(x)
        else:
            node_states = torch.zeros(
                num_nodes, self.embed_dim, requires_grad=False
            ).to(
                x
            )  # [n,e]
        aggregate_map = self.get_aggregate_map(edge_index, num_nodes).to(x)
        node_features = self.node_encoder(x)  # [n,e]
        edge_features = self.edge_encoder(edge_attr)  # [m,e]
        # node_states = node_features
        l_n = torch.index_select(node_features, 0, edge_index[0])  # [m,e]
        l_u = torch.index_select(node_features, 0, edge_index[1])
        l_nu = edge_features
        x_u = torch.index_select(node_states, 0, edge_index[1])
        for i in range(self.t):
            x = self.hw(l_n, l_nu, x_u, l_u)
            new_state = self.aggr(x, aggregate_map)
            # new_state = new_state / torch.norm(new_state, p=2, dim=-1, keepdim=True)
            with torch.no_grad():
                distance = torch.norm(new_state - node_states, p=2, dim=-1)
                # print(distance.mean().item())
                # print(new_state.mean().item())
                check = distance < self.thresh
            if check.all():
                # print("yes")
                break
            node_states = new_state
            # node_states = new_state * 0.9 + node_states * 0.1
            x_u = torch.index_select(node_states, 0, edge_index[1])
        # logits = self.output_layer(torch.index_select(node_states,dim=0,index=ptr))  # [b,1]
        logits = self.output_layer(
            torch.index_select(node_states, dim=0, index=ptr[:-1])
        )  # [b,1]
        # logits = self.output_layer(node_states.mean(dim=0))
        hw_params = self.hw.get_parameters()
        node_encoder_params = self.node_encoder.get_parameters()
        edge_encoder_params = self.edge_encoder.get_parameters()
        penalty = (
            self.contraction_penalty(hw_params, threshold=0.1)
            + self.contraction_penalty(node_encoder_params, threshold=0.1)
            + self.contraction_penalty(edge_encoder_params, threshold=0.1)
        )
        label[label == -1] = 0
        loss = self.criterion(logits.view(-1), label.float())
        # node_states.requires_grad_(True)
        # grad_x = torch.autograd.grad(loss, node_states, create_graph=True)[0]
        return logits, loss + 0.01 * penalty

In [45]:
dataset = TUDataset(root="./dataset/temp", name="MUTAG")
# data = dataset.shuffle()
train_data, test_data = train_test_split(dataset, test_size=0.1, random_state=2)
train_dataset = MutagDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_dataset = MutagDataset(test_data)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
)
model = GNN(embed_dim=50, t=1000, edge_input_dim=4, node_input_dim=7)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
acc = train(model, optimizer, train_dataset, test_dataset, epochs=100)

Epoch:1, Loss:134.00599282979965, Accuracy:0.5263158082962036
Epoch:2, Loss:121.7857601493597, Accuracy:0.5263158082962036
Epoch:3, Loss:137.7623077481985, Accuracy:0.5263158082962036
Epoch:4, Loss:145.52199064195156, Accuracy:0.5263158082962036


KeyboardInterrupt: 

<module 'torch.mps' from '/Users/lijinliang/opt/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/mps/__init__.py'>

In [109]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=10, shuffle=True, random_state=42)
kf.split(dataset)
all_acc = []
for train_index, test_index in kf.split(dataset):
    train_data = dataset[train_index]
    test_data = dataset[test_index]
    train_dataset = MutagDataset(train_data)
    test_dataset = MutagDataset(test_data)
    model = GNN(embed_dim=50, t=1000, edge_input_dim=4, node_input_dim=7)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    acc = train(
        model, optimizer, train_dataset, test_dataset, epochs=100, verbose=False
    )
    all_acc.append(acc)
    print(acc)

RuntimeError: Placeholder storage has not been allocated on MPS device!

In [45]:
from sklearn.model_selection import KFold

kf = KFold(n_splits=10, shuffle=True, random_state=42)
acc_scores = []
for train_index, test_index in tqdm(kf.split(data)):
    model = GNN(embed_dim=100, t=30, **meta_data)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.001,
    )
    data_train = [data[i] for i in train_index]
    data_test = [data[i] for i in test_index]
    train_dataset = MyDataset(data_train, device)
    train(model, optimizer, train_dataset, epochs=30)
    test_dataset = MyDataset(data_test, device)
    acc, _, _ = eval(model, test_dataset)
    print(acc)
    acc_scores.append(acc.item())

1it [00:12, 12.49s/it]

tensor(0.8000)


2it [00:24, 12.31s/it]

tensor(0.8000)


3it [00:36, 12.15s/it]

tensor(0.7500)


4it [00:48, 12.17s/it]

tensor(0.7500)


5it [01:01, 12.21s/it]

tensor(0.5000)


6it [01:13, 12.14s/it]

tensor(0.5000)


7it [01:25, 12.25s/it]

tensor(0.5000)


8it [01:37, 12.15s/it]

tensor(0.7500)


9it [01:49, 12.24s/it]

tensor(0.7500)


10it [02:02, 12.22s/it]

tensor(0.5000)





In [46]:
torch.tensor(acc_scores).mean()

tensor(0.6600)