In [4]:
# !pip install pydantic
# !pip install PyYAML 
# !pip install numpy==1.26.4

# !pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2

# !pip install dgl -f https://data.dgl.ai/wheels/repo.html
# !pip install ogb



Collecting pydantic
  Using cached pydantic-2.8.2-py3-none-any.whl.metadata (125 kB)
Collecting annotated-types>=0.4.0 (from pydantic)
  Using cached annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)
Collecting pydantic-core==2.20.1 (from pydantic)
  Using cached pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.6 kB)
Using cached pydantic-2.8.2-py3-none-any.whl (423 kB)
Using cached pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl (1.8 MB)
Using cached annotated_types-0.7.0-py3-none-any.whl (13 kB)
Installing collected packages: pydantic-core, annotated-types, pydantic
Successfully installed annotated-types-0.7.0 pydantic-2.8.2 pydantic-core-2.20.1
Collecting PyYAML
  Using cached PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.1 kB)
Using cached PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl (167 kB)
Installing collected packages: PyYAML
Successfully installed PyYAML-6.0.1
Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp311-cp311-macosx_

In [1]:
import torch
import dgl
import dgl.nn as dglnn
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# set dgl backend to pytorch
import os
os.environ['DGLBACKEND'] = 'pytorch'

from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Architecture
Original Paper: https://arxiv.org/pdf/2012.09699
<img src="graph_tsfm.png" width="300" height='500'>

## 1. Sparse attention

### (a) Traditional Attention: $Q, K, V$
1. Attention coefficients:
$$\text{attn\_coeffs}(Q,K)=softmax\left(\dfrac{QK^T}{\sqrt{d_k}}\right) \ \text{with} \ d_k = \text{dimensionality of} \ Q, K$$
2. Combine attention coefficients and values to give output
$$\text{attention}(Q,K,V)=\text{attn\_coeffs(Q,K)}V$$

### (b) Sparse Attention
Inputs: (1) hidden representation of node features $H = [N,\text{hidden\_dim}]$, (2) adjacency matrix $A$
1. Project the input into query, key, value: Q, K, V = $[N,\text{hidden\_dim}]$
2. Reshape Q,K,V according to number of heads $n_h$: $\text{hidden\_dim}=d_hn_h$ -> $Q,K,V=[N,d_h,n_h]$
3. Compute sparse attention: * means pointwise multiplication
$$\text{attention}(Q,K,A)=\text{softmax}\left(\dfrac{QK^T*A}{\sqrt{d_h}}\right),$$ 
    where the matrix $A$ controls the attention to be computed: $Q_iK_j^T$ is computed $\Leftrightarrow A_{ij}=1 \Leftrightarrow j\in N(i)$.

Remark: $K^T=K.\text{tranpose}(1,0)=[d_h,N,n_h]$, $QK^T=[N,N,n_h]$, the pointwise multiplication $QK^T*A$ is applied on all $n_h$ heads.
4. Combine with value $V$:
$$\text{SparseAttention(Q,K,V)}=\text{attention}(Q,K,A)V $$
with $\text{attention}(Q,K,A)=[N,N,n_h], V=[N,d_h,n_h]$ and $\text{SparseAttention(Q,K,V)}=[N,d_h,n_h]$
5. (Optional) Reshape that output into $[N,\text{hidden\_dim}]$ and project it:
$$\text{SparseAttention(Q,K,V)}=W\cdot\text{SparseAttention(Q,K,V)}.reshape(N,-1)$$

In [10]:
# sparse attention module
class SparseMHA(nn.Module):
    def __init__(self,hidden_dim=80,num_heads=8):
        super().__init__()
        self.hidden_dim=hidden_dim
        self.num_heads=num_heads

        self.linear_q=nn.Linear(hidden_dim,hidden_dim)
        self.linear_k=nn.Linear(hidden_dim,hidden_dim)
        self.linear_v=nn.Linear(hidden_dim,hidden_dim)

        # projection of the output
        self.out_proj=nn.Linear(hidden_dim,hidden_dim)

    def forward(self,A,h):
        # A: [N,N], h: [N,hidden_dim]
        N=len(h)
        nh=self.num_heads
        dh=self.hidden_dim//nh

        # compute query,key, value
        q=self.linear_q(h).reshape(N,dh,nh)
        k=self.linear_k(h).reshape(N,dh,nh)
        v=self.linear_k(h).reshape(N,dh,nh)

        # compute attention scores by sparse matrix API: dglsp.bsddmm(A,X1,X2)
        #                   compute (X1@X2)*A with X1,X2: dense matrices [N,dh,nh], [dh,N,nh]
        #                   the pointwise multiplication applied along the last dim (batch dim = last dim)
        attention_scores=dglsp.bsddmm(A,q,k.transpose(1,0)) # sparse [N,N,nh]

        # sparse softmax: apply on the last dim by default
        attention_scores=attention_scores.softmax()         # (sparse) [N,N,nh]

        # apply value V: dglsp.bspmm(A,V) multiplies sparse matrix by dense matrix by batches
        #                A=[N,N,nh], V=[N,dh,nh] -> output = [N,dh,nh]
        out=dglsp.bspmm(attention_scores,v) # [N,dh,nh]

        # concatentate the heads
        out=out.reshape(N,-1) # [N,hidden_dim]

        # project the output
        return self.out_proj(out)


## 2. Graph Transformer Layer
<center>
<img src="graph_tsfm.png" width="300" height='500'>

In [11]:
class GTLayer(nn.Module):
    def __init__(self,hidden_dim=80,num_heads=8):
        super().__init__()

        self.attention=SparseMHA(hidden_dim,num_heads)
        self.hidden_dim=hidden_dim
        self.num_heads=num_heads

        self.bn1=nn.BatchNorm1d(hidden_dim)
        self.bn2=nn.BatchNorm1d(hidden_dim)

        self.ffn=nn.Sequential(nn.Linear(hidden_dim,2*hidden_dim),
                              nn.ReLU(),
                              nn.Linear(2*hidden_dim,hidden_dim))
    def forward(self,A,h):
        # A: [N,N], h: [N,hidden_dim]

        # First add and norm
        h1=self.attention(A,h) # [N,hidden_dim]
        h=self.bn1(h+h1)

        # Second add and norm
        h2=self.ffn(h)
        h=self.bn2(h+h2)

        return h

## 3. Graph Transformer model
Inputs -> GTLayers -> SumPooling -> Classifier

1. SumPooling: extra pooler stacked on top of GT layers to aggregate node features of the same graph.
2. Classifier = linear(d,d/2)+relu+linear(d/2,d/4)+relu+linear(d/4,out_size).

Inputs: (1) Graph g, (2) Node features X, (3) pos_enc (Laplacian encoding): shape [N,k]

First, we pass original data (dim d) into hidden_dim (that will be passed to the model):
1. Project the nodes into hidden_dim: h=nn.Linear(d,hidden_dim)(x)
2. Add position encoding to all nodes: h=h+pos_enc(h)

Now iterate the output through layers:
1. Compute adjaceny matrix A by g.edges(): indices=torch.stack(g.edges()), A=dglsp.spmatrix(indices,shape=(N,N))
2. Pass h through the layers: h=self.layer(A,h)

In [12]:
class GTModel(nn.Module):
    def __init__(self, out_size,hidden_dim=80,num_heads=8,pos_enc_dim=2,num_layers=8):
        super().__init__()

        # use atom encoder to project x into hidden representation h
        self.atom_encoder=AtomEncoder(hidden_dim)

        # map laplacian position encoding pos_enc into hidden_dim
        self.pos_linear=nn.Linear(pos_enc_dim,hidden_dim)

        # stack graph transformer layers
        self.layers=nn.ModuleList(
            [GTLayer(hidden_dim,num_heads) for _ in range(num_layers)]
        )

        # pooling layer
        self.pooler=dglnn.SumPooling()

        # classifier layer
        d=hidden_dim
        self.classifier=nn.Sequential(nn.Linear(d,d//2),
                                      nn.ReLU(),
                                      nn.Linear(d//2,d//4),
                                      nn.ReLU(),
                                      nn.Linear(d//4,out_size))

    def forward(self,g,X,pos_enc):
        # g = nx graph
        indices=torch.stack(g.edges())
        N=g.num_nodes()
        A=dglsp.spmatrix(indices,shape=(N,N))
        h=self.atom_encoder(X)+self.pos_linear(pos_enc)

        for layer in self.layers:
            h=layer(A,h)

        # pooler aggregates node features of nodes in g
        h=self.pooler(g,h)
        # classify based on the aggregated nodes
        h=self.classifier(h)

        return h


In [None]:
# create model
# out_size=1
# model=GTModel(out_size=out_size,pos_enc_dim=9)

## 4. Dataset ogbg-molhiv


In [5]:
# Load dataset as graph prediction data
dataset = AsGraphPredDataset(
    DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)
evaluator = Evaluator("ogbg-molhiv")
dataset

Dataset("ogbg-molhiv-as-graphpred", num_graphs=41127, save_path=/Users/doductai/.dgl/ogbg-molhiv-as-graphpred)

In [13]:
labels=torch.tensor([dataset[i][1] for i in range(len(dataset))])
labels=labels.view(len(labels),)
print("Labels: ", torch.unique(labels))

g=dataset[0][0]
label=dataset[0][1]
print("----- First Graph ------")
print(f"Number of nodes : {g.num_nodes()} | Number of edges: {g.num_edges()} | Label: {label}")

Labels:  tensor([0, 1])
----- First Graph ------
Number of nodes : 19 | Number of edges: 40 | Label: tensor([0])


In [14]:
batch_size=256
# split data into train/validation/test
train_loader=GraphDataLoader(dataset[dataset.train_idx],
                             batch_size=batch_size,shuffle=True,collate_fn=collate_dgl)
val_loader=GraphDataLoader(dataset[dataset.val_idx],
                             batch_size=batch_size,shuffle=False,collate_fn=collate_dgl)
test_loader=GraphDataLoader(dataset[dataset.test_idx],
                             batch_size=batch_size,shuffle=False,collate_fn=collate_dgl)
print("--------- Train loader ---------")
print(f"Number of graphs: {len(train_loader.dataset)} | Number of batches: {len(train_loader)}")
print("--- First batch ---")
for batch,labels in train_loader:
    print(batch)
    print(labels[:20].view(20,))
    break

--------- Train loader ---------
Number of graphs: 32901 | Number of batches: 129
--- First batch ---
Graph(num_nodes=6510, num_edges=13934,
      ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


In [15]:
# laplacian positional encoding
pos_enc_dim=2*batch.ndata['feat'].shape[-1]
indices=torch.cat([dataset.train_idx,dataset.val_idx,dataset.test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
    g,_=dataset[idx]
    g.ndata["PE"]=dgl.lap_pe(g,k=pos_enc_dim,padding=True)


Computing Laplacian PE: 100%|██████████| 41127/41127 [01:09<00:00, 593.43it/s]


## 5. Train and Test

In [16]:
from tqdm.notebook import tqdm

def train(model,loader,loss_fn,optimizer,device):
    total_loss=0
    model.train()
    for batch,labels in loader:
        batch,labels=batch.to(device),labels.to(device)

        out_logits=model(batch,batch.ndata["feat"],batch.ndata["PE"])
        loss=loss_fn(out_logits,labels.float())
        total_loss+=loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return total_loss/len(loader)

def evaluation(model,loader,evaluator,device):
    model.eval()
    y_true, y_pred=[],[]
    for batch,labels in loader:
        batch,labels=batch.to(device), labels.to(device)
        y_hat=model(batch,batch.ndata["feat"],batch.ndata["PE"])

        y_true.append(labels.view(y_hat.shape,).detach().cpu())
        y_pred.append(y_hat.detach().cpu())

    y_true=torch.cat(y_true,dim=0).numpy()
    y_pred=torch.cat(y_pred,dim=0).numpy()

    input_dict={"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)



In [17]:
import copy


def train_and_test(model,train_loader,val_loader,test_loader,num_epochs,loss_fn,evaluator,optimizer,device):

    best_model=None
    best_val_acc=0.0

    for epoch in range(num_epochs):
        train_loss=train(model,train_loader,loss_fn,optimizer,device)

        train_result=evaluation(model,train_loader,evaluator,device)
        val_result=evaluation(model,val_loader,evaluator,device)
        test_result=evaluation(model,test_loader,evaluator,device)

        train_acc,val_acc,test_acc=train_result['rocauc'], val_result['rocauc'], test_result['rocauc']

        # save the best model
        if val_acc>best_val_acc:
            best_val_acc=val_acc
            best_model=copy.deepcopy(model)

        print(f'Epoch: {epoch} | Train loss: {train_loss:.4f} | train_roc: {train_acc*100:.2f}% | '
         f'val_roc: {val_acc*100:.2f}% | test_roc: {test_acc*100:.2f}%')
    return best_model

In [18]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# create model
model=GTModel(out_size=1,pos_enc_dim=pos_enc_dim).to(device)

num_epochs=10

# BCEWithLogitsLoss() = sigmoid + BCE (more stable than plain BCE applied on sigmoid)
loss_fn=torch.nn.BCEWithLogitsLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

best_model=train_and_test(model,train_loader,val_loader,test_loader,num_epochs,loss_fn,evaluator,optimizer,device)

Epoch: 0 | Train loss: 0.1811 | train_roc: 68.10% | val_roc: 64.28% | test_roc: 55.74%
Epoch: 1 | Train loss: 0.1485 | train_roc: 78.95% | val_roc: 67.25% | test_roc: 67.43%
Epoch: 2 | Train loss: 0.1349 | train_roc: 80.52% | val_roc: 66.73% | test_roc: 69.19%
Epoch: 3 | Train loss: 0.1299 | train_roc: 78.26% | val_roc: 62.48% | test_roc: 71.47%
Epoch: 4 | Train loss: 0.1197 | train_roc: 87.78% | val_roc: 64.13% | test_roc: 67.45%
Epoch: 5 | Train loss: 0.1101 | train_roc: 92.43% | val_roc: 77.60% | test_roc: 73.67%
Epoch: 6 | Train loss: 0.1046 | train_roc: 92.67% | val_roc: 76.53% | test_roc: 72.93%
Epoch: 7 | Train loss: 0.1019 | train_roc: 93.68% | val_roc: 75.16% | test_roc: 71.64%
Epoch: 8 | Train loss: 0.0912 | train_roc: 95.07% | val_roc: 76.04% | test_roc: 73.37%
Epoch: 9 | Train loss: 0.0888 | train_roc: 97.28% | val_roc: 78.46% | test_roc: 71.54%


In [19]:
# Evaluate the best model on test set
best_model_acc=evaluation(best_model,test_loader,evaluator,device)['rocauc']
print(f"Best model ROCAUC on test set: {best_model_acc:.4f}")

Best model ROCAUC on test set: 0.7154
