In [2]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
pip install ogb

Collecting ogb
  Using cached ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting outdated>=0.2.0 (from ogb)
  Using cached outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)
Collecting littleutils (from outdated>=0.2.0->ogb)
  Using cached littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)
Using cached ogb-1.3.6-py3-none-any.whl (78 kB)
Using cached outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Using cached littleutils-0.2.4-py3-none-any.whl (8.1 kB)
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.4 ogb-1.3.6 outdated-0.2.2

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


### 1. Dataset

We will use GIN dataset: $1,113$ graphs (in 2 classes), each ranges between $[10,500]$ nodes


In [9]:
dataset=dgl.data.GINDataset('PROTEINS',self_loop=True)

print(f"num_classes: {dataset.gclasses} | feature_dim: {dataset.dim_nfeats}")
print(dataset)

print("---- First graph -----")
g=dataset[0]
print(g)

num_classes: 2 | feature_dim: 3
Dataset("PROTEINS", num_graphs=1113, save_path=/Users/doductai/.dgl/PROTEINS_0c2c49a1)
---- First graph -----
(Graph(num_nodes=42, num_edges=204,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={}), tensor(0))


In [10]:
# prepare train and test set
import numpy as np

# data loader
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_train=int(0.8*len(dataset))

# randomly permute train indices between [0,num_train]
train_sampler=SubsetRandomSampler(torch.arange(num_train))
# randomly permute test indices 
test_sampler=SubsetRandomSampler(torch.arange(num_train,len(dataset)))

train_loader=GraphDataLoader(dataset,sampler=train_sampler,batch_size=4,drop_last=False)
test_loader=GraphDataLoader(dataset,sampler=test_sampler,batch_size=4,drop_last=False)

# print out a batch in train_loader
batched_graph, labels=next(iter(train_loader))
print(batched_graph)
print(labels)




Graph(num_nodes=109, num_edges=521,
      ndata_schemes={'attr': Scheme(shape=(3,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})
tensor([0, 0, 1, 1])


In [11]:
# random prediction on test set
labels=dataset.labels
labels_test=labels[num_train:]

prob_random=sum(labels_test)/len(labels)
print(f"Random prediction: {prob_random}")

Random prediction: 0.20035938918590546


### 2. GNN with GCN
input -> AtomEncoder -> sequence[GraphConv + bn +  relu] -> mean_nodes

In [12]:
from dgl.nn import GraphConv
from ogb.graphproppred.mol_encoder import AtomEncoder # encoder for atoms in molecular graph

class GNN_GCN(nn.Module):
    def __init__(self,in_dim,hidden_dim,out_dim,h_layers=8):
        super().__init__()
        self.node_encoder=AtomEncoder(hidden_dim)
        self.h_layers=h_layers

        # stack of GCNs
        self.convs=nn.ModuleList([GraphConv(hidden_dim,hidden_dim) for _ in range(h_layers-1)])
        self.convs.append(GraphConv(hidden_dim,out_dim)) # last layer

        # stack of bn
        self.bns=nn.ModuleList([torch.nn.BatchNorm1d(hidden_dim) for _ in range(h_layers-1)])
    
    def forward(self,g,x):
        # node encoding
        h=self.node_encoder(x)

        # sequence of conv+bn+relu
        for i in range(self.h_layers-1):
            h=self.convs[i](g,h)
            h=F.relu(self.bns[i](h))

        # last conv
        h=self.convs[-1](g,h)

        # compute mean of all node features
        g.ndata['h']=h
        # out logits= mean_nodes
        mean_feat=dgl.mean_nodes(g,'h') # [out_dim,]

        return mean_feat



In [13]:
in_dim=dataset.dim_nfeats
hidden_dim=64
out_dim=dataset.gclasses

model=GNN_GCN(in_dim,hidden_dim,out_dim)
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")

# check 1 forward pass
with torch.no_grad():
    model.eval()
    pred_logits=model(batched_graph,batched_graph.ndata['attr'].long())
print(pred_logits)

0.041282 million parameters
tensor([[-0.0256,  0.0489],
        [-0.0254,  0.0486],
        [-0.0242,  0.0550],
        [-0.0252,  0.0497]])


In [14]:
torch.manual_seed(1337)

model=GNN_GCN(in_dim,hidden_dim,out_dim)
optimizer=torch.optim.AdamW(model.parameters(),lr=4e-3)
num_epochs=100

model.train()
for epoch in range(num_epochs):
    loss_total=0
    for batched_graph,labels in train_loader:
        logits=model(batched_graph,batched_graph.ndata['attr'].long())
        loss=F.cross_entropy(logits,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update loss
        loss_total+=loss.item()
    loss_total/=len(train_loader)
    if epoch%5==0 or epoch==num_epochs-1:
        print(f"Epoch: {epoch} | Loss: {loss_total:.4f}")

Epoch: 0 | Loss: 0.5978
Epoch: 5 | Loss: 0.5446
Epoch: 10 | Loss: 0.5172
Epoch: 15 | Loss: 0.5318
Epoch: 20 | Loss: 0.5436
Epoch: 25 | Loss: 0.5356
Epoch: 30 | Loss: 0.5281
Epoch: 35 | Loss: 0.5270
Epoch: 40 | Loss: 0.5197
Epoch: 45 | Loss: 0.5241
Epoch: 50 | Loss: 0.5184
Epoch: 55 | Loss: 0.5180
Epoch: 60 | Loss: 0.5282
Epoch: 65 | Loss: 0.5206
Epoch: 70 | Loss: 0.5130
Epoch: 75 | Loss: 0.5217
Epoch: 80 | Loss: 0.5042
Epoch: 85 | Loss: 0.5215
Epoch: 90 | Loss: 0.5225
Epoch: 95 | Loss: 0.5152
Epoch: 99 | Loss: 0.5158


In [15]:
# evaluation
num_correct, num_data, loss_total =0 ,0,0

with torch.no_grad():
    model.eval()
    for batched_graph,labels in test_loader:
        logits=model(batched_graph,batched_graph.ndata['attr'].long())
        loss=F.cross_entropy(logits,labels)
        loss_total+=loss.item()
        num_correct += (logits.argmax(1)== labels).sum().item()
        num_data+=len(labels)
loss_total/=len(test_loader)
test_acc=num_correct/num_data

print(f"GNN-GCN | Test loss : {loss_total:.4f} | test accuracy={test_acc:.4f}")

GNN-GCN | Test loss : 1.3148 | test accuracy=0.2287


### 3. GNN with SageConv
input -> AtomEncoder -> sequence[SageConv + bn +  relu] -> mean_nodes -> classifier

In [16]:
from dgl.nn import SAGEConv

class GNN_Sage(nn.Module):
    def __init__(self,in_dim,hidden_dim,out_dim,h_layers=8):
        super().__init__()
        self.node_encoder=AtomEncoder(hidden_dim)
        self.h_layers=h_layers

        # stack of GCNs
        self.convs=nn.ModuleList([SAGEConv(hidden_dim,hidden_dim, aggregator_type="mean") for _ in range(h_layers-1)])
        self.convs.append(SAGEConv(hidden_dim,out_dim, aggregator_type="mean")) # last layer

        # stack of bn
        self.bns=nn.ModuleList([torch.nn.BatchNorm1d(hidden_dim) for _ in range(h_layers-1)])
    
    def forward(self,g,x):
        # node encoding
        h=self.node_encoder(x)

        # sequence of conv+bn+relu
        for i in range(self.h_layers-1):
            h=self.convs[i](g,h)
            h=F.relu(self.bns[i](h))

        # last conv
        h=self.convs[-1](g,h)

        # compute mean of all node features
        g.ndata['h']=h
        # out logits= mean_nodes
        mean_feat=dgl.mean_nodes(g,'h') # [out_dim,]

        return mean_feat



In [17]:
torch.manual_seed(1442)

model=GNN_Sage(in_dim,hidden_dim,out_dim)
optimizer=torch.optim.AdamW(model.parameters(),lr=4e-3)
num_epochs=100

model.train()
for epoch in range(num_epochs):
    loss_total=0
    for batched_graph,labels in train_loader:
        logits=model(batched_graph,batched_graph.ndata['attr'].long())
        loss=F.cross_entropy(logits,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update loss
        loss_total+=loss.item()
    loss_total/=len(train_loader)
    if epoch%5==0 or epoch==num_epochs-1:
        print(f"Epoch: {epoch} | Loss: {loss_total:.4f}")

Epoch: 0 | Loss: 0.6501
Epoch: 5 | Loss: 0.5423
Epoch: 10 | Loss: 0.5502
Epoch: 15 | Loss: 0.5447
Epoch: 20 | Loss: 0.5384
Epoch: 25 | Loss: 0.5287
Epoch: 30 | Loss: 0.5129
Epoch: 35 | Loss: 0.5198
Epoch: 40 | Loss: 0.5339
Epoch: 45 | Loss: 0.5172
Epoch: 50 | Loss: 0.5145
Epoch: 55 | Loss: 0.5185
Epoch: 60 | Loss: 0.5137
Epoch: 65 | Loss: 0.5190
Epoch: 70 | Loss: 0.5160
Epoch: 75 | Loss: 0.5129
Epoch: 80 | Loss: 0.5090
Epoch: 85 | Loss: 0.5073
Epoch: 90 | Loss: 0.5004
Epoch: 95 | Loss: 0.5012
Epoch: 99 | Loss: 0.5093


In [18]:
# evaluation
num_correct, num_data, loss_total =0 ,0,0

with torch.no_grad():
    model.eval()
    for batched_graph,labels in test_loader:
        logits=model(batched_graph,batched_graph.ndata['attr'].long())
        loss=F.cross_entropy(logits,labels)
        loss_total+=loss.item()
        num_correct += (logits.argmax(1)== labels).sum().item()
        num_data+=len(labels)
loss_total/=len(test_loader)
test_acc=num_correct/num_data

print(f"GNN-Sage | Test loss : {loss_total:.4f} | test accuracy={test_acc:.4f}")

GNN-Sage | Test loss : 1.1333 | test accuracy=0.3318
