In [1]:
import dgl
import dgl.nn as dglnn
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## 1. Sparse attention
<img src="graph_edge.png" width="500" height="300">

### (a) Traditional Attention: $Q, K, V$
1. Attention coefficients:
<center>
$\text{attn_coeffs(Q,K)}=\text{softmax}\left(\dfrac{QK^T}{\sqrt{d_k}}\right)$ with $d_k =$ dimensionality of $Q$ and $K$
<center>
2. Combine attention coefficients and values to give output
<center>
$\text{attention}(Q,K,V)=\text{attn_coeffs(Q,K)}V$
<center>

### (b) Sparse Attention: 
Input: $Q=[N,d], K=[N,d], V=[N,d], A=[N,N], E=[N,N,d]$
1. Attention coefficients
<center>
$\text{attn_coeffs(Q,K,A,E)}=\text{softmax}\left((\dfrac{QK^T}{\sqrt{d}}*A)*E\right)$ 
<center>
2. Aggregate over edge features to get coefficients as scalars:
<center>
$\text{attn_coeffs(Q,K,A,E)}=\text{attn_coeffs(Q,K,A,E)}.sum(dim=-1) = [N,N]$
<center>
3. Combine attention coefficients and values to give output
<center>
$\text{attention}(Q,K,A,E,V)=\text{attn_coeffs(Q,K,A,E)}V$
<center>
    
### (c) Attention module structure
(1) A = adjacency matrix= [N,N], \
        (2) h_x = hidden representation of node features $ = [N,\text{hidden_dim}]$,\
        (3) h_e = hidden representation of edge features $ = [N,N,\text{hidden_dim}]$
    
1. Project the input into Q, K, V, E: 
    $$Q, K, V = [N,\text{hidden_dim}], \ E=[N,N,\text{hidden_dim}]$$
2. Reshape Q,K,V,E according to num_heads $n_h$: 
    $$\text{hidden_dim}=d_hn_h \rightarrow Q,K,V=[N,d_h,n_h], E=[N,N,d_h,n_h]$$
3. Compute attention coefficients: * means pointwise multiplication
$$\text{attn_coeffs(Q,K,A,E)}=\text{softmax}\left((\dfrac{QK^T}{\sqrt{d_k}}*A)*E\right)=[N,N,d_h,n_h]$$
4. Aggregate along edge dimension to get the coefficients as scalars
    $$\text{attn_coeffs(Q,K,A,E)}=\text{attn_coeffs(Q,K,E)}.sum(dim=-2)=[N,N,nh]$$
5. Combine with value $V$:
$$\text{ouput(Q,K,E,A,V)}=\text{attn_coeffs}(Q,K,A,E)V=[N,d_h,n_h]$$


In [2]:
import numpy as np
class SparseMultiHeadAttention(nn.Module):
    def __init__(self,hidden_dim=64,num_heads=4):
        super().__init__()
        
        self.num_heads=num_heads
        self.hidden_dim=hidden_dim
        
        self.linear_q=nn.Linear(hidden_dim,hidden_dim)
        self.linear_k=nn.Linear(hidden_dim,hidden_dim)
        self.linear_v=nn.Linear(hidden_dim,hidden_dim)
        self.linear_e=nn.Linear(hidden_dim,hidden_dim)
        
        # projection of output O_h, O_e
        self.proj_h=nn.Linear(hidden_dim,hidden_dim)
        self.proj_e=nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self,A,h_x,h_e):
        
        # A   = adjacency matrix = [N,N] 
        # h_x = hidden representation of node features = [N,hidden_dim]
        # h_e = hidden representation of edge features = [N,N,hidden_dim]
        
        # Extract N, nh, dh
        N=len(h_x)                # number of nodes
        nh=self.num_heads
        dh=self.hidden_dim//nh
        
        # compute q,k,v,e and reshape into heads
        q=self.linear_q(h_x).reshape(N,dh,nh)
        k=self.linear_k(h_x).reshape(N,dh,nh)
        v=self.linear_v(h_x).reshape(N,dh,nh)
        
        e=self.linear_e(h_e).reshape(N,N,dh,nh)

        # implicit attention coefficients = (qk^T/sqrt(dh)*A)*e
        
        attn_coeff_implicit=dglsp.bsddmm(A,q,k.transpose(1,0))/np.sqrt(dh) # [N,N,nh]
        
        attn_coeff_implicit=dgl.sparse.SparseMatrix.to_dense(attn_coeff_implicit) # [N,N,nh]
        
        # multiply attn_coeff_implicit into e: multiply elementwise 
        attn_coeffs=attn_coeff_implicit.reshape(N,N,1,nh)*e # [N,N,dh,nh]*[N,N,1,nh]=[N,N,dh,nh]        
        
        # keep a copy of e for FFN_e
        e_out=attn_coeffs.view(N,N,-1)                      # [N,N,hidden_dim]
        
        # take sum along dimension d_h
        attn_coeffs=torch.exp(torch.sum(attn_coeffs,dim=-2).clamp(-5,5)) # [N,N,nh]
        
        
        # apply softmax
        attn_coeffs=attn_coeffs.softmax(dim=-2)                     # [N,N,nh]
        
        
        # combine with v to give output for FFN_v
        #              attn_coeffs=[N,N,nh], v=[N,dh,nh] 
        h_out=torch.matmul(attn_coeffs.transpose(-1,-2).transpose(-2,-3),
                           v.transpose(-1,-2).transpose(-2,-3))     # [nh,N,dh]
        h_out=(h_out.transpose(-2,-3).transpose(-2,-1)).reshape(N,-1) # [N,hidden_dim]


        # projection of output O_h, O_e
        h_out=self.proj_h(h_out)
        e_out=self.proj_e(e_out)
        
        return h_out,e_out 
        

In [3]:
# sannity check
hidden_dim=64
num_heads=4
N,nh,dh=19,4,8

# create a sparse adjacency matrix
src=torch.randint(high=19,size=(40,))
dst=torch.randint(high=19,size=(40,))
A=dglsp.spmatrix(torch.stack([src,dst]),shape=(N,N))

# node and edge features
h_x=torch.rand((N,hidden_dim))
h_e=torch.rand((N,N,hidden_dim))

# model
SparseMHA=SparseMultiHeadAttention(hidden_dim,num_heads)
h_out, e_out=SparseMHA(A,h_x,h_e)
print(h_out.shape, e_out.shape)

torch.Size([19, 64]) torch.Size([19, 19, 64])


## 2. Graph Transformer Layer
<img src="graph_edge.png" width="400" height='400'>

In [4]:
class GTLayer(nn.Module):
    def __init__(self,hidden_dim=64,num_heads=4):
        super().__init__()
        
        self.hidden_dim=hidden_dim
        self.num_heads=num_heads
        
        # attention module
        self.attention=SparseMultiHeadAttention(hidden_dim,num_heads)
        
        
        # batch normalization
        self.bnh_1=nn.BatchNorm1d(hidden_dim)
        self.bnh_2=nn.BatchNorm1d(hidden_dim)
        
        self.bne_1=nn.BatchNorm1d(hidden_dim)
        self.bne_2=nn.BatchNorm1d(hidden_dim)
        
        # ffn_h and ffn_e
        self.ffn_h=nn.Sequential(
            nn.Linear(hidden_dim,hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2,hidden_dim)
        )
        self.ffn_e=nn.Sequential(
            nn.Linear(hidden_dim,hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2,hidden_dim)
        )
        
    def forward(self,A,h_x,h_e):
        
        N=h_e.shape[0]
        
        # retain for 1st residual connection
        h1,e1=h_x,h_e 
        
        # attention module
        h,e=self.attention(A,h_x,h_e)
        
        
        # First add and norm
        h=self.bnh_1(h+h1)
        
        e=self.bne_1((e+e1).reshape(-1,self.hidden_dim))
        e=e.reshape(N,N,hidden_dim)
        
        # retain for 2nd residual connection
        h2,e2=h,e
        
        # ffn layers
        h=self.ffn_h(h)
        e=self.ffn_e(e)
        
        # Second add and norm
        h=self.bnh_2(h+h2)
        
        e=self.bne_2((e+e2).reshape(-1,self.hidden_dim))
        e=e.reshape(N,N,hidden_dim)
        
        return h,e

In [5]:
# sanity check
layer=GTLayer(hidden_dim=64,num_heads=4)
h,e=layer(A,h_x,h_e)
h.shape, e.shape

(torch.Size([19, 64]), torch.Size([19, 19, 64]))

## 3. Graph Transformer Model 
We will implemement: Inputs -> GTLayers -> SumPooling -> Classifier 
        
1. SumPooling: extra pooler stacked on top of GT layers to aggregate node features of the same graph.
2. Classifier = linear(d,d/2)+relu+linear(d/2,d/4)+relu+linear(d/4,out_size).

Inputs: (1) Graph g, (2) Node features X, (3) pos_enc (Laplacian encoding): shape [N,2]

First, we pass original data (dim d) into hidden_dim (that will be passed to the model):
1. Project the nodes into hidden_dim: h=nn.Linear(d,hidden_dim)(x)
2. Add position encoding to all nodes: h=h+pos_enc(h)

Now iterate the output through layers:
1. Compute adjaceny matrix A by g.edges(): indices=torch.stack(g.edges()), A=dglsp.spmatrix(indices,shape=(N,N))
2. Pass h through the layers: h=self.layer(A,h)

In [6]:
class GTModel(nn.Module):
    def __init__(self,out_size=1,hidden_dim=64,num_heads=4,pos_enc_dim=10,
                 num_layers=4,edge_dim=3):
        super().__init__()
        
        # atom encoder to project x into hidden representation h
        self.atom_encoder=AtomEncoder(hidden_dim)
        
        # map laplacian position encoding pos_enc into hidden_dim
        self.pos_linear=torch.nn.Linear(pos_enc_dim,hidden_dim)
        
        # map edge_dim into hidden_dim
        self.linear_e=torch.nn.Linear(edge_dim,hidden_dim)
        
        # stack of graph transformer layers
        self.layers=nn.ModuleList(
            [GTLayer(hidden_dim,num_heads) for _ in range(num_layers)]
        )
        
        # pooling layer
        self.pooler=dglnn.SumPooling()
        
        # classifier layer
        self.classifier=nn.Sequential(nn.Linear(hidden_dim,hidden_dim//2),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim//2,hidden_dim//4),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim//4,out_size))
        
        
    def forward(self,g,x,pos_enc):      # g = nx graph
        
        indices=torch.stack(g.edges())
        N=g.num_nodes()
        A=dglsp.spmatrix(indices,shape=(N,N))
        
        h=self.atom_encoder(x)+self.pos_linear(pos_enc)
        
        # create matrix e=[N,N,hidden_dim] with e[i,j] = feature of edge e_ij
        e=self.linear_e(g.edata['feat'].float())              # [Ne,hidden_dim]
        e=dglsp.spmatrix(indices,e,shape=(N,N))               # sparse matrix [N,N,hidden_dim]
        e=e.to_dense()                                        # [N,N,hidden_dim]
        
        for layer in self.layers:
            h,e=layer(A,h,e)
        
        # pooler aggregates node features of nodes in g
        h=self.pooler(g,h)
        
        
        # classify based on the aggregated node features
        h=self.classifier(h)
        
        return h
        

## 4. Data and Train 

In [7]:
import random
# Load dataset as graph prediction data
dataset = AsGraphPredDataset(
    DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)

# downsample the dataset for faster training time
train_size, val_size, test_size = len(dataset.train_idx), len(dataset.val_idx), len(dataset.test_idx)

torch.manual_seed(42)
train_idx = dataset.train_idx[torch.LongTensor(random.sample(range(train_size), 2000))]
val_idx = dataset.val_idx[torch.LongTensor(random.sample(range(val_size), 1000))]
test_idx = dataset.test_idx[torch.LongTensor(random.sample(range(test_size), 1000))]


# split data into train/validation/test
batch_size=32
train_loader=GraphDataLoader(dataset[train_idx],
                             batch_size=batch_size,shuffle=True,collate_fn=collate_dgl)
val_loader=GraphDataLoader(dataset[val_idx],
                             batch_size=batch_size,shuffle=False,collate_fn=collate_dgl)
test_loader=GraphDataLoader(dataset[test_idx],
                             batch_size=batch_size,shuffle=False,collate_fn=collate_dgl)

In [8]:
# laplacian positional encoding
pos_enc_dim=10
indices=torch.cat([train_idx,val_idx,test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
    g,_=dataset[idx]
    g.ndata["PE"]=dgl.lap_pe(g,k=pos_enc_dim,padding=True)


Computing Laplacian PE: 100%|██████████| 4000/4000 [00:07<00:00, 559.58it/s]


In [12]:
import copy
from tqdm.notebook import tqdm

def train(model,loader,loss_fn,optimizer,device):
    total_loss=0
    model.train()
    for i,(batch,labels) in enumerate(loader):
        batch,labels=batch.to(device),labels.to(device)
        
        out_logits=model(batch,batch.ndata["feat"],batch.ndata["PE"])
        
        loss=loss_fn(out_logits,labels.float())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss+=loss.item()
        
        if i%20==0 or i==len(loader)-1:
            print(f"Iteration: {i} | Loss: {loss:.4f}")
    
    return total_loss/len(loader)

def evaluation(model,loader,evaluator,device):
    # we use API evaluate for ogbg-molhiv which computes ROCAUC
    with torch.no_grad():
        model.eval()
        y_true, y_pred=[],[]

        for batch,labels in loader:

            batch,labels=batch.to(device), labels.to(device)
            y_hat=model(batch,batch.ndata["feat"],batch.ndata["PE"]) # logits

            y_true.append(labels.view(y_hat.shape,).detach().cpu())
            y_pred.append(y_hat.detach().cpu())

        y_true=torch.cat(y_true,dim=0).numpy()
        y_pred=torch.cat(y_pred,dim=0).numpy()

        input_dict={"y_true": y_true, "y_pred": y_pred}
        score=evaluator.eval(input_dict)["rocauc"]
    
    return score

def train_and_test(model,train_loader,val_loader,test_loader,
                   num_epochs,loss_fn,optimizer,evaluator,device):
    best_val_rocauc=0.0
    best_model=None
    
    for epoch in range(num_epochs):
        print(f"Epoch: {epoch}")
        print("----- Training ------")
        train_loss=train(model,train_loader,loss_fn,optimizer,device)
        
        print("----- Evaluating ----")
        train_rocauc=evaluation(model,train_loader,evaluator,device)
        val_rocauc=evaluation(model,val_loader,evaluator,device)
        test_rocauc=evaluation(model,test_loader,evaluator,device)
                
        # save the best model
        if val_rocauc>best_val_rocauc:
            best_val_rocauc=val_rocauc
            best_model=copy.deepcopy(model)
            
        print(f'Train loss: {train_loss:.4f} | train_roc: {train_rocauc:.4f} | '
         f'val_rocauc: {val_rocauc:.4f} | test_rocauc: {test_rocauc:.4f}')
    return best_model


In [10]:
evaluator = Evaluator("ogbg-molhiv")

model=GTModel()
device=torch.device("cpu")
loss_fn=torch.nn.BCEWithLogitsLoss() 
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

# print pre-trained performance of the model
train_rocauc=evaluation(model,train_loader,evaluator,device)
val_rocauc=evaluation(model,val_loader,evaluator,device)
test_rocauc=evaluation(model,test_loader,evaluator,device)
print("Pretrained performance")
print(f"train_rocauc : {train_rocauc:.4f} | val_rocauc : {train_rocauc:.4f} | test_rocauc : {test_rocauc:.4f}")


Pretrained performance
train_rocauc : 0.3960 | val_rocauc : 0.3960 | test_rocauc : 0.3593


In [13]:
# train the model and print out performance
num_epochs=5
best_model=train_and_test(model,train_loader,val_loader,test_loader,
                   num_epochs,loss_fn,optimizer,evaluator,device)

Epoch: 0
----- Training ------
Iteration: 0 | Loss: 0.4708
Iteration: 20 | Loss: 0.2386
Iteration: 40 | Loss: 0.3987
Iteration: 60 | Loss: 0.0679
Iteration: 62 | Loss: 0.0725
----- Evaluating ----
Train loss: 0.2238 | train_roc: 0.6089 | val_rocauc: 0.5358 | test_rocauc: 0.4209
Epoch: 1
----- Training ------
Iteration: 0 | Loss: 0.1227
Iteration: 20 | Loss: 0.4018
Iteration: 40 | Loss: 0.2443
Iteration: 60 | Loss: 0.1868
Iteration: 62 | Loss: 0.0694
----- Evaluating ----
Train loss: 0.1894 | train_roc: 0.6726 | val_rocauc: 0.6092 | test_rocauc: 0.5316
Epoch: 2
----- Training ------
Iteration: 0 | Loss: 0.0429
Iteration: 20 | Loss: 0.5700
Iteration: 40 | Loss: 0.0285
Iteration: 60 | Loss: 0.2035
Iteration: 62 | Loss: 0.2587
----- Evaluating ----
Train loss: 0.1725 | train_roc: 0.6861 | val_rocauc: 0.5915 | test_rocauc: 0.5178
Epoch: 3
----- Training ------
Iteration: 0 | Loss: 0.2680
Iteration: 20 | Loss: 0.0467
Iteration: 40 | Loss: 0.1686
Iteration: 60 | Loss: 0.0521
Iteration: 62 | L

In [None]:
# Evaluate the best model on test set
# best_model_rocauc=evaluation(best_model,test_loader,evaluator,device)
# print(f"Best model rocauc on test set: {best_model_rocauc*100:.2f}")