In [2]:
import torch
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Subset
from torch_geometric.utils.convert import to_networkx
from networkx import all_pairs_shortest_path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!pip install rdkit



In [3]:
dataset = MoleculeNet(root="./", name="ESOL")
dataset

ESOL(1128)

In [4]:
from torch import nn
import torch_geometric.nn as tgnn
from graphormer.model import Graphormer


model = Graphormer(
    num_layers=3,
    input_node_dim=dataset.num_node_features,
    node_dim=128,
    input_edge_dim=dataset.num_edge_features,
    edge_dim=128,
    output_dim=dataset[0].y.shape[1],
    n_heads=4,
    max_in_degree=5,
    max_out_degree=5,
    max_path_distance=5,
)

In [5]:
from sklearn.model_selection import train_test_split

test_ids, train_ids = train_test_split([i for i in range(len(dataset))], test_size=0.8, random_state=42)
train_loader = DataLoader(Subset(dataset, train_ids), batch_size=64, shuffle=False)
test_loader = DataLoader(Subset(dataset, test_ids), batch_size=64, shuffle=False)

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_functin = nn.L1Loss(reduction="sum")

In [None]:
from tqdm import tqdm
from torch_geometric.nn.pool import global_mean_pool

DEVICE = "cpu"

model.to(DEVICE)
for epoch in range(10):
    model.train()
    batch_loss = 0.0
    for batch in tqdm(train_loader):
        batch.to(DEVICE)
        y = batch.y
        optimizer.zero_grad()
        output = global_mean_pool(model(batch), batch.batch)
        loss = loss_functin(output, y)
        batch_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    print("TRAIN_LOSS", batch_loss / len(train_ids))

    model.eval()
    batch_loss = 0.0
    for batch in tqdm(test_loader):
        batch.to(DEVICE)
        y = batch.y
        with torch.no_grad():
            output = global_mean_pool(model(batch), batch.batch)
            loss = loss_functin(output, y)
            
        batch_loss += loss.item()

    print("EVAL LOSS", batch_loss / len(test_ids))


  0%|                                                                                                                                                | 0/15 [00:00<?, ?it/s]

new version 15 epochs: train 15:33 eval inference 00:13
old version 15 epochs: train 14:33 eval inference 00:13
new_1 version 15 epochs: train 15:24 eval inference 00:13
torch.degree version 15 epochs: train 15:08 eval inference 00:13

In [4]:
a = torch.randn((50, 50, 10, 3))
b = torch.randn((10, 3))
(a * b).sum(dim=-1).sum(dim=-1)

tensor([[  8.6705,  11.7491,   1.5928,  ...,   9.9827,   0.3962,  -4.8716],
        [  5.0475,  -0.7071,   3.1063,  ..., -16.8568,   0.6816,  -4.5517],
        [ -5.1184,  -0.6247,  -0.8803,  ...,   6.2039,  -6.5872,  -0.8567],
        ...,
        [  3.4272,  10.8034,   1.1754,  ...,  11.9679,  -1.4390,  -4.1685],
        [ -4.2619,  -9.7124,  -2.3889,  ...,   6.1664,  -0.9408,   1.7503],
        [  1.1239,   0.2552,  -5.8109,  ...,  -2.3696,  -2.5762,  -2.1762]])