In [2]:
from google.colab import drive
drive.mount('/content/drive')

import torch
print("Torch:", torch.__version__, "CUDA:", torch.cuda.is_available())

#Install pytorch geometric (Colab-friendly)
!pip -q install torch_geometric

import torch_geometric
print("PyG:", torch_geometric.__version__)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Torch: 2.9.0+cpu CUDA: False
PyG: 2.7.0


In [5]:
# load labeled .npz graph and convert to PyG Data
from pathlib import Path
import numpy as np
from torch_geometric.data import Data

BASE = Path("/content/drive/MyDrive/biolip_gnn")
LABELED_DIR = BASE / "graphs_labeled_v3"

npz_files = sorted(LABELED_DIR.glob("*.npz"))
print("Labeled graphs found:", len(npz_files))

def load_npz_as_pyg(path: Path) -> Data:
  z = np.load(path, allow_pickle = True)

  x_idx = torch.tensor(z["x_idx"], dtype=torch.long)
  edge_index = torch.tensor(z["edge_index"], dtype=torch.long)
  y = torch.tensor(z["y"], dtype=torch.long)

  #Optional edge feature -> distance
  edge_attr = None
  if "edge_dist" in z.files:
    edge_attr = torch.tensor(z["edge_dist"], dtype=torch.float).view(-1, 1)

  data = Data(
      x = x_idx.view(-1, 1),
      edge_index = edge_index,
      edge_attr = edge_attr,
      y = y
  )
  data.pdb_id = str(z["pdb_id"])
  data.chain = str(z["chain"])
  return data

dataset = [load_npz_as_pyg(p) for p in npz_files]
print("Loaded dataset size:", len(dataset))
print("Example:", dataset[0], dataset[0].pdb_id, dataset[0].chain)




Labeled graphs found: 50
Loaded dataset size: 50
Example: Data(x=[387, 1], edge_index=[2, 4420], edge_attr=[4420, 1], y=[387], pdb_id='1KMM', chain='C') 1KMM C


In [7]:
# train/val/test split + class imbalance stats

import random
random.seed(42)

random.shuffle(dataset)

n = len(dataset)
n_train = int(0.7 * n)
n_val = int(0.15 * n)
train_set = dataset[:n_train]
val_set = dataset[n_train:n_train+n_val]
test_set = dataset[n_train+n_val:]

print("Split:", len(train_set), len(val_set), len(test_set))

def count_labels(ds):
  pos = 0
  tot = 0
  for d in ds:
    pos += int(d.y.sum().item())
    tot += int(d.y.numel())
  return pos, tot

pos_train, tot_train = count_labels(train_set)
neg_train = tot_train - pos_train
print("Train positives:", pos_train, "Train total:", tot_train, "Pos rate:", pos_train/tot_train)

# Weight for BCEWithLogitsLoss: pos_weight = neg/pos
pos_weight = torch.tensor([neg_train / max(pos_train, 1)], dtype = torch.float)
pos_weight



Split: 35 7 8
Train positives: 457 Train total: 8721 Pos rate: 0.05240224744868708


tensor([18.0832])

In [10]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

print("Batches:", len(train_loader), len(val_loader), len(test_loader))


Batches: 9 2 2


In [14]:
# defining a tiny baseline GNN for node classification

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class BaselineSAGE(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)

        self.conv1 = SAGEConv(emb_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)

        self.lin1 = nn.Linear(hidden, hidden)
        self.lin2 = nn.Linear(hidden, 1)

    def forward(self, data):
        x = data.x.squeeze(-1)
        x = self.emb(x)

        x = self.conv1(x, data.edge_index)
        x = F.relu(x)
        x = self.conv2(x, data.edge_index)
        x = F.relu(x)

        x = F.relu(self.lin1(x))
        logits = self.lin2(x).squeeze(-1)
        return logits




In [15]:
# train loop + metric
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BaselineSAGE().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

def run_epoch(loader, train=False):
    model.train() if train else model.eval()

    total_loss = 0.0
    tp = fp = fn = 0

    for batch in loader:
        batch = batch.to(device)

        with torch.set_grad_enabled(train):
            logits = model(batch)
            y = batch.y.float()

            loss = criterion(logits, y)
            if train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += float(loss.item())

        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).long()

        tp += int(((preds == 1) & (batch.y == 1)).sum().item())
        fp += int(((preds == 1) & (batch.y == 0)).sum().item())
        fn += int(((preds == 0) & (batch.y == 1)).sum().item())

    precision = tp / (tp + fp + 1e-9)
    recall    = tp / (tp + fn + 1e-9)
    avg_loss  = total_loss / max(len(loader), 1)

    return avg_loss, precision, recall

EPOCHS = 10
for epoch in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, train=True)
    va = run_epoch(val_loader, train=False)
    print(f"Epoch {epoch:02d} | train loss {tr[0]:.4f} P {tr[1]:.3f} R {tr[2]:.3f} "
          f"| val loss {va[0]:.4f} P {va[1]:.3f} R {va[2]:.3f}")


Epoch 01 | train loss 1.3236 P 0.000 R 0.000 | val loss 1.4011 P 0.000 R 0.000
Epoch 02 | train loss 1.3252 P 0.088 R 0.098 | val loss 1.3929 P 0.092 R 0.238
Epoch 03 | train loss 1.3018 P 0.083 R 0.333 | val loss 1.3875 P 0.091 R 0.343
Epoch 04 | train loss 1.2964 P 0.079 R 0.495 | val loss 1.3758 P 0.081 R 0.552
Epoch 05 | train loss 1.3113 P 0.084 R 0.455 | val loss 1.3685 P 0.091 R 0.505
Epoch 06 | train loss 1.2970 P 0.080 R 0.650 | val loss 1.3526 P 0.073 R 0.771
Epoch 07 | train loss 1.2313 P 0.081 R 0.630 | val loss 1.3534 P 0.090 R 0.543
Epoch 08 | train loss 1.2331 P 0.088 R 0.600 | val loss 1.3358 P 0.087 R 0.638
Epoch 09 | train loss 1.1822 P 0.085 R 0.678 | val loss 1.3378 P 0.092 R 0.581
Epoch 10 | train loss 1.2057 P 0.094 R 0.648 | val loss 1.3262 P 0.087 R 0.581


In [16]:
test_loss, test_p, test_r = run_epoch(test_loader, train=False)
print("TEST | loss:", round(test_loss,4), "precision:", round(test_p,3), "recall:", round(test_r,3))

SAVE_PATH = BASE / "out" / "day6_baseline_sage.pt"
torch.save(model.state_dict(), SAVE_PATH)
print("Saved model to:", SAVE_PATH)

TEST | loss: 1.1691 precision: 0.079 recall: 0.604
Saved model to: /content/drive/MyDrive/biolip_gnn/out/day6_baseline_sage.pt
