In [57]:
# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import DataLoader

# OGB
from ogb.graphproppred import PygGraphPropPredDataset

# Utils
import tqdm

In [None]:
# useful stuff
# https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb#scrollTo=HvhgQoO8Svw4
# https://ogb.stanford.edu/docs/graphprop/#ogbg-mol

### Data loading

In [58]:
BATCH_SIZE = 32

dataset = PygGraphPropPredDataset(name = "ogbg-molhiv")

split_idx = dataset.get_idx_split() 
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)



In [71]:
# Example of molecule graph and target label
print(dataset[0])
print(dataset.num_node_features)
print(dataset.y)
print(f"Number of graphs: {len(dataset)}")
print(f"Class balance: {dataset.y.sum()}")

Data(edge_index=[2, 40], edge_attr=[40, 3], x=[19, 9], y=[1, 1], num_nodes=19)
9
tensor([[0],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]])
Number of graphs: 41127
Class balance: 1443


### GCN model

In [64]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Conv layers
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        # Readout layer
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

In [65]:
def train(model, optimizer, loss_fun):
    for data in train_loader:
        out = model(data.x.float(), data.edge_index, data.batch)
        loss = loss_fun(out, torch.reshape(data.y, (-1,)))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [66]:
def validate(loader, model):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x.float(), data.edge_index, data.batch)  
        pred = out.argmax(dim=1)
        correct += int((pred == torch.reshape(data.y, (-1,))).sum())
    acc = correct / len(loader.dataset)
    return acc

In [67]:
model = GCN(4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fun = nn.CrossEntropyLoss()

model.train()
for epoch in tqdm.tqdm(range(20)):
    train(model, optimizer, loss_fun)

print(f"Epoch: {epoch}, Train ACC: {validate(train_loader, model)}, Val ACC: {validate(valid_loader, model)}")
    

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [05:50<00:00, 17.53s/it]


Epoch: 19, Train ACC: 0.9625239354426917, Val ACC: 0.9803063457330415


In [68]:
print(f"Test ACC: {validate(test_loader, model)}")

Test ACC: 0.9683929005592026


In [None]:
# tutaj jest też metryka do tego datasetu (lub po prostu ROC-AUC)
# https://ogb.stanford.edu/docs/graphprop/#ogbg-mol