In [2]:
import os
import time
import torch
import torch_sparse
import torch_geometric
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Flickr
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv
from torch.nn import Linear, Sequential, ReLU
from torch_geometric.data import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.datasets import TUDataset, ZINC
from torch_geometric.transforms import NormalizeFeatures, ToUndirected
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.utils import train_test_split_edges, negative_sampling, remove_self_loops, add_self_loops, to_undirected

In [7]:
print("torch:", torch.__version__)
print("torch-sparse:", torch_sparse.__version__)
print("torch-geometric:", torch_geometric.__version__)
print("✅ NeighborLoader 正常可用")

torch: 2.1.0
torch-sparse: 0.6.18+pt21cpu
torch-geometric: 2.6.1
✅ NeighborLoader 正常可用


# 节点分类
## 定义模型

In [26]:
'''标准图卷积层'''
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.layers.append(GCNConv(hidden_channels, hidden_channels))
        self.layers.append(GCNConv(hidden_channels, out_channels))
    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x, edge_index))
            x = F.dropout(x, p=0.5, training=self.training)
        return self.layers[-1](x, edge_index)
'''使用注意力机制计算邻居节点的加权平均'''
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1)
    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, 0.5, training=self.training)
        return self.conv2(x, edge_index)
'''可采样邻居'''
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, 0.5, training=self.training)
        return self.conv2(x, edge_index)
'''强调结构区分能力，更复杂的 MLP 替代单线性变换'''
class GIN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        nn1 = Sequential(Linear(in_channels, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
        nn2 = Sequential(Linear(hidden_channels, hidden_channels), ReLU(), Linear(hidden_channels, out_channels))
        self.conv1 = GINConv(nn1)
        self.conv2 = GINConv(nn2)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, 0.5, training=self.training)
        return self.conv2(x, edge_index)

## 加载数据

In [24]:
class LocalOnlyPlanetoid(Planetoid):
    @property
    def raw_file_names(self):
        if self.name.lower() == 'cora':
            return ['cora.content', 'cora.cites']
        elif self.name.lower() == 'citeseer':
            return ['citeseer.content', 'citeseer.cites']
        elif self.name.lower() == 'pubmed':
            return ['Pubmed-Diabetes.NODE.paper.tab', 'Pubmed-Diabetes.DIRECTED.cites.tab']
        else:
            raise ValueError(f"Unsupported dataset: {self.name}")
    def download(self):
        print(f"\nChecking dataset: {self.name}")
        print(f"Root directory: {self.root}")
        print(f"Raw directory: {self.raw_dir}")
        print(f"Expected raw files: {self.raw_file_names}")
        if os.path.exists(self.raw_dir):
            print(f"Files in raw directory: {os.listdir(self.raw_dir)}")
        else:
            print("Raw directory does not exist!")
        missing_files = [f for f in self.raw_paths if not os.path.exists(f)]
        if missing_files:
            raise RuntimeError(
                f"Missing {len(missing_files)} files in {self.raw_dir}:\n"
                f"- Missing: {missing_files}\n"
                f"- Required: {self.raw_file_names}\n"
                "Please download them from:\n"
                "- Cora/Citeseer: https://linqs-data.soe.ucsc.edu/public/lbc/\n"
                "- Pubmed: https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz"
            )
        else:
            print("All raw files found!")

def load_data(name):
    name = name.lower()
    base_path = r'D:\Data\master\Graph Machine Learning\GNN\standard benchmark datasets'
    transform = T.NormalizeFeatures()
    # 直接使用数据集父目录作为root，并小写化name
    if name == 'cora':
        return LocalOnlyPlanetoid(root=os.path.join(base_path, ''), name='cora', transform=transform)
    elif name == 'citeseer':
        return LocalOnlyPlanetoid(root=os.path.join(base_path, ''), name='citeseer', transform=transform)
    elif name == 'flickr':
        return Flickr(root=os.path.join(base_path, ''), transform=transform)
    else:
        raise ValueError("Only 'cora', 'citeseer', and 'flickr' are supported.")

## 图训练

In [16]:
'''全图训练'''
def train_full(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()
@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        acc = (pred[mask] == data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

'''子图采样训练'''
def train_sample(model, loader, optimizer):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

## 主程序入口

In [27]:
def run(dataset_name, model_name, use_sampler=False, epochs=100, lr=0.005, layers=2):
    # 加载数据集和初始化模型
    dataset = load_data(dataset_name)
    data = dataset[0].to(device)  # 将图数据转移到设备（GPU或CPU）
    # 获取输入维度（节点特征维度）和输出维度（类别数）
    in_dim = dataset.num_node_features
    out_dim = dataset.num_classes
    # 用一个字典映射模型名称到对应类，简洁灵活
    model_cls = {'GCN': GCN, 'GAT': GAT, 'GraphSAGE': GraphSAGE, 'GIN': GIN}[model_name]
    # 特殊处理：GCN 支持自定义层数（num_layers），其他模型没有这个参数
    if model_name == 'GCN':
        model = model_cls(in_dim, 64, out_dim, num_layers=layers).to(device)
    else:
        model = model_cls(in_dim, 64, out_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
     # 打印当前实验配置信息
    print(f"--- [{dataset_name}] {model_name} | Sampler = {use_sampler} | LR={lr}, Layers={layers} ---")

    '''模型训练部分'''
    # 启动计时器
    start = time.time()
    # 子图采样训练模式
    if use_sampler:
        loader = NeighborLoader(
            data,   # 完整图数据
            input_nodes=data.train_mask,    # 指定训练节点
            num_neighbors=[15, 10],    # 每层采样邻居数
            batch_size=1024
        )
        for epoch in range(epochs):
            loss = train_sample(model, loader, optimizer)
            if epoch % 10 == 0:    # 每10轮打印一次精度
                accs = test(model, data)
                print(f"Epoch {epoch}: Loss {loss:.4f}, Acc {accs}")
    # 全图训练模式
    else:
        for epoch in range(epochs):
            loss = train_full(model, data, optimizer)
            if epoch % 10 == 0:
                accs = test(model, data)
                print(f"Epoch {epoch}: Loss {loss:.4f}, Acc {accs}")
    print(f"Total time: {time.time() - start:.2f}s\n")
    
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    datasets = ['Cora', 'Citeseer', 'Flickr']
    models = ['GCN', 'GAT', 'GraphSAGE', 'GIN']
    for d in datasets:
        for m in models:
            run(d, m, use_sampler=False)
            run(d, m, use_sampler=True)
            run(d, m, use_sampler=True, lr=0.01)
            run(d, m, use_sampler=True, layers=3)

--- [Cora] GCN | Sampler = False | LR=0.005, Layers=2 ---
Epoch 0: Loss 1.9463, Acc [0.3, 0.184, 0.203]
Epoch 10: Loss 1.8533, Acc [0.9142857142857143, 0.666, 0.692]
Epoch 20: Loss 1.6965, Acc [0.95, 0.756, 0.77]
Epoch 30: Loss 1.4786, Acc [0.9642857142857143, 0.776, 0.783]
Epoch 40: Loss 1.2311, Acc [0.9714285714285714, 0.79, 0.796]
Epoch 50: Loss 1.0170, Acc [0.9785714285714285, 0.8, 0.814]
Epoch 60: Loss 0.7971, Acc [0.9785714285714285, 0.802, 0.815]
Epoch 70: Loss 0.6713, Acc [0.9857142857142858, 0.802, 0.821]
Epoch 80: Loss 0.5500, Acc [0.9857142857142858, 0.802, 0.823]
Epoch 90: Loss 0.4961, Acc [0.9857142857142858, 0.802, 0.814]
Total time: 2.22s

--- [Cora] GCN | Sampler = True | LR=0.005, Layers=2 ---
Epoch 0: Loss 1.9453, Acc [0.6142857142857143, 0.454, 0.482]
Epoch 10: Loss 1.8116, Acc [0.9428571428571428, 0.768, 0.789]
Epoch 20: Loss 1.5875, Acc [0.95, 0.768, 0.789]
Epoch 30: Loss 1.3347, Acc [0.9714285714285714, 0.778, 0.801]
Epoch 40: Loss 1.0345, Acc [0.9642857142857143,

Epoch 40: Loss 0.0087, Acc [0.9928571428571429, 0.758, 0.737]
Epoch 50: Loss 0.0053, Acc [0.9857142857142858, 0.754, 0.771]
Epoch 60: Loss 0.0085, Acc [0.9928571428571429, 0.76, 0.731]
Epoch 70: Loss 0.0657, Acc [0.9857142857142858, 0.766, 0.784]
Epoch 80: Loss 0.0070, Acc [0.9857142857142858, 0.73, 0.743]
Epoch 90: Loss 0.0103, Acc [0.9928571428571429, 0.758, 0.765]
Total time: 2.79s

--- [Cora] GIN | Sampler = True | LR=0.01, Layers=2 ---
Epoch 0: Loss 1.9552, Acc [0.15714285714285714, 0.178, 0.152]
Epoch 10: Loss 0.4367, Acc [0.9571428571428572, 0.768, 0.754]
Epoch 20: Loss 0.1049, Acc [0.9928571428571429, 0.736, 0.721]
Epoch 30: Loss 0.0174, Acc [0.9857142857142858, 0.726, 0.731]
Epoch 40: Loss 0.0047, Acc [0.9857142857142858, 0.706, 0.699]
Epoch 50: Loss 0.0026, Acc [0.9928571428571429, 0.722, 0.723]
Epoch 60: Loss 0.0477, Acc [0.9857142857142858, 0.764, 0.768]
Epoch 70: Loss 0.0220, Acc [0.9857142857142858, 0.688, 0.697]
Epoch 80: Loss 0.0193, Acc [0.9642857142857143, 0.702, 0.66

Epoch 10: Loss 1.3291, Acc [1.0, 0.632, 0.645]
Epoch 20: Loss 0.5994, Acc [1.0, 0.688, 0.68]
Epoch 30: Loss 0.2674, Acc [1.0, 0.674, 0.687]
Epoch 40: Loss 0.2070, Acc [1.0, 0.676, 0.685]
Epoch 50: Loss 0.1704, Acc [1.0, 0.686, 0.694]
Epoch 60: Loss 0.1606, Acc [1.0, 0.692, 0.687]
Epoch 70: Loss 0.1252, Acc [1.0, 0.702, 0.695]
Epoch 80: Loss 0.1243, Acc [1.0, 0.694, 0.684]
Epoch 90: Loss 0.1356, Acc [1.0, 0.692, 0.694]
Total time: 5.33s

--- [Citeseer] GraphSAGE | Sampler = True | LR=0.005, Layers=3 ---
Epoch 0: Loss 1.7950, Acc [0.475, 0.22, 0.213]
Epoch 10: Loss 1.6227, Acc [1.0, 0.586, 0.589]
Epoch 20: Loss 1.2804, Acc [1.0, 0.668, 0.654]
Epoch 30: Loss 0.8529, Acc [1.0, 0.708, 0.682]
Epoch 40: Loss 0.4956, Acc [1.0, 0.702, 0.692]
Epoch 50: Loss 0.3284, Acc [1.0, 0.686, 0.687]
Epoch 60: Loss 0.2524, Acc [1.0, 0.688, 0.695]
Epoch 70: Loss 0.2382, Acc [1.0, 0.688, 0.701]
Epoch 80: Loss 0.2301, Acc [1.0, 0.696, 0.702]
Epoch 90: Loss 0.2162, Acc [1.0, 0.694, 0.703]
Total time: 5.11s

---

Downloading https://drive.usercontent.google.com/download?id=1crmsTbd1-2sEXsGwa2IKnIB7Zd3TmUsy&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1join-XdvX3anJU_MLVtick7MgeAQiWIZ&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1uxIkbtg5drHTsKt-PAsZZ4_yJmgFmle9&confirm=t
Downloading https://drive.usercontent.google.com/download?id=1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7&confirm=t
Processing...
Done!


--- [Flickr] GCN | Sampler = False | LR=0.005, Layers=2 ---
Epoch 0: Loss 1.9454, Acc [0.2614453781512605, 0.2603531731803514, 0.2578317572715457]
Epoch 10: Loss 1.7412, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 20: Loss 1.6368, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 30: Loss 1.6318, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 40: Loss 1.6308, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 50: Loss 1.6273, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 60: Loss 1.6229, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 70: Loss 1.6218, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 80: Loss 1.6180, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 90: Loss 1.6141, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Total time: 76.47s

--- [Flickr] GCN | Sampler = Tru

Epoch 50: Loss 1.5800, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 60: Loss 1.5753, Acc [0.42160224089635856, 0.42394227321620653, 0.4234302872764756]
Epoch 70: Loss 1.5708, Acc [0.4230812324929972, 0.4253764790247401, 0.42432662573387714]
Epoch 80: Loss 1.5645, Acc [0.42738375350140057, 0.43093402653280743, 0.43001837493837675]
Epoch 90: Loss 1.5603, Acc [0.43318767507002803, 0.4355951954105414, 0.4353964056827858]
Total time: 161.44s

--- [Flickr] GraphSAGE | Sampler = True | LR=0.005, Layers=2 ---
Epoch 0: Loss 1.7008, Acc [0.42160224089635856, 0.4238526353531732, 0.4234302872764756]
Epoch 10: Loss 1.5714, Acc [0.46384313725490195, 0.4606489781283614, 0.4618383901761305]
Epoch 20: Loss 1.5607, Acc [0.46525490196078434, 0.4617246324847616, 0.46304844709362253]
Epoch 30: Loss 1.5598, Acc [0.46639775910364145, 0.46401039799211186, 0.4659167301573074]
Epoch 40: Loss 1.5566, Acc [0.4666666666666667, 0.46360702760846184, 0.465110025545646]
Epoch 50: Loss 1.5551

# 图上的链路预测
## 定义模型

In [17]:
'''标准图卷积：聚合邻居节点特征'''
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, layers=2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        return self.convs[-1](x, edge_index)
'''图注意力卷积：适合处理图中连接权重不均的情况'''
class GATEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, layers=2):
        super().__init__()
        self.layers = layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=4, concat=True))
        for _ in range(layers - 2):
            self.convs.append(GATConv(hidden_channels * 4, hidden_channels, heads=4, concat=True))
        self.convs.append(GATConv(hidden_channels * 4, hidden_channels, heads=1, concat=True))
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.elu(conv(x, edge_index))
        return self.convs[-1](x, edge_index)
'''邻居聚合 + 自身连接：支持大图的采样训练，适用于归纳学习'''
class SAGEEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, layers=2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        return self.convs[-1](x, edge_index)
'''图同构网络'''
class GINEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, layers=2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        for i in range(layers):
            in_dim = in_channels if i == 0 else hidden_channels
            mlp = Sequential(Linear(in_dim, hidden_channels), ReLU(), Linear(hidden_channels, hidden_channels))
            self.convs.append(GINConv(mlp))
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
        return self.convs[-1](x, edge_index)

## 链路预测（解码器）

In [18]:
'''解码器'''
def decode(z, edge_index):
    # Inner product decoder
    return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
'''构造训练目标标签'''
def get_link_labels(pos_edge_index, neg_edge_index):
    # 正样本边（图中真实存在的边）
    pos_labels = torch.ones(pos_edge_index.size(1))
    # 负样本边（图中不存在的边，通过负采样得到）
    neg_labels = torch.zeros(neg_edge_index.size(1))
    return torch.cat([pos_labels, neg_labels], dim=0)
'''评估函数'''
def evaluate(model, data, pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model(data.x, data.edge_index)
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)
        pred = decode(z, edge_index).sigmoid()
        labels = get_link_labels(pos_edge_index, neg_edge_index).to(pred.device)
        auc = roc_auc_score(labels.cpu(), pred.cpu())
        ap = average_precision_score(labels.cpu(), pred.cpu())
        return auc, ap

## 数据加载

In [19]:
class LocalOnlyPlanetoid(Planetoid):
    @property
    def raw_file_names(self):
        if self.name.lower() == 'cora':
            return ['cora.content', 'cora.cites']
        elif self.name.lower() == 'citeseer':
            return ['citeseer.content', 'citeseer.cites']
        else:
            raise ValueError(f"Unsupported Planetoid dataset: {self.name}")
    def download(self):
        print(f"[INFO] 使用本地文件加载数据集：{self.name}")
        print(f"检查目录：{self.raw_dir}")
        print(f"期望文件：{self.raw_file_names}")
        missing = [f for f in self.raw_paths if not os.path.exists(f)]
        if missing:
            raise RuntimeError(
                f"[ERROR] 缺少文件: {missing}\n"
                f"请手动下载并放置到目录：{self.raw_dir}\n"
                f"Cora/Citeseer 下载地址：https://linqs-data.soe.ucsc.edu/public/lbc/"
            )
        else:
            print("[OK] 所有原始文件已找到，开始加载。")

# ========== 通用加载函数 ==========
def load_data(name):
    name = name.lower()
    base_path = r'D:\Data\master\Graph Machine Learning\GNN\standard benchmark datasets'
    if name in ['cora', 'citeseer']:
        dataset = LocalOnlyPlanetoid(
            root=os.path.join(base_path),
            name=name,
            transform=NormalizeFeatures()
        )
        data = dataset[0]
        data.edge_index = ToUndirected()(data).edge_index
        data = train_test_split_edges(data)  
        return data, dataset.num_node_features
    elif name == 'flickr':
        dataset = Flickr(
            root=os.path.join(base_path, 'Flickr'),
            transform=NormalizeFeatures()
        )
        data = dataset[0]
        data.edge_index = ToUndirected()(data).edge_index
        # 对Flickr用采样器训练，不做 train_test_split_edges
        return data, dataset.num_node_features
    else:
        raise ValueError("Only 'cora', 'citeseer', and 'flickr' datasets are supported.")

## 训练与测试

In [20]:
def train(model, data, optimizer, device):
    model.train()
    optimizer.zero_grad()
    if hasattr(data, 'train_pos_edge_index'):
        # 链路预测任务
        z = model(data.x, data.train_pos_edge_index)
        neg_edge_index = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=data.train_pos_edge_index.size(1)
        )
        edge_index = torch.cat([data.train_pos_edge_index, neg_edge_index], dim=-1)
        labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(device)
        pred = decode(z, edge_index).sigmoid()
        loss = F.binary_cross_entropy(pred, labels)
    else:
        # 非链路预测任务
        z = model(data.x, data.edge_index)
        loss = z.norm(p=2).mean()  
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
@torch.no_grad()
def test(model, data, device):
    model.eval()
    if not hasattr(data, 'val_pos_edge_index'):
        print("No link prediction evaluation for this dataset.")
        return None, None
    z = model(data.x, data.train_pos_edge_index)
    edge_index = torch.cat([data.val_pos_edge_index, data.val_neg_edge_index], dim=-1)
    labels = get_link_labels(data.val_pos_edge_index, data.val_neg_edge_index).to(device)
    pred = decode(z, edge_index).sigmoid()
    auc = roc_auc_score(labels.cpu(), pred.cpu())
    ap = average_precision_score(labels.cpu(), pred.cpu())
    return auc, ap

## 主程序运行

In [21]:
def run(dataset_name, model_name, use_sampler=False, lr=0.01, layers=2, hidden_dim=64, epochs=100):
    print(f"\n--- Running: {dataset_name} | {model_name} | Sampler={use_sampler} | LR={lr} | Layers={layers} ---")
    data, in_dim = load_data(dataset_name)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # 模型选择
    if model_name == 'GCN': model = GCNEncoder(in_dim, hidden_dim, layers)
    elif model_name == 'GAT': model = GATEncoder(in_dim, hidden_dim, layers)
    elif model_name == 'GraphSAGE': model = SAGEEncoder(in_dim, hidden_dim, layers)
    elif model_name == 'GIN': model = GINEncoder(in_dim, hidden_dim, layers)
    else: raise ValueError("Unsupported model")
    model = model.to(device)
    data = data.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    start = time.time()
    for epoch in range(epochs):
        loss = train(model, data, optimizer, device)
        if epoch % 10 == 0:
            auc, ap = test(model, data, device)
            if auc is not None:
                print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | AUC: {auc:.4f} | AP: {ap:.4f}")
            else:
                print(f"Epoch {epoch:03d} | Loss: {loss:.4f} (No link prediction)")
    print(f"Done. Total time: {time.time() - start:.2f}s")
    
# ======== 启动入口 ========
if __name__ == "__main__":
    datasets = ['Cora', 'Citeseer', 'Flickr']
    models = ['GCN', 'GAT', 'GraphSAGE', 'GIN']
    for d in datasets:
        for m in models:
            run(d, m, use_sampler=False, lr=0.005, layers=2)
            run(d, m, use_sampler=False, lr=0.01, layers=3)


--- Running: Cora | GCN | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.6930 | AUC: 0.6910 | AP: 0.7316
Epoch 010 | Loss: 0.6696 | AUC: 0.6906 | AP: 0.7305
Epoch 020 | Loss: 0.6376 | AUC: 0.7233 | AP: 0.7544
Epoch 030 | Loss: 0.5657 | AUC: 0.7806 | AP: 0.7699
Epoch 040 | Loss: 0.5192 | AUC: 0.8305 | AP: 0.8217
Epoch 050 | Loss: 0.4861 | AUC: 0.8684 | AP: 0.8561
Epoch 060 | Loss: 0.4779 | AUC: 0.8800 | AP: 0.8666
Epoch 070 | Loss: 0.4672 | AUC: 0.8900 | AP: 0.8820
Epoch 080 | Loss: 0.4619 | AUC: 0.8947 | AP: 0.8873
Epoch 090 | Loss: 0.4599 | AUC: 0.8980 | AP: 0.8910
Done. Total time: 3.54s

--- Running: Cora | GCN | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.6931 | AUC: 0.7285 | AP: 0.7439
Epoch 010 | Loss: 0.6886 | AUC: 0.7076 | AP: 0.7320
Epoch 020 | Loss: 0.6668 | AUC: 0.6879 | AP: 0.7318
Epoch 030 | Loss: 0.5942 | AUC: 0.7977 | AP: 0.8095
Epoch 040 | Loss: 0.5626 | AUC: 0.8056 | AP: 0.8221
Epoch 050 | Loss: 0.5505 | AUC: 0.8145 | AP: 0.8356
Epoch 060 | Loss: 0.5469 | AUC: 0.8164 | AP: 0.8404
Epoch 070 | Loss: 0.5371 | AUC: 0.8266 | AP: 0.8462
Epoch 080 | Loss: 0.5010 | AUC: 0.8623 | AP: 0.8620
Epoch 090 | Loss: 0.4963 | AUC: 0.8605 | AP: 0.8635
Done. Total time: 4.24s

--- Running: Cora | GAT | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.6929 | AUC: 0.6633 | AP: 0.6651
Epoch 010 | Loss: 0.5741 | AUC: 0.7745 | AP: 0.7094
Epoch 020 | Loss: 0.5279 | AUC: 0.8375 | AP: 0.8083
Epoch 030 | Loss: 0.4823 | AUC: 0.8774 | AP: 0.8599
Epoch 040 | Loss: 0.4721 | AUC: 0.8890 | AP: 0.8788
Epoch 050 | Loss: 0.4629 | AUC: 0.8938 | AP: 0.8819
Epoch 060 | Loss: 0.4489 | AUC: 0.8972 | AP: 0.8819
Epoch 070 | Loss: 0.4469 | AUC: 0.9030 | AP: 0.8980
Epoch 080 | Loss: 0.4431 | AUC: 0.8999 | AP: 0.9041
Epoch 090 | Loss: 0.4415 | AUC: 0.8998 | AP: 0.9036
Done. Total time: 9.44s

--- Running: Cora | GAT | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.6929 | AUC: 0.6144 | AP: 0.6607
Epoch 010 | Loss: 0.6899 | AUC: 0.7096 | AP: 0.7099
Epoch 020 | Loss: 0.5582 | AUC: 0.8108 | AP: 0.7729
Epoch 030 | Loss: 0.5119 | AUC: 0.8483 | AP: 0.8027
Epoch 040 | Loss: 0.4863 | AUC: 0.8541 | AP: 0.8077
Epoch 050 | Loss: 0.4779 | AUC: 0.8632 | AP: 0.8227
Epoch 060 | Loss: 0.4587 | AUC: 0.8749 | AP: 0.8563
Epoch 070 | Loss: 0.4534 | AUC: 0.8889 | AP: 0.8802
Epoch 080 | Loss: 0.4500 | AUC: 0.8892 | AP: 0.8889
Epoch 090 | Loss: 0.4488 | AUC: 0.8896 | AP: 0.8859
Done. Total time: 12.95s

--- Running: Cora | GraphSAGE | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.7136 | AUC: 0.6076 | AP: 0.6195
Epoch 010 | Loss: 0.6916 | AUC: 0.4979 | AP: 0.5343
Epoch 020 | Loss: 0.6631 | AUC: 0.6440 | AP: 0.6194
Epoch 030 | Loss: 0.5864 | AUC: 0.7294 | AP: 0.6841
Epoch 040 | Loss: 0.5521 | AUC: 0.7831 | AP: 0.7533
Epoch 050 | Loss: 0.5265 | AUC: 0.8062 | AP: 0.7778
Epoch 060 | Loss: 0.5173 | AUC: 0.8219 | AP: 0.8060
Epoch 070 | Loss: 0.4987 | AUC: 0.8290 | AP: 0.8234
Epoch 080 | Loss: 0.4905 | AUC: 0.8414 | AP: 0.8409
Epoch 090 | Loss: 0.4824 | AUC: 0.8533 | AP: 0.8515
Done. Total time: 6.21s

--- Running: Cora | GraphSAGE | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.7230 | AUC: 0.5694 | AP: 0.6072
Epoch 010 | Loss: 0.6480 | AUC: 0.7549 | AP: 0.7241
Epoch 020 | Loss: 0.5877 | AUC: 0.7447 | AP: 0.7281
Epoch 030 | Loss: 0.5673 | AUC: 0.7704 | AP: 0.7512
Epoch 040 | Loss: 0.5552 | AUC: 0.7794 | AP: 0.7593
Epoch 050 | Loss: 0.5495 | AUC: 0.7802 | AP: 0.7653
Epoch 060 | Loss: 0.5476 | AUC: 0.7805 | AP: 0.7620
Epoch 070 | Loss: 0.5430 | AUC: 0.7833 | AP: 0.7689
Epoch 080 | Loss: 0.5239 | AUC: 0.8046 | AP: 0.7992
Epoch 090 | Loss: 0.5088 | AUC: 0.8138 | AP: 0.8191
Done. Total time: 6.75s

--- Running: Cora | GIN | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.6910 | AUC: 0.5525 | AP: 0.6420
Epoch 010 | Loss: 0.6459 | AUC: 0.6249 | AP: 0.6770
Epoch 020 | Loss: 0.6235 | AUC: 0.5839 | AP: 0.6616
Epoch 030 | Loss: 0.5986 | AUC: 0.5831 | AP: 0.6557
Epoch 040 | Loss: 0.5615 | AUC: 0.6610 | AP: 0.7079
Epoch 050 | Loss: 0.5136 | AUC: 0.7592 | AP: 0.7809
Epoch 060 | Loss: 0.5020 | AUC: 0.8032 | AP: 0.8112
Epoch 070 | Loss: 0.4767 | AUC: 0.8255 | AP: 0.8387
Epoch 080 | Loss: 0.4599 | AUC: 0.8423 | AP: 0.8554
Epoch 090 | Loss: 0.4607 | AUC: 0.8514 | AP: 0.8641
Done. Total time: 5.69s

--- Running: Cora | GIN | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.8082 | AUC: 0.5504 | AP: 0.5762
Epoch 010 | Loss: 0.6625 | AUC: 0.6169 | AP: 0.6584
Epoch 020 | Loss: 0.6139 | AUC: 0.6195 | AP: 0.6689
Epoch 030 | Loss: 0.5787 | AUC: 0.6627 | AP: 0.6985
Epoch 040 | Loss: 0.5750 | AUC: 0.6764 | AP: 0.7086
Epoch 050 | Loss: 0.5379 | AUC: 0.6923 | AP: 0.7190
Epoch 060 | Loss: 0.5347 | AUC: 0.6929 | AP: 0.7155
Epoch 070 | Loss: 0.4941 | AUC: 0.7169 | AP: 0.7399
Epoch 080 | Loss: 0.4830 | AUC: 0.7165 | AP: 0.7430
Epoch 090 | Loss: 0.4888 | AUC: 0.7357 | AP: 0.7634
Done. Total time: 5.99s

--- Running: Citeseer | GCN | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.6931 | AUC: 0.6902 | AP: 0.7023
Epoch 010 | Loss: 0.6571 | AUC: 0.6409 | AP: 0.7050
Epoch 020 | Loss: 0.5639 | AUC: 0.7882 | AP: 0.7884
Epoch 030 | Loss: 0.5423 | AUC: 0.8004 | AP: 0.7940
Epoch 040 | Loss: 0.5258 | AUC: 0.8131 | AP: 0.8046
Epoch 050 | Loss: 0.5070 | AUC: 0.8371 | AP: 0.8343
Epoch 060 | Loss: 0.4880 | AUC: 0.8675 | AP: 0.8653
Epoch 070 | Loss: 0.4846 | AUC: 0.8606 | AP: 0.8617
Epoch 080 | Loss: 0.4794 | AUC: 0.8581 | AP: 0.8574
Epoch 090 | Loss: 0.4801 | AUC: 0.8549 | AP: 0.8524
Done. Total time: 6.09s

--- Running: Citeseer | GCN | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.6931 | AUC: 0.6439 | AP: 0.6706
Epoch 010 | Loss: 0.6452 | AUC: 0.7394 | AP: 0.7632
Epoch 020 | Loss: 0.5413 | AUC: 0.7780 | AP: 0.7941
Epoch 030 | Loss: 0.5016 | AUC: 0.8546 | AP: 0.8587
Epoch 040 | Loss: 0.5029 | AUC: 0.8497 | AP: 0.8525
Epoch 050 | Loss: 0.4773 | AUC: 0.8618 | AP: 0.8659
Epoch 060 | Loss: 0.4835 | AUC: 0.8583 | AP: 0.8651
Epoch 070 | Loss: 0.4678 | AUC: 0.8574 | AP: 0.8622
Epoch 080 | Loss: 0.4668 | AUC: 0.8544 | AP: 0.8575
Epoch 090 | Loss: 0.4600 | AUC: 0.8622 | AP: 0.8648
Done. Total time: 6.37s

--- Running: Citeseer | GAT | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.6931 | AUC: 0.6685 | AP: 0.7156
Epoch 010 | Loss: 0.5832 | AUC: 0.7987 | AP: 0.7989
Epoch 020 | Loss: 0.5172 | AUC: 0.8602 | AP: 0.8578
Epoch 030 | Loss: 0.4892 | AUC: 0.8748 | AP: 0.8775
Epoch 040 | Loss: 0.4672 | AUC: 0.9060 | AP: 0.9113
Epoch 050 | Loss: 0.4567 | AUC: 0.9110 | AP: 0.9125
Epoch 060 | Loss: 0.4526 | AUC: 0.9048 | AP: 0.9099
Epoch 070 | Loss: 0.4430 | AUC: 0.9083 | AP: 0.9154
Epoch 080 | Loss: 0.4423 | AUC: 0.9033 | AP: 0.9135
Epoch 090 | Loss: 0.4440 | AUC: 0.8975 | AP: 0.9083
Done. Total time: 17.39s

--- Running: Citeseer | GAT | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.6931 | AUC: 0.5892 | AP: 0.6685
Epoch 010 | Loss: 0.8700 | AUC: 0.6611 | AP: 0.7014
Epoch 020 | Loss: 0.5886 | AUC: 0.8148 | AP: 0.8172
Epoch 030 | Loss: 0.5088 | AUC: 0.8624 | AP: 0.8594
Epoch 040 | Loss: 0.4697 | AUC: 0.8999 | AP: 0.8972
Epoch 050 | Loss: 0.4512 | AUC: 0.9016 | AP: 0.8968
Epoch 060 | Loss: 0.4499 | AUC: 0.8996 | AP: 0.8912
Epoch 070 | Loss: 0.4421 | AUC: 0.8998 | AP: 0.9025
Epoch 080 | Loss: 0.4483 | AUC: 0.8918 | AP: 0.8982
Epoch 090 | Loss: 0.4411 | AUC: 0.8916 | AP: 0.8970
Done. Total time: 21.35s

--- Running: Citeseer | GraphSAGE | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.7092 | AUC: 0.6720 | AP: 0.6665
Epoch 010 | Loss: 0.6772 | AUC: 0.6144 | AP: 0.6135
Epoch 020 | Loss: 0.5848 | AUC: 0.7627 | AP: 0.7277
Epoch 030 | Loss: 0.5587 | AUC: 0.7793 | AP: 0.7434
Epoch 040 | Loss: 0.5400 | AUC: 0.7847 | AP: 0.7865
Epoch 050 | Loss: 0.5420 | AUC: 0.7907 | AP: 0.8008
Epoch 060 | Loss: 0.5313 | AUC: 0.7966 | AP: 0.8072
Epoch 070 | Loss: 0.5235 | AUC: 0.8057 | AP: 0.8174
Epoch 080 | Loss: 0.5200 | AUC: 0.8151 | AP: 0.8296
Epoch 090 | Loss: 0.4984 | AUC: 0.8187 | AP: 0.8367
Done. Total time: 13.49s

--- Running: Citeseer | GraphSAGE | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.7184 | AUC: 0.5813 | AP: 0.6281
Epoch 010 | Loss: 0.5990 | AUC: 0.6826 | AP: 0.6951
Epoch 020 | Loss: 0.5644 | AUC: 0.7197 | AP: 0.7325
Epoch 030 | Loss: 0.5510 | AUC: 0.7425 | AP: 0.7524
Epoch 040 | Loss: 0.5363 | AUC: 0.7564 | AP: 0.7604
Epoch 050 | Loss: 0.5267 | AUC: 0.7739 | AP: 0.7788
Epoch 060 | Loss: 0.5085 | AUC: 0.7945 | AP: 0.7963
Epoch 070 | Loss: 0.5036 | AUC: 0.7962 | AP: 0.7997
Epoch 080 | Loss: 0.4981 | AUC: 0.8014 | AP: 0.8125
Epoch 090 | Loss: 0.4888 | AUC: 0.8097 | AP: 0.8221
Done. Total time: 14.04s

--- Running: Citeseer | GIN | Sampler=False | LR=0.005 | Layers=2 ---




Epoch 000 | Loss: 0.7129 | AUC: 0.5951 | AP: 0.6427
Epoch 010 | Loss: 0.6303 | AUC: 0.6001 | AP: 0.6407
Epoch 020 | Loss: 0.5813 | AUC: 0.6382 | AP: 0.6610
Epoch 030 | Loss: 0.5541 | AUC: 0.6796 | AP: 0.6917
Epoch 040 | Loss: 0.5062 | AUC: 0.7507 | AP: 0.7591
Epoch 050 | Loss: 0.4908 | AUC: 0.7838 | AP: 0.7864
Epoch 060 | Loss: 0.4830 | AUC: 0.8006 | AP: 0.8037
Epoch 070 | Loss: 0.4663 | AUC: 0.8173 | AP: 0.8166
Epoch 080 | Loss: 0.4652 | AUC: 0.8210 | AP: 0.8273
Epoch 090 | Loss: 0.4574 | AUC: 0.8291 | AP: 0.8330
Done. Total time: 11.86s

--- Running: Citeseer | GIN | Sampler=False | LR=0.01 | Layers=3 ---




Epoch 000 | Loss: 0.6590 | AUC: 0.6201 | AP: 0.6795
Epoch 010 | Loss: 0.6248 | AUC: 0.6206 | AP: 0.6895
Epoch 020 | Loss: 0.6300 | AUC: 0.6520 | AP: 0.7079
Epoch 030 | Loss: 0.5659 | AUC: 0.7131 | AP: 0.7529
Epoch 040 | Loss: 0.5223 | AUC: 0.7157 | AP: 0.7501
Epoch 050 | Loss: 0.4945 | AUC: 0.7219 | AP: 0.7474
Epoch 060 | Loss: 0.4681 | AUC: 0.7537 | AP: 0.7700
Epoch 070 | Loss: 0.4623 | AUC: 0.7591 | AP: 0.7760
Epoch 080 | Loss: 0.4699 | AUC: 0.7583 | AP: 0.7736
Epoch 090 | Loss: 0.4601 | AUC: 0.7687 | AP: 0.7824
Done. Total time: 12.26s

--- Running: Flickr | GCN | Sampler=False | LR=0.005 | Layers=2 ---
No link prediction evaluation for this dataset.
Epoch 000 | Loss: 3.6633 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 010 | Loss: 3.6620 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 020 | Loss: 1.9483 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 030 | Loss: 0.9863 (No link prediction)
No link pr

No link prediction evaluation for this dataset.
Epoch 040 | Loss: 10.8807 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 050 | Loss: 3.4849 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 060 | Loss: 2.6293 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 070 | Loss: 1.7085 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 080 | Loss: 1.2998 (No link prediction)
No link prediction evaluation for this dataset.
Epoch 090 | Loss: 0.8838 (No link prediction)
Done. Total time: 161.58s


# 图分类
## 数据集

In [38]:
# 选择数据集，支持 TUDataset 或 ZINC 数据集
def load_data(name):
    if name == 'TUDataset':
        dataset = TUDataset(root='D:\Data\master\Graph Machine Learning\GNN\standard benchmark datasets\TUDataset', name='PROTEINS')
    elif name == 'ZINC':
        dataset = ZINC(root='D:\Data\master\Graph Machine Learning\GNN\standard benchmark datasets\ZINC')
    else:
        raise ValueError("Only 'TUDataset' and 'ZINC' are supported.")
    return dataset

## 定义模型和方法

In [44]:
'''池化方法'''
class GraphPooling(nn.Module):
    def __init__(self, pooling_type='mean'):
        super(GraphPooling, self).__init__()
        if pooling_type == 'mean':
            self.pool = global_mean_pool
        elif pooling_type == 'max':
            self.pool = global_max_pool
        elif pooling_type == 'min':
            self.pool = global_add_pool  # Replace with MinPooling logic if needed
        else:
            raise ValueError("Invalid pooling type. Choose from ['mean', 'max', 'min'].")
    def forward(self, x, batch):
        return self.pool(x, batch)
    
'''GCN模型'''
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pooling, layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels)
        self.pool = pooling
    def forward(self, x, edge_index, batch):
        x = x.float()  
        edge_weight = x.new_ones(edge_index.size(1))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index, edge_weight=edge_weight))
        x = self.pool(x, batch)
        return self.lin(x)
    
'''GAT模型'''
class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pooling, layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels))
        for _ in range(layers - 2):
            self.convs.append(GATConv(hidden_channels, hidden_channels))
        self.convs.append(GATConv(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels)
        self.pool = pooling
    def forward(self, x, edge_index, batch):
        x = x.float()
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.pool(x, batch)
        return self.lin(x)
    
'''GraphSAGE模型'''
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pooling, layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels)
        self.pool = pooling
    def forward(self, x, edge_index, batch):
        x = x.float()
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.pool(x, batch)
        return self.lin(x)

'''GIN模型'''
class GIN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, pooling, layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(layers):
            mlp = Sequential(Linear(in_channels if i == 0 else hidden_channels, hidden_channels),
                             ReLU(),
                             Linear(hidden_channels, hidden_channels))
            self.convs.append(GINConv(mlp))
        self.lin = Linear(hidden_channels, out_channels)
        self.pool = pooling

    def forward(self, x, edge_index, batch):
        x = x.float()
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.pool(x, batch)
        return self.lin(x)

## 训练与检验

In [40]:
'''定义训练过程'''
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = out.argmax(dim=1) if out.shape[1] > 1 else out.view(-1)
        if out.shape[1] > 1:
            correct += pred.eq(data.y).sum().item()
    acc = correct / len(loader.dataset) if out.shape[1] > 1 else 0.0
    return total_loss / len(loader), acc

'''定义测试过程'''
def test(model, loader, device):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1) if out.shape[1] > 1 else out.view(-1)
        if out.shape[1] > 1:
            correct += pred.eq(data.y).sum().item()
    acc = correct / len(loader.dataset) if out.shape[1] > 1 else 0.0
    return acc

'''设置训练和测试过程'''
def run_all():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    datasets = ['TUDataset', 'ZINC']
    models = ['GCN', 'GAT', 'GraphSAGE', 'GIN']
    poolings = ['mean', 'max', 'min']
    lrs = [0.001, 0.005]
    layers_list = [2, 3]

    for dataset_name in datasets:
        dataset = load_data(dataset_name)
        dataset = dataset.shuffle()
        in_dim = dataset.num_node_features
        out_dim = dataset.num_classes if dataset_name == 'TUDataset' else 1

        for pooling_type in poolings:
            pool = GraphPooling(pooling_type)
            for model_name in models:
                for lr in lrs:
                    for layers in layers_list:
                        train_loader = DataLoader(dataset[:int(0.8*len(dataset))], batch_size=32, shuffle=True)
                        test_loader = DataLoader(dataset[int(0.8*len(dataset)):], batch_size=32)

                        if model_name == 'GCN':
                            model = GCN(in_dim, 64, out_dim, pool, layers).to(device)
                        elif model_name == 'GAT':
                            model = GAT(in_dim, 64, out_dim, pool, layers).to(device)
                        elif model_name == 'GraphSAGE':
                            model = GraphSAGE(in_dim, 64, out_dim, pool, layers).to(device)
                        elif model_name == 'GIN':
                            model = GIN(in_dim, 64, out_dim, pool, layers).to(device)

                        optimizer = optim.Adam(model.parameters(), lr=lr)
                        criterion = nn.CrossEntropyLoss() if dataset_name == 'TUDataset' else nn.L1Loss()

                        start = time.time()
                        for epoch in range(1, 51):
                            train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
                            test_acc = test(model, test_loader, device)
                            if epoch % 10 == 0:
                                print(f"[{dataset_name}] {model_name} | Pool={pooling_type} | LR={lr} | Layers={layers} | Epoch={epoch} | Loss={train_loss:.4f} | TrainAcc={train_acc:.4f} | TestAcc={test_acc:.4f}")
                        print(f"Total Time: {time.time() - start:.2f}s\n")

## 主程序

In [45]:
if __name__ == '__main__':
    run_all()

[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=2 | Epoch=10 | Loss=0.6388 | TrainAcc=0.6528 | TestAcc=0.7534
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=2 | Epoch=20 | Loss=0.6229 | TrainAcc=0.7000 | TestAcc=0.7444
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=2 | Epoch=30 | Loss=0.6135 | TrainAcc=0.7034 | TestAcc=0.7444
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=2 | Epoch=40 | Loss=0.6086 | TrainAcc=0.7101 | TestAcc=0.7354
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=2 | Epoch=50 | Loss=0.6090 | TrainAcc=0.7067 | TestAcc=0.7534
Total Time: 13.86s

[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=3 | Epoch=10 | Loss=0.6272 | TrainAcc=0.6775 | TestAcc=0.7309
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=3 | Epoch=20 | Loss=0.6144 | TrainAcc=0.7011 | TestAcc=0.7220
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=3 | Epoch=30 | Loss=0.6066 | TrainAcc=0.7000 | TestAcc=0.7578
[TUDataset] GCN | Pool=mean | LR=0.001 | Layers=3 | Epoch=40 | Loss=0.6065 | TrainAcc=0.7011 | TestA

[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=2 | Epoch=20 | Loss=0.6061 | TrainAcc=0.7056 | TestAcc=0.7354
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=2 | Epoch=30 | Loss=0.6035 | TrainAcc=0.7101 | TestAcc=0.7399
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=2 | Epoch=40 | Loss=0.6090 | TrainAcc=0.6944 | TestAcc=0.7399
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=2 | Epoch=50 | Loss=0.5914 | TrainAcc=0.7236 | TestAcc=0.7444
Total Time: 13.52s

[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=3 | Epoch=10 | Loss=0.6267 | TrainAcc=0.6697 | TestAcc=0.7399
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=3 | Epoch=20 | Loss=0.6048 | TrainAcc=0.7022 | TestAcc=0.7399
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=3 | Epoch=30 | Loss=0.6000 | TrainAcc=0.6989 | TestAcc=0.7399
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=3 | Epoch=40 | Loss=0.5948 | TrainAcc=0.7101 | TestAcc=0.7175
[TUDataset] GIN | Pool=mean | LR=0.005 | Layers=3 | Epoch=50 | Loss=0.6014 | TrainAcc=0.7090 | TestA

[TUDataset] GIN | Pool=max | LR=0.001 | Layers=2 | Epoch=40 | Loss=0.5027 | TrainAcc=0.7674 | TestAcc=0.7713
[TUDataset] GIN | Pool=max | LR=0.001 | Layers=2 | Epoch=50 | Loss=0.4965 | TrainAcc=0.7708 | TestAcc=0.7848
Total Time: 14.82s

[TUDataset] GIN | Pool=max | LR=0.001 | Layers=3 | Epoch=10 | Loss=0.5413 | TrainAcc=0.7472 | TestAcc=0.7803
[TUDataset] GIN | Pool=max | LR=0.001 | Layers=3 | Epoch=20 | Loss=0.5161 | TrainAcc=0.7697 | TestAcc=0.7758
[TUDataset] GIN | Pool=max | LR=0.001 | Layers=3 | Epoch=30 | Loss=0.4967 | TrainAcc=0.7640 | TestAcc=0.7848
[TUDataset] GIN | Pool=max | LR=0.001 | Layers=3 | Epoch=40 | Loss=0.4813 | TrainAcc=0.7854 | TestAcc=0.7982
[TUDataset] GIN | Pool=max | LR=0.001 | Layers=3 | Epoch=50 | Loss=0.4665 | TrainAcc=0.7820 | TestAcc=0.7758
Total Time: 19.18s

[TUDataset] GIN | Pool=max | LR=0.005 | Layers=2 | Epoch=10 | Loss=0.5706 | TrainAcc=0.7101 | TestAcc=0.7309
[TUDataset] GIN | Pool=max | LR=0.005 | Layers=2 | Epoch=20 | Loss=0.5354 | TrainAcc=0.7

[TUDataset] GraphSAGE | Pool=min | LR=0.005 | Layers=3 | Epoch=10 | Loss=0.5491 | TrainAcc=0.7483 | TestAcc=0.7085
[TUDataset] GraphSAGE | Pool=min | LR=0.005 | Layers=3 | Epoch=20 | Loss=0.5402 | TrainAcc=0.7326 | TestAcc=0.7489
[TUDataset] GraphSAGE | Pool=min | LR=0.005 | Layers=3 | Epoch=30 | Loss=0.5382 | TrainAcc=0.7348 | TestAcc=0.7175
[TUDataset] GraphSAGE | Pool=min | LR=0.005 | Layers=3 | Epoch=40 | Loss=0.5204 | TrainAcc=0.7461 | TestAcc=0.7309
[TUDataset] GraphSAGE | Pool=min | LR=0.005 | Layers=3 | Epoch=50 | Loss=0.5268 | TrainAcc=0.7438 | TestAcc=0.7444
Total Time: 15.68s

[TUDataset] GIN | Pool=min | LR=0.001 | Layers=2 | Epoch=10 | Loss=0.5649 | TrainAcc=0.6888 | TestAcc=0.7444
[TUDataset] GIN | Pool=min | LR=0.001 | Layers=2 | Epoch=20 | Loss=0.5408 | TrainAcc=0.7337 | TestAcc=0.7354
[TUDataset] GIN | Pool=min | LR=0.001 | Layers=2 | Epoch=30 | Loss=0.5406 | TrainAcc=0.7371 | TestAcc=0.7534
[TUDataset] GIN | Pool=min | LR=0.001 | Layers=2 | Epoch=40 | Loss=0.5264 | Tr

[ZINC] GraphSAGE | Pool=mean | LR=0.001 | Layers=3 | Epoch=50 | Loss=1.4948 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 2847.07s

[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=2 | Epoch=10 | Loss=1.4954 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=2 | Epoch=20 | Loss=1.4954 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=2 | Epoch=30 | Loss=1.4953 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=2 | Epoch=40 | Loss=1.4954 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=2 | Epoch=50 | Loss=1.4954 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 2207.68s

[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=3 | Epoch=10 | Loss=1.4955 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=3 | Epoch=20 | Loss=1.4953 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=mean | LR=0.005 | Layers=3 | Epoch=30 | Loss

[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=2 | Epoch=50 | Loss=1.4949 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 2108.99s

[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=3 | Epoch=10 | Loss=1.4953 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=3 | Epoch=20 | Loss=1.4952 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=3 | Epoch=30 | Loss=1.4951 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=3 | Epoch=40 | Loss=1.4951 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.001 | Layers=3 | Epoch=50 | Loss=1.4949 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 2603.77s

[ZINC] GraphSAGE | Pool=max | LR=0.005 | Layers=2 | Epoch=10 | Loss=1.4958 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.005 | Layers=2 | Epoch=20 | Loss=1.4954 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=max | LR=0.005 | Layers=2 | Epoch=30 | Loss=1.4954 |

[ZINC] GAT | Pool=min | LR=0.005 | Layers=3 | Epoch=50 | Loss=1.4953 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 4070.02s

[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=2 | Epoch=10 | Loss=1.4952 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=2 | Epoch=20 | Loss=1.4951 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=2 | Epoch=30 | Loss=1.4952 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=2 | Epoch=40 | Loss=1.4953 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=2 | Epoch=50 | Loss=1.4952 | TrainAcc=0.0000 | TestAcc=0.0000
Total Time: 1985.13s

[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=3 | Epoch=10 | Loss=1.4951 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=3 | Epoch=20 | Loss=1.4950 | TrainAcc=0.0000 | TestAcc=0.0000
[ZINC] GraphSAGE | Pool=min | LR=0.001 | Layers=3 | Epoch=30 | Loss=1.4950 | Train

KeyboardInterrupt: 