In [1]:
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
import torch
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
import numpy as np

from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

%reload_ext autoreload
%autoreload 
#pytorch==2.0.0 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 pyg=2.4.0 ogb=1.3.6
#pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu117.html

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

torch.__version__, device

('2.0.0+cu117', device(type='cuda'))

In [2]:
# Download and process data at './dataset/ogbg_molhiv/'
dataset = PygGraphPropPredDataset(name='ogbg-molpcba',)

In [3]:
data = dataset[0]
data

Data(edge_index=[2, 44], edge_attr=[44, 3], x=[20, 9], y=[1, 128], num_nodes=20)

In [4]:
device = 0
drop_ratio = 0.4
num_layer = 4
emb_dim = 64
batch_size = 64
epochs = 10
num_workers = 0
dataset_name = "ogbg-molpcba"
filename = ""

In [5]:
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

def train(model, device, loader, optimizer, task_type):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type: 
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()

def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)

In [6]:
import sys
sys.path.append("../scPRINT/scprint/model")

from EGT import EGTLayer
from collections import Counter

%reload_ext autoreload
%autoreload 2

In [7]:
class EGT(torch.nn.Module):

    def __init__(self, num_tasks, num_layer = 5, emb_dim = 300, 
                drop_ratio = 0.5):
        super(EGT, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.model = EGT_node(num_layer, emb_dim, drop_ratio = drop_ratio)
        # self.pool = global_mean_pool
        self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.model(batched_data)
        h_graph = torch.mean(h_node, dim=1)
        return self.graph_pred_linear(h_graph)

class EGT_node(torch.nn.Module):
    """
    Output:
        node representations
    """
    def __init__(self, num_layer, emb_dim, num_heads=4, edge_feat_size=12, drop_ratio = 0.5):
        super(EGT_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.edge_feat_size = edge_feat_size
        self.emb_dim = emb_dim
        
        ### add residual connection or not
        self.bond_encoder = BondEncoder(emb_dim = edge_feat_size)

        self.atom_encoder = AtomEncoder(emb_dim)
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()        

        for _ in range(num_layer):
            self.convs.append(EGTLayer(feat_size=emb_dim, inner_size=emb_dim, edge_feat_size=edge_feat_size, num_heads=num_heads, num_virtual_nodes=1, dropout=drop_ratio, use_flash=True))


    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
        N = max(Counter(batch.tolist()).values())
        x = self.atom_encoder(x)
        ### computing input node embedding
        batch_x_list = x.split(list(Counter(batch.tolist()).values()))
        h = torch.stack([torch.cat([x, torch.zeros((N-x.size(0), self.emb_dim), dtype=torch.int64).to(device)]) for x in batch_x_list]).to(device)
        
        # do the same for the other ones
        attr = torch.zeros((batched_data.num_nodes, batched_data.num_nodes, self.edge_feat_size)).to(device)
        attr[edge_index[0], edge_index[1]] = self.bond_encoder(edge_attr)
        edge_attr = []
        mask = []
        
        for i,j in enumerate(batched_data.ptr[1:]):
            i = batched_data.ptr[i]
            e = torch.zeros((N, N, self.edge_feat_size)).to(device)
            e[:j-i, :j-i] = attr[i:j,i:j]
            edge_attr.append(e)
            #m = torch.zeros((N, N)) - np.inf
            #m[:j-i, :j-i] = torch.zeros((j-i, j-i))
            #mask.append(m)
        edge_attr = torch.stack(edge_attr)
        #mask = torch.stack(mask).to(device)
        #import pdb
        #pdb.set_trace()
        for layer in range(self.num_layer):
            h, edge_attr = self.convs[layer](h, edge_attr)

        ### Different implementations of Jk-concat

        return h

In [8]:
# TODO: make masking work
# TODO: give sparse tensor as input
# TODO: implem fash attention
# TODO: add distance encoding
# TODO: add svd encoding (put it outside the model, in scprint)


In [11]:
split_idx = dataset.get_idx_split()

### automatic evaluator. takes dataset name as input
evaluator = Evaluator(dataset_name)

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True, num_workers = num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=batch_size, shuffle=False, num_workers = num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=batch_size, shuffle=False, num_workers = num_workers)


model = EGT(num_tasks = dataset.num_tasks, num_layer = num_layer, emb_dim = emb_dim, drop_ratio = drop_ratio).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

valid_curve = []
test_curve = []
train_curve = []

for epoch in range(1, epochs + 1):
    print("=====Epoch {}".format(epoch))
    print('Training...')
    train(model, device, train_loader, optimizer, dataset.task_type)

    print('Evaluating...')
    train_perf = eval(model, device, train_loader, evaluator)
    valid_perf = eval(model, device, valid_loader, evaluator)
    test_perf = eval(model, device, test_loader, evaluator)

    print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

    train_curve.append(train_perf[dataset.eval_metric])
    valid_curve.append(valid_perf[dataset.eval_metric])
    test_curve.append(test_perf[dataset.eval_metric])

if 'classification' in dataset.task_type:
    best_val_epoch = np.argmax(np.array(valid_curve))
    best_train = max(train_curve)
else:
    best_val_epoch = np.argmin(np.array(valid_curve))
    best_train = min(train_curve)

print('Finished training!')
print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
print('Test score: {}'.format(test_curve[best_val_epoch]))

if not filename == '':
    torch.save({'Val': valid_curve[best_val_epoch], 'Test': test_curve[best_val_epoch], 'Train': train_curve[best_val_epoch], 'BestTrnain': best_train}, filename)
    

=====Epoch 1
Training...


Iteration:   0%|          | 0/5475 [00:00<?, ?it/s]

> [0;32m/home/ml4ig1/Documents code/scPRINT/scprint/model/EGT.py[0m(297)[0;36mforward[0;34m()[0m
[0;32m    295 [0;31m[0;34m[0m[0m
[0m[0;32m    296 [0;31m        [0;31m# Scale the aggregated values by degree.[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 297 [0;31m        [0mdegrees[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mgates[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m2[0m[0;34m,[0m [0mkeepdim[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    298 [0;31m        [0mdegree_scalers[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0;36m1[0m [0;34m+[0m [0mdegrees[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    299 [0;31m        [0mdegree_scalers[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m [0mself[0m[0;34m.[0m[0mnum_virtual_nodes[0m[0;34m][0m [0;34m=[0m [0;36m1.0[0m[0;34m[0m[0;34m[0m[0m
[0m
torch.Size([64, 37, 4, 16])
tensor([[[[ 6.3574e-01, -2.56

Iteration:   0%|          | 0/5475 [14:57<?, ?it/s]


RuntimeError: The size of tensor a (16) must match the size of tensor b (4) at non-singleton dimension 3

In [None]:
from einops import rearrange

rearrange(qkv, "b s1 e d -> b s1 d e").shape


torch.Size([6, 4, 2, 4])

In [None]:
# I reached 0.11 with the previous model in the same number of epochs

In [None]:
for batch in train_loader:
    break
batch = batch.to("cpu")
print(batch)
N = max(Counter(batch.batch.tolist()).values())

batch_x_list = batch.x.split(list(Counter(batch.batch.tolist()).values()))
h = torch.stack([torch.cat([AtomEncoder(emb_dim)(x), torch.zeros((N-x.size(0), emb_dim), dtype=torch.int64)]) for x in batch_x_list])

AtomEncoder(64)(h[0])
