In [15]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid

In [16]:
# 1.加载Cora数据集
dataset = Planetoid(root='./data/Cora', name='Cora')

In [17]:
# 2.定义GCNConv网络
class GCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = pyg_nn.GCNConv(num_node_features, 16)
        self.conv2 = pyg_nn.GCNConv(16, 10)
        self.conv3 = pyg_nn.GCNConv(10, num_classes)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = self.conv1(x, edge_index)
        x = F.leaky_relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        
        return F.softmax(x, dim=1)

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 100 # 学习轮数
lr = 0.003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图

In [19]:
# 3.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数

In [31]:
help(torch.eq)

Help on built-in function eq in module torch:

eq(...)
    eq(input, other, *, out=None) -> Tensor
    
    Computes element-wise equality
    
    The second argument can be a number or a tensor whose shape is
    :ref:`broadcastable <broadcasting-semantics>` with the first argument.
    
    Args:
        input (Tensor): the tensor to compare
        other (Tensor or float): the tensor or value to compare
    
    Keyword args:
        out (Tensor, optional): the output tensor.
    
    Returns:
        A boolean tensor that is True where :attr:`input` is equal to :attr:`other` and False elsewhere
    
    Example::
    
        >>> torch.eq(torch.tensor([[1, 2], [3, 4]]), torch.tensor([[1, 1], [4, 4]]))
        tensor([[ True, False],
                [False, True]])



In [29]:
a = torch.randn(4, 4)
print(a.argmax(axis=0))

<built-in method argmax of Tensor object at 0x000001B567F23C20>


In [21]:
# 训练模式
model.train()

for epoch in range(epochs):
    optimizer.zero_grad()
    pred = model(data)
    
    loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失
    correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目
    acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度
    
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print("【EPOCH: 】%s" % str(epoch + 1))
        print('训练损失为：{:.4f}'.format(loss.item()), '训练精度为：{:.4f}'.format(acc_train))

print('【Finished Training！】')

【EPOCH: 】10
训练损失为：-0.1799 训练精度为：0.5214
【EPOCH: 】20
训练损失为：-0.2918 训练精度为：0.6643
【EPOCH: 】30
训练损失为：-0.4045 训练精度为：0.6857
【EPOCH: 】40
训练损失为：-0.5380 训练精度为：0.7571
【EPOCH: 】50
训练损失为：-0.6459 训练精度为：0.8214
【EPOCH: 】60
训练损失为：-0.7067 训练精度为：0.8286
【EPOCH: 】70
训练损失为：-0.7622 训练精度为：0.8357
【EPOCH: 】80
训练损失为：-0.7893 训练精度为：0.8857
【EPOCH: 】90
训练损失为：-0.8537 训练精度为：0.9071
【EPOCH: 】100
训练损失为：-0.8472 训练精度为：0.8857
【Finished Training！】


In [22]:
# 模型验证
model.eval()
pred = model(data)

In [23]:
# 训练集（使用了掩码）
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()

In [24]:
# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()

In [25]:
print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test  Accuracy: {:.4f}'.format(acc_test), 'Test  Loss: {:.4f}'.format(loss_test))

Train Accuracy: 0.9786 Train Loss: -0.9654
Test  Accuracy: 0.7700 Test  Loss: -0.7208
