In [1]:
import torch
from torch_geometric.datasets import MoleculeNet
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Subset
from torch_geometric.utils.convert import to_networkx
from networkx import all_pairs_shortest_path

In [2]:
!pip install rdkit


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
dataset = MoleculeNet(root="./", name="ESOL")
dataset

ESOL(1128)

In [4]:
from torch import nn
import torch_geometric.nn as tgnn
from graphormer.model import Graphormer


model = Graphormer(
    num_layers=3,
    input_node_dim=dataset.num_node_features,
    node_dim=128,
    input_edge_dim=dataset.num_edge_features,
    edge_dim=128,
    output_dim=dataset[0].y.shape[1],
    n_heads=4,
    max_in_degree=5,
    max_out_degree=5,
    max_path_distance=5,
)

In [5]:
from sklearn.model_selection import train_test_split

test_ids, train_ids = train_test_split([i for i in range(len(dataset))], test_size=0.8, random_state=42)
train_loader = DataLoader(Subset(dataset, train_ids), batch_size=64)
test_loader = DataLoader(Subset(dataset, test_ids), batch_size=64)

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_functin = nn.L1Loss(reduction="sum")

In [None]:
from tqdm import tqdm
from torch_geometric.nn.pool import global_mean_pool

DEVICE = "cuda"

model.to(DEVICE)
for epoch in range(10):
    model.train()
    batch_loss = 0.0
    for batch in tqdm(train_loader):
        batch.to(DEVICE)
        y = batch.y
        optimizer.zero_grad()
        output = global_mean_pool(model(batch), batch.batch)
        loss = loss_functin(output, y)
        batch_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
    print("TRAIN_LOSS", batch_loss / len(train_ids))

    model.eval()
    batch_loss = 0.0
    for batch in tqdm(test_loader):
        batch.to(DEVICE)
        y = batch.y
        with torch.no_grad():
            output = global_mean_pool(model(batch), batch.batch)
            loss = loss_functin(output, y)
            
        batch_loss += loss.item()

    print("EVAL LOSS", batch_loss / len(test_ids))

    

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [11:36<00:00, 46.43s/it]


TRAIN_LOSS 1.6868955225643525


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:45<00:00, 11.41s/it]


EVAL LOSS 1.498822530110677


 20%|████████████████████████████████████▍                                                                                                                                                 | 3/15 [02:18<09:01, 45.14s/it]