## Multilabel classification on Protein-Protein Interaction with GraphSAGE

In [2]:
import torch
!pip install -qU torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html

torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
import torch
from sklearn.metrics import f1_score

from torch_geometric.datasets import PPI
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Load training, validation, and test sets
train_dataset = PPI(root=".", split='train')
val_dataset = PPI(root=".", split='val')
test_dataset = PPI(root=".", split='test')

Downloading https://data.dgl.ai/dataset/ppi.zip
Extracting ./ppi.zip
Processing...
Done!


In [12]:
# Unify the training graphs and apply neighbor sampling
train_data = Batch.from_data_list(train_dataset)
train_loader = NeighborLoader(train_data, batch_size=2048,
                              shuffle=True, num_neighbors=[20, 10],
                              num_workers=2, persistent_workers=True)


In [13]:
# Validation and test loaders (one datapoint corresponds to a graph)
val_loader = DataLoader(val_dataset, batch_size=2)
test_loader = DataLoader(test_dataset, batch_size=2)

In [14]:
# Setup the GraphSAGE model
model = GraphSAGE(
    in_channels=train_dataset.num_features,
    hidden_channels=512,
    num_layers=2,
    out_channels=train_dataset.num_classes).to(device)

model

GraphSAGE(50, 121, num_layers=2)

In [15]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [16]:
def fit(loader):
    model.train()

    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = loss_fn(out, data.y)
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(loader.data)

@torch.no_grad()
def test(loader):
    model.eval()

    data = next(iter(loader))
    out = model(data.x.to(device), data.edge_index.to(device))
    preds = (out > 0).float().cpu()

    y, pred = data.y.numpy(), preds.numpy()
    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0

In [17]:
for epoch in range(301):
    loss = fit(train_loader)
    val_f1 = test(val_loader)
    if epoch % 50 == 0:
        print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Val F1-score: {val_f1:.4f}')

print(f'Test F1-score: {test(test_loader):.4f}')

Epoch   0 | Train Loss: 12.700 | Val F1-score: 0.4963
Epoch  50 | Train Loss: 8.742 | Val F1-score: 0.7970
Epoch 100 | Train Loss: 8.604 | Val F1-score: 0.8145
Epoch 150 | Train Loss: 8.544 | Val F1-score: 0.8197
Epoch 200 | Train Loss: 8.502 | Val F1-score: 0.8234
Epoch 250 | Train Loss: 8.464 | Val F1-score: 0.8277
Epoch 300 | Train Loss: 8.440 | Val F1-score: 0.8269
Test F1-score: 0.8505
