In [1]:
# simple GNN from scratch using PyTorch

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from functools import reduce
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import torch
import torch.nn as nn
import torch.nn.functional as F

# protein classification dataset from Hugging Face
df = pd.read_json("hf://datasets/graphs-datasets/PROTEINS/full.jsonl", lines=True)

# exclude one row where num of nodes doesn't match
df = df[df.edge_index.map(lambda a: len(np.unique(a[1]))) == df.node_feat.map(len)]

In [2]:
# helper
def block_diag_pad(arrays):

    total_size = sum(array.shape[0] for array in arrays)
    result = np.zeros((total_size, total_size))
    row_start = 0
    col_start = 0
    
    for array in arrays:
        size = array.shape[0]
        result[row_start:row_start + size, col_start:col_start + size] = array
        row_start += size
        col_start += size
    
    return result

def combine_list(ls):
    return reduce(lambda a,b: a+b, ls)
    
def random_index_generator(max_len, batch_size):
    shuffled_idxs = np.random.permutation(max_len)
    for i in range(len(shuffled_idxs))[::batch_size]:
        yield shuffled_idxs[i:i+batch_size]

def batch_generator(adj_mat_list, feature_list, labels_vec, batch_size):
    
    for idxs in random_index_generator(len(labels_vec), batch_size):
        
        actual_batch_size = torch.tensor(len(idxs))

        A_batch = block_diag_pad([adj_mat_list[i] for i in idxs])
        X_batch = combine_list([feature_list[i] for i in idxs])
        batch_index = [[i]*len(adj_mat_list[idx]) for i, idx in zip(range(batch_size), idxs)]

        # cast into tensor
        A_batch = torch.tensor(A_batch).float()
        X_batch = torch.tensor(X_batch).float()
        Y_batch = torch.tensor(labels_vec[idxs]).float()
        batch_index = torch.tensor(combine_list(batch_index))
        
        yield A_batch, X_batch, Y_batch, batch_index, actual_batch_size

In [3]:
adj_mat_list = []
feature_list = []

for i in range(len(df)):

    edge_pairs = list(zip(*df.edge_index.iloc[i]))
    graph = nx.Graph(edge_pairs)

    # normalized adjacency matrix
    A = nx.adjacency_matrix(graph).toarray()
    D = np.diag(np.sum(A, axis=1))
    D_ = np.diag(1.0 / np.sqrt(np.diag(D)))
    D_[np.isinf(D_)] = 0
    A_ = D_ @ A @ D_

    adj_mat_list.append(A_)
    feature_list.append(df.node_feat.iloc[i])

labels_vec = np.eye(2)[[i[0] for i in df.y]]

In [4]:
class GNNLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GNNLayer, self).__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, A, X):
        out = A @ X
        out = self.linear(out)
        return F.relu(out)

class BasicGNN(nn.Module):

    def __init__(self,):
        super().__init__()
        self.GNNlayer1 = GNNLayer(3, 32)
        self.GNNlayer2 = GNNLayer(32, 64)
        self.GNNlayer3 = GNNLayer(64, 64)
        self.readoutLayer = nn.Linear(64,2)

    def forward(self, A, X, batch_size, batch_index):
        
        # message propogation
        H = self.GNNlayer1(A, X)
        H = self.GNNlayer2(A, H)
        H = self.GNNlayer3(A, H)

        # graph level mean
        out = torch.zeros((batch_size, H.size(1)))
        for graph_id in range(batch_size):
            mask = batch_index == graph_id
            out[graph_id] = H[mask].mean(dim=0)

        out = self.readoutLayer(out)
        out = F.relu(out)

        return out

model = BasicGNN()
model = model

In [8]:
# Loss function & optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


# training loop
epochs = 80
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for batch in batch_generator(adj_mat_list, feature_list, labels_vec, 8):
        A_batch, X_batch, Y_batch, batch_index, batch_size = batch
        
        # call model
        Y_hat = model(A_batch, X_batch, batch_size, batch_index)

        # update
        loss = criterion(Y_hat, Y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    # Print average loss for the epoch
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(df):.4f}")

Epoch [1/80], Loss: 0.0809
Epoch [2/80], Loss: 0.0805
Epoch [3/80], Loss: 0.0800
Epoch [4/80], Loss: 0.0794
Epoch [5/80], Loss: 0.0788
Epoch [6/80], Loss: 0.0783
Epoch [7/80], Loss: 0.0777
Epoch [8/80], Loss: 0.0769
Epoch [9/80], Loss: 0.0763
Epoch [10/80], Loss: 0.0756
Epoch [11/80], Loss: 0.0753
Epoch [12/80], Loss: 0.0748
Epoch [13/80], Loss: 0.0742
Epoch [14/80], Loss: 0.0740
Epoch [15/80], Loss: 0.0738
Epoch [16/80], Loss: 0.0735
Epoch [17/80], Loss: 0.0730
Epoch [18/80], Loss: 0.0727
Epoch [19/80], Loss: 0.0728
Epoch [20/80], Loss: 0.0726
Epoch [21/80], Loss: 0.0724
Epoch [22/80], Loss: 0.0724
Epoch [23/80], Loss: 0.0720
Epoch [24/80], Loss: 0.0718
Epoch [25/80], Loss: 0.0718
Epoch [26/80], Loss: 0.0716
Epoch [27/80], Loss: 0.0715
Epoch [28/80], Loss: 0.0712
Epoch [29/80], Loss: 0.0713
Epoch [30/80], Loss: 0.0710
Epoch [31/80], Loss: 0.0709
Epoch [32/80], Loss: 0.0709
Epoch [33/80], Loss: 0.0712
Epoch [34/80], Loss: 0.0712
Epoch [35/80], Loss: 0.0710
Epoch [36/80], Loss: 0.0710
E

In [9]:
with torch.no_grad():
    model.eval()

    y_true = []
    y_pred = []

    for batch in batch_generator(adj_mat_list, feature_list, labels_vec, 16):
        A_batch, X_batch, Y_batch, batch_index, batch_size = batch
    
        # call model
        Y_hat = model(A_batch, X_batch, batch_size, batch_index)
        y_true += list(torch.argmax(Y_batch, axis=1).numpy())
        y_pred += list(torch.argmax(Y_hat, axis=1).numpy())

from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.65      0.95      0.77       661
           1       0.78      0.24      0.37       447

    accuracy                           0.67      1108
   macro avg       0.71      0.60      0.57      1108
weighted avg       0.70      0.67      0.61      1108

