In [15]:
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric.datasets import KarateClub

dataset = KarateClub()

data = dataset[0]  # Get the first graph object.
print(data.x.shape)
print(data)

torch.Size([34, 34])
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])


This code defines a Graph Convolutional Network (GCN) using PyTorch Geometric. It consists of a graph convolution layer (GCNConv) that processes node features based on graph structure, followed by a fully connected (Linear) layer for classification.

In [16]:
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn = GCNConv(dataset.num_features, 3)
        self.out = Linear(3, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.gcn.forward(x,edge_index).relu()
        z = self.out.forward(h)
        return h, z

This code implements the training loop for a Graph Convolutional Network (GCN), optimizing it using cross-entropy loss and the Adam optimizer. During each epoch, the model performs a forward pass, computes the loss and accuracy, updates parameters using backpropagation, and stores embeddings, losses, and predictions for analysis.

In [19]:
model = GCN()
print(model)

critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

losses = []
embeddings = []
accuracies = []
outputs = []

# Calculate accuracy
def accuracy(pred_y, y):
    return (pred_y == y).sum() / len(y)

for epoch in range(300):
    
    optimizer.zero_grad()
    h,z = model(data.x, data.edge_index)
    loss = critereon.forward(z, data.y)

    loss.backward()

    optimizer.step()

    embeddings.append(h)
    outputs.append(z)
    accuracies.append(accuracy(z.argmax(dim=1), data.y))
    losses.append(loss.item())

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss {loss.item()}')

GCN(
  (gcn): GCNConv(34, 3)
  (out): Linear(in_features=3, out_features=4, bias=True)
)
Epoch 0, Loss 1.3743200302124023
Epoch 10, Loss 1.3137054443359375
Epoch 20, Loss 1.255401611328125
Epoch 30, Loss 1.1938754320144653
Epoch 40, Loss 1.1278448104858398
Epoch 50, Loss 1.0522974729537964
Epoch 60, Loss 0.9466291069984436
Epoch 70, Loss 0.8156130313873291
Epoch 80, Loss 0.691918671131134
Epoch 90, Loss 0.588395893573761
Epoch 100, Loss 0.5068764090538025
Epoch 110, Loss 0.4448983669281006
Epoch 120, Loss 0.3979796767234802
Epoch 130, Loss 0.36184149980545044
Epoch 140, Loss 0.3326040208339691
Epoch 150, Loss 0.30815961956977844
Epoch 160, Loss 0.28695911169052124
Epoch 170, Loss 0.2680508494377136
Epoch 180, Loss 0.2510685324668884
Epoch 190, Loss 0.23542404174804688
Epoch 200, Loss 0.22079217433929443
Epoch 210, Loss 0.2068587839603424
Epoch 220, Loss 0.19322995841503143
Epoch 230, Loss 0.1796860247850418
Epoch 240, Loss 0.1662464290857315
Epoch 250, Loss 0.1527620553970337
Epoch 260

In [None]:
print(f'Final embedding shape: {h.shape}')
print(h)
print(f'Final output shape: {z.shape}')
print(z)

Final embedding shape: torch.Size([34, 3])
tensor([[-2.5098,  3.8000, -5.5853, -1.6394],
        [-3.0547,  5.8089, -8.4287, -3.4921],
        [-0.9497,  3.4924, -5.8904, -4.7216],
        [-2.6197,  5.1946, -7.6938, -3.5373],
        [-1.3203, -0.5473,  0.5628,  2.3463],
        [-1.2873, -0.6679,  0.7333,  2.4569],
        [-1.2711, -0.7269,  0.8168,  2.5110],
        [-1.4883,  3.7743, -6.0576, -3.9281],
        [ 3.0579, -1.0406, -0.8672, -6.8731],
        [-0.5075,  3.0681, -5.4539, -5.0759],
        [-1.3194, -0.5505,  0.5673,  2.3493],
        [-1.4268,  3.3707, -5.4622, -3.4465],
        [-1.7390,  3.5268, -5.5479, -2.9754],
        [-0.9555,  3.2057, -5.4427, -4.2667],
        [ 3.5526, -1.7742,  0.0228, -6.8704],
        [ 3.6602, -1.8224,  0.0437, -7.0414],
        [-1.3038, -0.6074,  0.6478,  2.4014],
        [-1.6642,  3.6079, -5.7113, -3.2710],
        [ 3.6431, -1.8146,  0.0400, -7.0146],
        [-0.8354,  2.8712, -4.9840, -4.0251],
        [ 3.3390, -1.6543, -0.0561, -