In [77]:
import os
import math
import time
import json
import torch
import random
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.data import DataLoader
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.nn import GATConv, GraphNorm, global_mean_pool
from sklearn.metrics import r2_score

In [40]:
class PDBBindDataset(InMemoryDataset):
    def __init__(
        self,
        metadata_csv = "data/processed/refined_dataset_metadata.csv",
        complex_dir = "data/processed/complex_graphs",
        split_dir = "data/processed/splits",
        root = "data/processed/full_dataset",
        transform=None,
        pre_transform=None,
        force_rebuild=False,
        split=None
    ):
        self.metadata_csv = Path(metadata_csv)
        self.complex_dir = Path(complex_dir)
        self.root = Path(root)
        self.root.mkdir(parents=True, exist_ok=True)
        self.force_rebuild = force_rebuild
        self.split_dir=Path(split_dir)
        self.split=split
        
        split_suffix = f"_{split}" if split else ""
        self.cache_file = self.root / f"complex_dataset{split_suffix}.pt"

        super().__init__(self.root, transform, pre_transform)

        if self.cache_file.exists() and not self.force_rebuild:
            print(f"Loading cached dataset from {self.cache_file}")
            self.data, self.slices = torch.load(self.cache_file, weights_only=False)
        else:
            print("Processing complex dataset...")
            self.process()
            self.data, self.slices = torch.load(self.cache_file)

    @property
    def raw_file_names(self):
        return [self.metadata_csv.name]

    @property
    def processed_file_names(self):
        return [self.cache_file.name]

    def process(self):
        df = pd.read_csv(self.metadata_csv)
        if self.split:
            print(self.split_dir, self.split)
            split_file = self.split_dir / f"{self.split}_ids.txt"
            if split_file.exists():
                split_ids = set(open(split_file).read().splitlines())
                df = df[df["complex_id"].astype(str).isin(split_ids)]
            else:
                print(f"split not found: {split_file}")

        data_list = []
        skipped = 0


        for _, row in df.iterrows():
            cid = row["complex_id"]
            affinity = float(row["affinity"])
            graph_file = self.complex_dir / f"{cid}.pt"

            if not graph_file.exists():
                print(f"Complex graph missing for {cid}")
                skipped += 1
                continue

            try:
                graph_data = torch.load(graph_file)
                graph_data.y = torch.tensor([affinity], dtype=torch.float32)
                graph_data.complex_id = cid
                data_list.append(graph_data)
            except Exception as e:
                print(f"Failed to load complex {cid}: {e}")
                skipped += 1
                continue

            if len(data_list) % 200 == 0:
                print(f"Loaded {len(data_list)} complex graphs...")

        print(f"Finished. Loaded {len(data_list)} complexes. Skipped: {skipped}")

        if len(data_list) == 0:
            raise RuntimeError("No valid complex graphs loaded. Check paths and files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.cache_file)
        print(f"Saved processed dataset to {self.cache_file}")

    def len(self):
        return super().len()

    def get(self, idx):
        return super().get(idx)

In [57]:
class CoreGNN(nn.Module):
    def __init__(
        self,
        node_dim: int,
        edge_dim: int,
        hidden_dim: int = 128,
        num_layers: int = 3,
        dropout: float = 0.2,
    ):
        super().__init__()

        self.encoder_layers = nn.ModuleList([
            GATConv(
                in_channels=node_dim if i == 0 else hidden_dim,
                out_channels=hidden_dim,
                edge_dim=edge_dim,
                heads=4,
                concat=False
            )
            for i in range(num_layers)
        ])
        self.encoder_norms = nn.ModuleList([GraphNorm(hidden_dim) for _ in range(num_layers)])

        self.cross_layers = nn.ModuleList([
            GATConv(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=edge_dim,
                heads=4,
                concat=False
            )
            for _ in range(num_layers)
        ])
        self.cross_norms = nn.ModuleList([GraphNorm(hidden_dim) for _ in range(num_layers)])

        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.dropout = dropout
        self.hidden_dim = hidden_dim
        
    def forward(self, data):
        x, edge_index, edge_attr, batch, node_type = (
            data.x, data.edge_index, data.edge_attr, data.batch, data.node_type
        )
        src, dst = edge_index
#         print(batch)
#         print('src', src)
#         print('dst', dst)
#         print('node_type', node_type)
#         print('len(node_type)', len(node_type))
#         print('max(edge_index)', max(edge_index.flatten()))
#         print(node_type[src])
#         print(node_type[dst])
        same_mask = (node_type[src] == node_type[dst])
#         print(edge_attr)
        cross_mask = (node_type[src] != node_type[dst])

        intra_edge_index = edge_index[:, same_mask]
        intra_edge_attr = edge_attr[same_mask]
        cross_edge_index = edge_index[:, cross_mask]
        cross_edge_attr = edge_attr[cross_mask]

        # ligand + protein separately
        for conv, norm in zip(self.encoder_layers, self.encoder_norms):
            x = conv(x, intra_edge_index, intra_edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # ligand <-> protein)
        for conv, norm in zip(self.cross_layers, self.cross_norms):
            x = conv(x, cross_edge_index, cross_edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # pooling
        out = global_mean_pool(x, batch)
        
        # regression
        y_hat = self.readout(out)
#         print(y_hat)
        return y_hat.squeeze(1)


In [12]:
dataset = PDBBindDataset()

Processing...


Complex graph missing for 2r58
Complex graph missing for 3c2f
Complex graph missing for 3g2y
Complex graph missing for 3pce
Complex graph missing for 4qsu
Complex graph missing for 4qsv
Complex graph missing for 4u54
Complex graph missing for 3ao4
Complex graph missing for 4cs9
Complex graph missing for 2w8w
Complex graph missing for 3gv9
Complex graph missing for 6r9u
Complex graph missing for 6abx
Complex graph missing for 4q90
Complex graph missing for 5cs3
Complex graph missing for 4tim
Complex graph missing for 5fe6
Complex graph missing for 6ghj
Complex graph missing for 3gqz
Complex graph missing for 4y3j
Complex graph missing for 5oxk
Complex graph missing for 5z5f
Complex graph missing for 4ahr
Complex graph missing for 4ahs
Complex graph missing for 4mre
Complex graph missing for 1x8d
Complex graph missing for 4g0z
Complex graph missing for 1m0n
Complex graph missing for 2aac
Complex graph missing for 4aci
Complex graph missing for 4ury
Complex graph missing for 3ao5
Complex 

Done!


AttributeError: module 'torch.serialization' has no attribute 'safe_globals'

In [217]:
len(dataset)

24

In [218]:
node_dim = 22
edge_dim = 5

model = CoreGNN(node_dim=node_dim, edge_dim=edge_dim, hidden_dim=128)

In [219]:
dataloader = DataLoader(dataset, batch_size=4)



In [220]:
for batch in dataloader:
    model(batch)
    break

tensor([0, 0, 0,  ..., 3, 3, 3])
src tensor([   0,   11,    0,  ..., 2512, 2367, 2643])
dst tensor([  11,    0,   10,  ..., 2367, 2643, 2367])
node_type tensor([0, 0, 0,  ..., 1, 1, 1])
len(node_type) 3132
max(edge_index) tensor(3131)
tensor([0, 0, 0,  ..., 1, 0, 1])
tensor([0, 0, 0,  ..., 0, 1, 0])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [4.0560, 0.0000, 0.0000, 0.0000, 0.0000],
        [4.3776, 0.0000, 0.0000, 0.0000, 0.0000],
        [4.3776, 0.0000, 0.0000, 0.0000, 0.0000]])
tensor([[-0.0053],
        [ 0.1294],
        [ 0.0473],
        [ 0.1194]], grad_fn=<AddmmBackward0>)


In [78]:
SEED = 42

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((np.array(y_true) - np.array(y_pred)) ** 2)))

def pearson_r(y_true, y_pred):
    if len(y_true) < 2:
        return 0.0
    c = np.corrcoef(y_true, y_pred)
    if np.isnan(c).any():
        return 0.0
    return float(c[0,1])

In [64]:
def ensure_splits(metadata_csv="data/processed/refined_dataset_metadata.csv",
                  splits_dir="data/processed/splits",
                  seed=SEED,
                  ratios=(0.8, 0.1, 0.1)):
    """
    Ensure split files exist. If not, create deterministic random scaffold split.
    Currently: simple random split reproducible by seed.
    Writes three files: train_ids.txt, val_ids.txt, test_ids.txt
    """
    splits_dir = Path(splits_dir)
    splits_dir.mkdir(parents=True, exist_ok=True)

    train_file = splits_dir / "train_ids.txt"
    val_file = splits_dir / "val_ids.txt"
    test_file = splits_dir / "test_ids.txt"

    if train_file.exists() and val_file.exists() and test_file.exists():
        print("splits already available")
        return str(train_file), str(val_file), str(test_file)

    print("no split files, creating new one")
    df_meta = pd.read_csv(metadata_csv)
    ids = list(df_meta["complex_id"].astype(str).tolist())

    random.Random(seed).shuffle(ids)
    n = len(ids)
    n_train = int(ratios[0] * n)
    n_val = int(ratios[1] * n)
    train_ids = ids[:n_train]
    val_ids = ids[n_train:n_train + n_val]
    test_ids = ids[n_train + n_val:]

    train_file.write_text("\n".join(train_ids))
    val_file.write_text("\n".join(val_ids))
    test_file.write_text("\n".join(test_ids))

    print(f"new splits: train->{len(train_ids)}, val->{len(val_ids)}, test->{len(test_ids)}")
    return str(train_file), str(val_file), str(test_file)

In [73]:
def evaluate(model, loader):
    model.eval()
    ys, preds = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(batch)
            out = out.numpy().reshape(-1)
            y = batch.y.numpy().reshape(-1)
            preds.append(out)
            ys.append(y)
    if len(preds) == 0:
        return {}
    preds = np.concatenate(preds)
    ys = np.concatenate(ys)
    metrics = {
        "rmse": rmse(ys, preds),
        "r2": float(r2_score(ys, preds)) if len(ys) > 1 else 0.0,
        "pearson": pearson_r(ys, preds)
    }
    return metrics

In [74]:
def train_one_epoch(model, loader, optim):
    model.train()
    total_loss = 0.0
    total_graphs = 0
    
    for batch in loader:
        optim.zero_grad()
        out = model(batch)
        target = batch.y.view(-1).to(out.dtype)
        loss = torch.nn.functional.mse_loss(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optim.step()

        total_loss += loss.item() * batch.num_graphs
        total_graphs += batch.num_graphs
    return total_loss / (total_graphs + 1e-12)

In [79]:
def main(
    metadata_csv="data/processed/refined_dataset_metadata.csv",
    complex_dir="data/processed/complex_graphs",
    split_dir="data/processed/splits",
    output_dir="experiments/runs/coregnn",
    epochs=50,
    batch_size=8,
    lr=1e-4,
    weight_decay=1e-5,
    hidden_dim=128,
    encoder_layers=3,
    seed=SEED,
    patience=10
):
    set_seed(seed)
    writer = SummaryWriter(log_dir=output_dir)
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    ensure_splits(metadata_csv=metadata_csv, splits_dir=split_dir, seed=seed)

    train_ds = PDBBindDataset(metadata_csv=metadata_csv,
                                     complex_dir=complex_dir,
                                     split_dir=split_dir,
                                     root="data/processed/complex_dataset",
                                     split="train",
                                     force_rebuild=False)
    val_ds = PDBBindDataset(metadata_csv=metadata_csv,
                                   complex_dir=complex_dir,
                                   split_dir=split_dir,
                                   root="data/processed/complex_dataset",
                                   split="val",
                                   force_rebuild=False)
    test_ds = PDBBindDataset(metadata_csv=metadata_csv,
                                    complex_dir=complex_dir,
                                    split_dir=split_dir,
                                    root="data/processed/complex_dataset",
                                    split="test",
                                    force_rebuild=False)

    print(f"datasets: train-> {len(train_ds)}, val-> {len(val_ds)}, test-> {len(test_ds)}")

    # dataloaders
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    # infer dims from a sample
    sample = train_ds.get(0)
    in_dim = sample.x.shape[1]
    # infer edge_dim if present
    edge_dim = sample.edge_attr.shape[1] if hasattr(sample, "edge_attr") and sample.edge_attr is not None else 0

    model = CoreGNN(in_dim, edge_dim, hidden_dim=hidden_dim, num_layers=encoder_layers)
    optim = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optim, mode="min", factor=0.5, patience=3, verbose=True)

    best_val = float("inf")
    best_epoch = -1
    early_stop_counter = 0

    for epoch in range(1, epochs + 1):
        t0 = time.time()
        train_loss = train_one_epoch(model, train_loader, optim)
        val_metrics = evaluate(model, val_loader)
        test_metrics = evaluate(model, test_loader)

        print(f"epoch {epoch:03d} | train_loss {train_loss:.4f} | val_rmse {val_metrics.get('rmse',float('nan')):.4f} | test_rmse {test_metrics.get('rmse',float('nan')):.4f} | time {time.time()-t0:.1f}s")
        writer.add_scalar("loss/train", train_loss, epoch)
        writer.add_scalar("rmse/val", val_metrics.get("rmse", np.nan), epoch)
        writer.add_scalar("rmse/test", test_metrics.get("rmse", np.nan), epoch)
        writer.add_scalar("r2/val", val_metrics.get("r2", np.nan), epoch)
        writer.add_scalar("pearson/val", val_metrics.get("pearson", np.nan), epoch)
        scheduler.step(val_metrics.get("rmse", math.nan))

        val_rmse = val_metrics.get("rmse", float("inf"))
        if val_rmse < best_val:
            best_val = val_rmse
            best_epoch = epoch
            early_stop_counter = 0
            ckpt_path = Path(output_dir) / "best_model.pt"
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optim.state_dict(),
                "val_rmse": val_rmse
            }, ckpt_path)
            print(f"best model -> {ckpt_path} (val_rmse={val_rmse:.4f})")
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            break

    ckpt_path = Path(output_dir) / "best_model.pt"
    if ckpt_path.exists():
        print("loading model...")
        ckpt = torch.load(ckpt_path)
        model.load_state_dict(ckpt["model_state"])

    final_test_metrics = evaluate(model, test_loader)
    print(f"test metrics: RMSE->{final_test_metrics.get('rmse',np.nan):.4f}, r2->{final_test_metrics.get('r2',np.nan):.4f}, Pearson->{final_test_metrics.get('pearson',np.nan):.4f}")

    Path(output_dir).joinpath("results.json").write_text(json.dumps({
        "final_test": final_test_metrics,
        "best_val_rmse": best_val,
        "best_epoch": best_epoch
    }, indent=2))

    writer.close()
    print("done")

In [80]:
main(epochs=50, batch_size=4, lr=1e-4, hidden_dim=128, encoder_layers=3, patience=10)

Split files found; using existing splits.
data/processed/splits train
Using 4252 complexes from split 'train'.
Complex graph missing for 2r58
Complex graph missing for 3c2f
Complex graph missing for 3g2y
Complex graph missing for 4qsv
Complex graph missing for 4u54
Complex graph missing for 3ao4
Complex graph missing for 4cs9
Complex graph missing for 3gv9
Complex graph missing for 6r9u
Complex graph missing for 6abx
Complex graph missing for 5cs3
Complex graph missing for 4tim
Complex graph missing for 5fe6
Complex graph missing for 3gqz
Complex graph missing for 4y3j
Complex graph missing for 5oxk
Complex graph missing for 4ahr
Complex graph missing for 4ahs
Complex graph missing for 4mre
Complex graph missing for 1x8d
Complex graph missing for 1m0n
Complex graph missing for 2aac
Complex graph missing for 4ury
Complex graph missing for 3ao5
Complex graph missing for 3i3b
Complex graph missing for 5m9w
Complex graph missing for 1ew8
Complex graph missing for 4np2
Complex graph missing

Processing...
Done!
Processing...


Complex graph missing for 4o6w
Complex graph missing for 3r5t
Complex graph missing for 5upf
Complex graph missing for 6gl8
Complex graph missing for 3ryz
Complex graph missing for 1t31
Complex graph missing for 2wgj
Complex graph missing for 3eqr
Complex graph missing for 5nk4
Complex graph missing for 6dj7
Complex graph missing for 2wtv
Complex graph missing for 4ei4
Complex graph missing for 5wal
Complex graph missing for 1w5w
Complex graph missing for 3fvk
Complex graph missing for 6dz2
Complex graph missing for 3egt
Complex graph missing for 1ec1
Complex graph missing for 4heg
Complex graph missing for 5hcy
Complex graph missing for 4bam
Complex graph missing for 2p4j
Complex graph missing for 3i9g
Complex graph missing for 5een
Complex graph missing for 5wqc
Complex graph missing for 2j4i
Complex graph missing for 2wyj
Complex graph missing for 3uri
Complex graph missing for 4h3f
Complex graph missing for 1qkt
Complex graph missing for 1bnw
Complex graph missing for 3ibu
Complex 

Done!
Processing...
Done!


epoch 001 | train_loss 35.6550 | val_rmse 6.7900 | test_rmse 5.9686 | time 1.4s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.7900)
epoch 002 | train_loss 35.0367 | val_rmse 6.7088 | test_rmse 5.9253 | time 1.3s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.7088)
epoch 003 | train_loss 34.3232 | val_rmse 6.6380 | test_rmse 5.8678 | time 1.3s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.6380)
epoch 004 | train_loss 33.7979 | val_rmse 6.5820 | test_rmse 5.8136 | time 1.2s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.5820)
epoch 005 | train_loss 33.3790 | val_rmse 6.5320 | test_rmse 5.7572 | time 1.3s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.5320)
epoch 006 | train_loss 32.5149 | val_rmse 6.4858 | test_rmse 5.7054 | time 1.5s
best model -> experiments/runs/coregnn/best_model.pt (val_rmse=6.4858)
epoch 007 | train_loss 32.1128 | val_rmse 6.4387 | test_rmse 5.6614 | time 1.3s
best model -> 