# 02. Cluster-GCN应用（PubMed-半监督）

## 1. 数据集

In [None]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures


dataset = Planetoid(root='./dataset', name='PubMed', transform=NormalizeFeatures())

print(f'数据集：{dataset}')
print(f'图数量：{len(dataset)}')
print(f'节点特征数量：{dataset.num_features}')
print(f'节点标签数量：{dataset.num_classes}')


In [None]:
# 获取第一个图对象
data = dataset[0]
print(f'图对象：{data}')

# 获取这张图的统计信息
print(f'节点数量：{data.num_nodes}')
print(f'边数量：{data.num_edges}')
print(f'节点平均度数：{data.num_edges / data.num_nodes:.2f}')
print(f'训练节点率：{int(data.train_mask.sum()) / data.num_nodes:.3f}')
print(f'是否存在孤立结点：{data.has_isolated_nodes()}')
print(f'是否自环图：{data.has_self_loops()}')
print(f'是否无向图：{data.is_undirected()}')


In [None]:
print(f'训练集节点数：{data.train_mask.sum()}')
print(f'验证集节点数：{data.val_mask.sum()}')
print(f'测试集节点数：{data.test_mask.sum()}')


In [None]:
import matplotlib.pyplot as plt

# 统计每个类别的数量
unique_labels, counts = torch.unique(data.y, return_counts=True)

# 打印每个类别的数量
for label, count in zip(unique_labels, counts):
    print(f'类别 {label.item()}: {count.item()} 个样本')

# 绘制每个类别的数量
plt.bar(unique_labels.numpy(), counts.numpy())
plt.xlabel('labels')
plt.ylabel('counts')
plt.title('labels distribution')

# 在每个条形图上添加数字标签
for label, count in zip(unique_labels, counts):
    plt.text(label.item(), count.item(), str(count.item()), ha='center', va='bottom')

plt.show()


In [5]:
# networkx 可视化
# import networkx as nx
# from torch_geometric.utils import to_networkx

# # 创建无向图
# G = to_networkx(data, to_undirected=True)

# pos = nx.spring_layout(G, seed=42)

# # 创建图对象
# fig = plt.figure(figsize=(10, 8))

# # 可视化
# nx.draw_networkx_nodes(G, pos, node_size=30, node_color=data.y, cmap='Set2')
# plt.title('PubMed Dataset Visualization')
# plt.show()


## 2. Full-batch 加载训练

### 2.1 采用GCN算子构造模型

In [None]:
# 采用GCN算子设计模型
import torch.nn as nn
from torch_geometric.nn import GCNConv
import torch.nn.functional as F


class GCN(nn.Module):
    def __init__(self, out_channels=3):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 128)
        self.conv2 = GCNConv(128, 32)
        self.conv3 = GCNConv(32, out_channels)
        
        self.dp = nn.Dropout(p=0.5)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dp(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dp(x)
        x = self.conv3(x, edge_index)
        return x
    
gcn_model = GCN(out_channels=dataset.num_classes)
print(gcn_model)

In [None]:
from torchinfo import summary

summary(gcn_model, input_data=(data.x, data.edge_index), device='cuda')

In [None]:
# 训练和评估
import os
import time
import torch.optim as optim
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gcn_model = gcn_model.to(device)
optimizer = optim.Adam(gcn_model.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
data = data.to(device)

def train(model):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test(model):
    model.eval()
    out = model(data.x, data.edge_index)
    preds = out.argmax(dim=1)
    
    
    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = preds[mask] == data.y[mask]
        accs.append(int(correct.sum()) / int(mask.sum()))
    return accs

start_time = time.time()
for epoch in range(1000):
    loss = train(gcn_model)
    train_acc, val_acc, test_acc = test(gcn_model)
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1}/1000, Loss: {loss:.4f} Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
torch.save(gcn_model, os.path.join('weights', 'gcn_pubmed_model.pth'))
end_time = time.time()

print(f'训练时间：{end_time - start_time:.2f}秒')

### 2.2 采用Cluster-GCN算子构造模型

In [None]:
from torch_geometric.nn import ClusterGCNConv

class ClusterGCN(nn.Module):
    def __init__(self, out_channels=3):
        super(ClusterGCN, self).__init__()
        self.conv1 = ClusterGCNConv(dataset.num_features, 128)
        self.conv2 = ClusterGCNConv(128, 32)
        self.conv3 = ClusterGCNConv(32, out_channels)
        
        self.dp = nn.Dropout(p=0.5)
    
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dp(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dp(x)
        x = self.conv3(x, edge_index)
        return x
    
cluster_gcn_model = ClusterGCN(out_channels=dataset.num_classes)
print(cluster_gcn_model)

summary(cluster_gcn_model, input_data=(data.x, data.edge_index), device='cuda')

In [None]:
import os
# 训练和评估
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cluster_gcn_model = cluster_gcn_model.to(device)
optimizer = optim.Adam(cluster_gcn_model.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
data = data.to(device)


start_time = time.time()
for epoch in range(1000):
    loss = train(cluster_gcn_model)
    train_acc, val_acc, test_acc = test(cluster_gcn_model)
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1}/1000, Loss: {loss:.4f} Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
torch.save(cluster_gcn_model, os.path.join('weights', 'cluster_gcn_pubmed_model.pth'))
end_time = time.time()

print(f'训练时间：{end_time - start_time:.2f}秒')


## 3. ClusterLoader 加载训练

### 3.1 数据加载

In [None]:
from torch_geometric.loader import ClusterLoader, ClusterData

torch.manual_seed(42)

data = data.cpu()

# 创建子图
cluster_data = ClusterData(data, num_parts=128)

# 创建ClusterLoader 随机分区
train_loader = ClusterLoader(cluster_data, batch_size=32, shuffle=True)


total_num_nodes = 0

for step, sub_data in enumerate(train_loader):
   print(f'Step: {step + 1}')
   print('='*20)
   print(f'Number of nodes in the current batch: {sub_data.num_nodes}')
   print(sub_data)
   print()
   total_num_nodes += sub_data.num_nodes
   
print(f'Iterated over {total_num_nodes}  of {data.num_nodes} nodes！')
