In [1]:
import torch 

from lgi_gt import LGI_GT 
import gnnt, parallel 

# LGI 
model_lgi = LGI_GT( 
            out_dim = 128, 
            gconv_dim = 384, 
            tlayer_dim = 384, 
            num_layers = 5, 
            num_heads=8, 
            local_attn_dropout=0.0, 
            global_attn_dropout=0.3, 
            local_ffn_dropout=0.3, 
            global_ffn_dropout=0.3, 
            clustering=False, 
            masked_attention=True, 
            norm='ln', 
            skip_connection='none', 
            readout='cls') 

model_state_lgi = torch.load("state/LGI.pt")
model_lgi.load_state_dict(model_state_lgi) 
model_lgi.eval() 

# GNN+Transformer 
model_gnnt = gnnt.GraphTransformer( 
            out_dim = 128, 
            gconv_dim = 384, 
            tlayer_dim = 384, 
            num_layers = 5, 
            num_heads=8, 
            local_attn_dropout=0.0, 
            global_attn_dropout=0.3, 
            local_ffn_dropout=0.3, 
            global_ffn_dropout=0.3, 
            clustering=False, 
            masked_attention=True, 
            norm='ln', 
            skip_connection='none', 
            readout='cls') 

model_state_gnnt = torch.load("state/GNNT.pt")
model_gnnt.load_state_dict(model_state_gnnt) 
model_gnnt.eval() 

# parallel 
model_parallel = parallel.GraphTransformer( 
            out_dim = 128, 
            gconv_dim = 384, 
            tlayer_dim = 384, 
            num_layers = 5, 
            num_heads=8, 
            local_attn_dropout=0.0, 
            global_attn_dropout=0.3, 
            local_ffn_dropout=0.3, 
            global_ffn_dropout=0.3, 
            clustering=False, 
            masked_attention=True, 
            norm='ln', 
            skip_connection='none', 
            readout='cls') 

model_state_parallel = torch.load("state/Parallel.pt")
model_parallel.load_state_dict(model_state_parallel) 
model_parallel.eval() 

"""models loading done"""

'models loading done'

In [2]:
# from torch_geometric.loader import DataLoader # for pyg == 2.0.4 
from ogb.graphproppred import PygGraphPropPredDataset 

dataset = PygGraphPropPredDataset(name='ogbg-molpcba', root='.') 
split_idx = dataset.get_idx_split() 
# val_loader = DataLoader(dataset[split_idx["valid"]], batch_size=1, shuffle=False) 
val_dataset = dataset[split_idx['valid']] 
len(val_dataset)

43793

In [3]:
from visualize_draw import draw 

In [11]:
data_idx = 6

# 6, 2023, 2022, 10000, 996, 5555, 22222, 777 

data = val_dataset[data_idx] 
data.batch = torch.zeros(data.num_nodes, dtype=torch.int64) 
data.num_graphs = 1 

# model.get_clf_attn(batch, 1)[0] # 1-st layer clf attn, 1-st head 
clf_attn_lgi = model_lgi.get_clf_attn(data, 5).mean(dim=0) # the last layer clf attn, mean of all heads 
clf_attn_gnnt = model_gnnt.get_clf_attn(data, 5).mean(dim=0) 
clf_attn_parallel = model_parallel.get_clf_attn(data, 5).mean(dim=0) 

data.tag = data.x[:, 0] 
data.attn_lgi = clf_attn_lgi 
data.attn_gnnt = clf_attn_gnnt 
data.attn_parallel = clf_attn_parallel 

draw(data) 

set(data.tag.tolist()), data.tag.tolist() # 5: C, 6: N, 7: O, 15: S 

({5, 6, 7, 15},
 [5,
  7,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  7,
  5,
  7,
  6,
  7,
  5,
  5,
  5,
  5,
  15,
  5,
  5,
  6])