In [4]:
from google.colab import drive
drive.mount('/content/drive')
print("Drive mounted")

from pathlib import Path
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

print("Imports done")
print("Torch:", torch.__version__)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted
Imports done
Torch: 2.9.0+cpu


In [5]:
BASE = Path("/content/drive/MyDrive/biolip_gnn")

GRAPH_DIR = BASE / "graphs_labeled_v6_feat837"
OUT_DIR = BASE / "out"
OUT_DIR.mkdir(exist_ok=True, parents=True)

npz_files = sorted(GRAPH_DIR.glob("*.npz"))
print("Graphs found:", len(npz_files))
assert len(npz_files) > 0, "No graphs found. Fix GRAPH_DIR path."

def load_npz(p):
    z = np.load(p, allow_pickle=True)
    return {k: z[k] for k in z.files}

# quick schema check
z0 = load_npz(npz_files[0])
print("Keys:", list(z0.keys()))
print("Has edge_dist:", "edge_dist" in z0)
print("Has x_feat:", "x_feat" in z0)


Graphs found: 837
Keys: ['pdb_id', 'chain', 'row_idx', 'ligand_code', 'n_nodes', 'resseq', 'x_idx', 'x_feat', 'edge_index', 'edge_dist', 'y', 'label_mode']
Has edge_dist: True
Has x_feat: True


In [6]:
!pip -q install torch_geometric -U

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [10]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from sklearn.metrics import precision_recall_curve, average_precision_score

def npz_to_data_baseline(g):
    # required
    edge_index = torch.tensor(g["edge_index"], dtype=torch.long)
    y = torch.tensor(g["y"], dtype=torch.long)
    aa = torch.tensor(g["x_idx"], dtype=torch.long)   # (N,)

    # degree (cheap)
    N = aa.numel()
    deg = torch.zeros(N, dtype=torch.float)
    deg.scatter_add_(0, edge_index[0], torch.ones(edge_index.shape[1]))
    deg = (deg - deg.mean()) / (deg.std() + 1e-9)

    # x = [aa_idx, deg]
    x = torch.cat([aa.view(-1,1).float(), deg.view(-1,1)], dim=1)  # (N,2)
    return Data(x=x, edge_index=edge_index, y=y)

dataset = [npz_to_data_baseline(load_npz(p)) for p in npz_files]
print("Built dataset:", len(dataset))


Built dataset: 837


In [17]:
# Model + helpers (training + threshold)

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        xf = data.x[:,1:].float()
        h = self.emb(aa_idx)
        h = torch.cat([h, xf], dim=1)
        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

def split_dataset(ds, seed):
    ds = ds.copy()
    random.Random(seed).shuffle(ds)
    n = len(ds)
    ntr = int(0.70*n)
    nva = int(0.15*n)
    return ds[:ntr], ds[ntr:ntr+nva], ds[ntr+nva:]

def compute_pos_weight(graphs):
    pos = sum(int(d.y.sum()) for d in graphs)
    tot = sum(int(d.y.numel()) for d in graphs)
    neg = tot - pos
    return torch.tensor([neg/max(pos,1)], dtype=torch.float)

@torch.no_grad()
def collect_probs(model, loader, device):
    model.eval()
    P, Y = [], []
    for b in loader:
        b = b.to(device)
        p = torch.sigmoid(model(b)).cpu().numpy()
        y = b.y.cpu().numpy()
        P.append(p); Y.append(y)
    return np.concatenate(P), np.concatenate(Y)

def thr_maxf1(probs, y):
    prec, rec, thr = precision_recall_curve(y, probs)
    f1 = (2*prec[:-1]*rec[:-1])/(prec[:-1]+rec[:-1]+1e-9)
    return float(thr[int(np.argmax(f1))])

def thr_precision_target(probs, y, target=0.20):
    prec, rec, thr = precision_recall_curve(y, probs)
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr
    ok = np.where(prec2 >= target)[0]
    if len(ok)==0:
        return None
    # choose threshold among those that maximizes recall
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def prf_at_thr(probs, y, thr):
    pred = (probs >= thr).astype(int)
    tp = int(((pred==1) & (y==1)).sum())
    fp = int(((pred==1) & (y==0)).sum())
    fn = int(((pred==0) & (y==1)).sum())
    prec = tp/(tp+fp+1e-9)
    rec  = tp/(tp+fn+1e-9)
    f1   = 2*prec*rec/(prec+rec+1e-9)
    return prec, rec, f1





In [19]:
# train baseline across seeds and save report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seeds = [1, 7, 42, 123, 999]

rows = []
for sd in seeds:
    train_set, val_set, test_set = split_dataset(dataset, sd)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    model = SAGE_NodeClassifier(extra_feats=1).to(device)
    pos_w = compute_pos_weight(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pos_w)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # short train (keep lightweight)
    for epoch in range(6):
        model.train()
        for b in train_loader:
            b = b.to(device)
            loss = crit(model(b), b.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    # thresholds from validation
    val_probs, val_y = collect_probs(model, val_loader, device)
    thr_f1 = thr_maxf1(val_probs, val_y)

    thr_p20 = thr_precision_target(val_probs, val_y, 0.20)
    thr_p15 = thr_precision_target(val_probs, val_y, 0.15)

    # evaluate on test
    test_probs, test_y = collect_probs(model, test_loader, device)
    test_auprc = float(average_precision_score(test_y, test_probs))

    P1,R1,F1 = prf_at_thr(test_probs, test_y, thr_f1)

    if thr_p20 is None:
        P20=R20=F20=np.nan
        p20_fallback=True
    else:
        P20,R20,F20 = prf_at_thr(test_probs, test_y, thr_p20)
        p20_fallback=False

    if thr_p15 is None:
        P15=R15=F15=np.nan
        p15_fallback=True
    else:
        P15,R15,F15 = prf_at_thr(test_probs, test_y, thr_p15)
        p15_fallback=False

    rows.append({
        "seed": sd,
        "n_graphs": len(dataset),
        "test_auprc": test_auprc,
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p20,
        "p20_fallback": p20_fallback,
        "val_thr_p15": thr_p15,
        "p15_fallback": p15_fallback,
        "test_P_maxF1": P1, "test_R_maxF1": R1, "test_F1_maxF1": F1,
        "test_P_p20": P20, "test_R_p20": R20, "test_F1_p20": F20,
        "test_P_p15": P15, "test_R_p15": R15, "test_F1_p15": F15,
    })

df_base = pd.DataFrame(rows)
display(df_base)

summary = df_base[["test_auprc","test_P_maxF1","test_R_maxF1","test_F1_maxF1","test_P_p20","test_R_p20","test_F1_p20","test_P_p15","test_R_p15","test_F1_p15"]].agg(["mean","std"])
display(summary)

save_path = OUT_DIR / "day14_baseline_repeat.csv"
df_base.to_csv(save_path, index=False)
print("Saved:", save_path)


Unnamed: 0,seed,n_graphs,test_auprc,val_thr_maxf1,val_thr_p20,p20_fallback,val_thr_p15,p15_fallback,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
0,1,837,0.084099,0.72273,0.819596,False,0.717184,False,0.114198,0.216374,0.149495,0.140909,0.060429,0.084584,0.112023,0.226121,0.149822
1,7,837,0.083421,0.710454,0.895505,False,0.830943,False,0.095843,0.31529,0.147,0.186813,0.016144,0.02972,0.16094,0.08452,0.110834
2,42,837,0.088944,0.78146,0.841317,False,0.775453,False,0.123097,0.179018,0.145882,0.158416,0.076997,0.103627,0.121359,0.192493,0.148865
3,123,837,0.102945,0.742148,0.807495,False,0.754247,False,0.12796,0.238946,0.166667,0.182085,0.116651,0.142202,0.138279,0.219191,0.169578
4,999,837,0.098392,0.737348,,True,0.87556,False,0.123035,0.274321,0.169878,,,,0.182741,0.031551,0.053812


Unnamed: 0,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
mean,0.09156,0.116826,0.24479,0.155784,0.167056,0.067555,0.090033,0.143069,0.150775,0.126582
std,0.008733,0.012738,0.052464,0.011531,0.021405,0.041607,0.046811,0.028945,0.08761,0.045904


Saved: /content/drive/MyDrive/biolip_gnn/out/day14_baseline_repeat.csv


In [20]:
# distance aware rdges and distance buckets

#Distance-weighted GCN version
from torch_geometric.nn import GCNConv
def npz_to_data_gcn_weighted(g):
  edge_index = torch.tensor (g["edge_index"], dtype=torch.long)
  y = torch.tensor (g["y"], dtype=torch.long)
  aa = torch.tensor (g["x_idx"], dtype=torch.long)

  # degree as node feature
  N = aa.numel()
  deg = torch.zeros(N, dtype=torch.float)
  deg.scatter_add_(0, edge_index[0], torch.ones(edge_index.shape[1]))
  deg = (deg - deg.mean()) / (deg.std() + 1e+9)

  x = torch.cat([aa.view(-1, 1).float(), deg.view(-1, 1)], dim=1)

  # edge weights: dist-aware
  dist = torch.tensor(g["edge_dist"], dtype=torch.float)
  w = torch.where(dist > 0, 1.0/(dist +1e-3), torch.ones_like(dist))
  return Data(x=x, edge_index=edge_index, edge_weight=w, y=y)

dataset_gcn = [npz_to_data_gcn_weighted(load_npz(p)) for p in npz_files]
print("BUILT DISTANCE_WEIGHTED DATASET:", len(dataset_gcn))

class GCN_NodeClassifier(nn.Module):
  def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
    super().__init__()
    self.emb = nn.Embedding(num_aa, emb_dim)
    in_dim = emb_dim + extra_feats
    self.conv1 = GCNConv(in_dim, hidden)
    self.conv2 = GCNConv(hidden, hidden)
    self.lin1 = nn.Linear(hidden, hidden)
    self.lin2 = nn.Linear(hidden, 1)

  def forward(self, data):
    aa_idx = data.x[:,0].long()
    xf = data.x[:,1:].float()
    h = self.emb(aa_idx)
    h = torch.cat([h, xf], dim=1)
    h = F.relu(self.conv1(h, data.edge_index, edge_weight=data.edge_weight))
    h = F.relu(self.conv2(h, data.edge_index, edge_weight=data.edge_weight))
    h = F.relu(self.lin1(h))
    return self.lin2(h).squeeze(-1)


BUILT DISTANCE_WEIGHTED DATASET: 837


In [21]:
# train/eval the weighted GCN across seeds

rows = []
for sd in seeds:
    train_set, val_set, test_set = split_dataset(dataset_gcn, sd)
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    model = GCN_NodeClassifier(extra_feats=1).to(device)
    pos_w = compute_pos_weight(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pos_w)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    for epoch in range(6):
        model.train()
        for b in train_loader:
            b = b.to(device)
            loss = crit(model(b), b.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    val_probs, val_y = collect_probs(model, val_loader, device)
    thr_f1 = thr_maxf1(val_probs, val_y)
    thr_p20 = thr_precision_target(val_probs, val_y, 0.20)
    thr_p15 = thr_precision_target(val_probs, val_y, 0.15)

    test_probs, test_y = collect_probs(model, test_loader, device)
    test_auprc = float(average_precision_score(test_y, test_probs))

    P1,R1,F1 = prf_at_thr(test_probs, test_y, thr_f1)

    if thr_p20 is None:
        P20=R20=F20=np.nan; p20_fallback=True
    else:
        P20,R20,F20 = prf_at_thr(test_probs, test_y, thr_p20); p20_fallback=False

    if thr_p15 is None:
        P15=R15=F15=np.nan; p15_fallback=True
    else:
        P15,R15,F15 = prf_at_thr(test_probs, test_y, thr_p15); p15_fallback=False

    rows.append({
        "seed": sd,
        "n_graphs": len(dataset_gcn),
        "test_auprc": test_auprc,
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p20,
        "p20_fallback": p20_fallback,
        "val_thr_p15": thr_p15,
        "p15_fallback": p15_fallback,
        "test_P_maxF1": P1, "test_R_maxF1": R1, "test_F1_maxF1": F1,
        "test_P_p20": P20, "test_R_p20": R20, "test_F1_p20": F20,
        "test_P_p15": P15, "test_R_p15": R15, "test_F1_p15": F15,
    })

df_gcn = pd.DataFrame(rows)
display(df_gcn)

summary_gcn = df_gcn[["test_auprc","test_P_maxF1","test_R_maxF1","test_F1_maxF1","test_P_p20","test_R_p20","test_F1_p20","test_P_p15","test_R_p15","test_F1_p15"]].agg(["mean","std"])
display(summary_gcn)

save_path = OUT_DIR / "day14_distanceaware_report.csv"
df_gcn.to_csv(save_path, index=False)
print("Saved:", save_path)


Unnamed: 0,seed,n_graphs,test_auprc,val_thr_maxf1,val_thr_p20,p20_fallback,val_thr_p15,p15_fallback,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
0,1,837,0.0522,0.748725,0.931197,False,0.854477,False,0.06384,0.133528,0.086381,0.083333,0.000975,0.001927,0.087248,0.012671,0.022128
1,7,837,0.05847,0.648251,0.779424,False,0.730111,False,0.08513,0.146249,0.107617,0.066667,0.013295,0.022169,0.08229,0.043685,0.057072
2,42,837,0.062781,0.680085,0.817168,False,0.782904,False,0.075949,0.236766,0.115007,0.136054,0.019249,0.033727,0.120614,0.052936,0.073579
3,123,837,0.070161,0.685242,0.911745,False,0.775311,False,0.089758,0.233302,0.129639,0.0,0.0,0.0,0.129666,0.062088,0.083969
4,999,837,0.066601,0.714234,,True,,True,0.077515,0.255916,0.118989,,,,,,


Unnamed: 0,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
mean,0.062043,0.078438,0.201152,0.111527,0.071514,0.00838,0.014456,0.104955,0.042845,0.059187
std,0.007015,0.009913,0.056764,0.016151,0.056104,0.00944,0.016297,0.023686,0.021473,0.027075


Saved: /content/drive/MyDrive/biolip_gnn/out/day14_distanceaware_report.csv


In [23]:
# Top-K metrics on test proteins

def precision_at_k(probs, y, k):
    k = min(k, len(probs))
    idx = np.argsort(-probs)[:k]
    return float(y[idx].mean())

def topk_report_per_graph(model, test_set, device, Ks=(5,10,20)):
    model.eval()
    rows = []
    for d in test_set:
        d = d.to(device)
        p = torch.sigmoid(model(d)).detach().cpu().numpy()
        y = d.y.detach().cpu().numpy().astype(int)
        L = len(y)
        row = {"L": L, "pos_rate": float(y.mean())}
        for k in Ks:
            row[f"P@{k}"] = precision_at_k(p, y, k)
        row["P@L20"] = precision_at_k(p, y, max(1, L//20))
        rows.append(row)
    return pd.DataFrame(rows)

# Example: compute Top-K for ONE seed for baseline vs distance-aware
sd = 42

# baseline model re-train quickly for seed 42
train_set, val_set, test_set = split_dataset(dataset, sd)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
model_b = SAGE_NodeClassifier(extra_feats=1).to(device)
crit = nn.BCEWithLogitsLoss(pos_weight=compute_pos_weight(train_set).to(device))
opt = torch.optim.Adam(model_b.parameters(), lr=1e-3, weight_decay=1e-4)
for epoch in range(6):
    model_b.train()
    for b in train_loader:
        b = b.to(device)
        loss = crit(model_b(b), b.y.float())
        opt.zero_grad(); loss.backward(); opt.step()

df_topk_base = topk_report_per_graph(model_b, test_set, device)
display(df_topk_base.describe())

# distance-aware model re-train quickly for seed 42
train_set_g, val_set_g, test_set_g = split_dataset(dataset_gcn, sd)
train_loader_g = DataLoader(train_set_g, batch_size=4, shuffle=True)
model_g = GCN_NodeClassifier(extra_feats=1).to(device)
crit = nn.BCEWithLogitsLoss(pos_weight=compute_pos_weight(train_set_g).to(device))
opt = torch.optim.Adam(model_g.parameters(), lr=1e-3, weight_decay=1e-4)
for epoch in range(6):
    model_g.train()
    for b in train_loader_g:
        b = b.to(device)
        loss = crit(model_g(b), b.y.float())
        opt.zero_grad(); loss.backward(); opt.step()

df_topk_gcn = topk_report_per_graph(model_g, test_set_g, device)
display(df_topk_gcn.describe())

# save combined summary
topk_summary = pd.DataFrame({
    "metric": ["P@5","P@10","P@20","P@L20"],
    "baseline_mean": [df_topk_base["P@5"].mean(), df_topk_base["P@10"].mean(), df_topk_base["P@20"].mean(), df_topk_base["P@L20"].mean()],
    "gcn_mean":      [df_topk_gcn["P@5"].mean(), df_topk_gcn["P@10"].mean(), df_topk_gcn["P@20"].mean(), df_topk_gcn["P@L20"].mean()],
})
display(topk_summary)

save_path = OUT_DIR / "day14_topk_metrics.csv"
topk_summary.to_csv(save_path, index=False)
print("Saved:", save_path)


Unnamed: 0,L,pos_rate,P@5,P@10,P@20,P@L20
count,127.0,127.0,127.0,127.0,127.0,127.0
mean,285.685039,0.032678,0.152756,0.146457,0.120079,0.127916
std,97.973571,0.03557,0.220327,0.181176,0.129905,0.15457
min,85.0,0.004264,0.0,0.0,0.0,0.0
25%,221.0,0.013269,0.0,0.0,0.0,0.0
50%,288.0,0.023904,0.0,0.1,0.1,0.076923
75%,366.0,0.041317,0.2,0.2,0.2,0.2
max,499.0,0.314516,1.0,0.8,0.65,0.75


Unnamed: 0,L,pos_rate,P@5,P@10,P@20,P@L20
count,127.0,127.0,127.0,127.0,127.0,127.0
mean,285.685039,0.032678,0.108661,0.103937,0.088189,0.093231
std,97.973571,0.03557,0.18983,0.151396,0.118933,0.137428
min,85.0,0.004264,0.0,0.0,0.0,0.0
25%,221.0,0.013269,0.0,0.0,0.0,0.0
50%,288.0,0.023904,0.0,0.0,0.05,0.0
75%,366.0,0.041317,0.2,0.2,0.15,0.157895
max,499.0,0.314516,0.8,0.8,0.75,0.8


Unnamed: 0,metric,baseline_mean,gcn_mean
0,P@5,0.152756,0.108661
1,P@10,0.146457,0.103937
2,P@20,0.120079,0.088189
3,P@L20,0.127916,0.093231


Saved: /content/drive/MyDrive/biolip_gnn/out/day14_topk_metrics.csv


In [24]:
# visualization

from pathlib import Path
import random

BASE = Path("/content/drive/MyDrive/biolip_gnn")
GRAPH_DIR = BASE / "graphs_labeled_v6_feat837"
STRUCT_DIR = BASE / "structures"
PLOT_DIR = BASE / "out" / "day14_pred_plots"
PLOT_DIR.mkdir(parents=True, exist_ok=True)

npz_files = sorted(GRAPH_DIR.glob("*.npz"))
print("Graphs:", len(npz_files))
print("Structures (.cif.gz):", len(list(STRUCT_DIR.glob("*.cif.gz"))))
print("Saving plots to:", PLOT_DIR)

# Use the SAME split protocol (seed 42, 70/15/15)
seed_split = 42
paths = npz_files.copy()
random.Random(seed_split).shuffle(paths)
n = len(paths)
ntr = int(0.70*n)
nva = int(0.15*n)
test_paths = paths[ntr+nva:]

print("Test graphs:", len(test_paths))

# pick 2 random from test split
two_paths = random.sample(test_paths, 2)
print("Chosen 2 test graphs:")
for p in two_paths:
    print(" -", p.name)

print("picked 2 random proteins from the TEST split (seed=42).")


Graphs: 837
Structures (.cif.gz): 893
Saving plots to: /content/drive/MyDrive/biolip_gnn/out/day14_pred_plots
Test graphs: 127
Chosen 2 test graphs:
 - 7BKB_E_000424.npz
 - 3ABI_A_000316.npz
picked 2 random proteins from the TEST split (seed=42).


In [25]:
import numpy as np
import gzip
import matplotlib.pyplot as plt

!pip -q install biopython
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB.Polypeptide import is_aa

parser = MMCIFParser(QUIET=True)

def as_str(x):
    try:
        if isinstance(x, np.ndarray) and x.shape == ():
            x = x.item()
        elif isinstance(x, np.ndarray) and x.size == 1:
            x = x.reshape(()).item()
    except:
        pass
    if isinstance(x, (bytes, bytearray)):
        return x.decode("utf-8").strip()
    return str(x).strip()

def load_structure_from_cif_gz(pdb_id: str):
    path = STRUCT_DIR / f"{pdb_id.lower()}.cif.gz"
    if not path.exists():
        return None
    with gzip.open(path, "rt") as handle:
        return parser.get_structure(pdb_id.lower(), handle)

def chain_ca_map(structure, chain_id: str):
    if structure is None:
        return {}
    model = structure[0]
    if chain_id not in model:
        return {}
    m = {}
    for res in model[chain_id]:
        if not is_aa(res, standard=False):
            continue
        if "CA" not in res:
            continue
        rid = res.get_id()          # (' ', resseq, icode)
        resnum = int(rid[1])
        icode = rid[2].strip()
        if icode != "":
            continue
        m[resnum] = res["CA"].get_coord().astype(float)
    return m

def save_true_vs_topk(npz_path, probs, topk=20):
    z = np.load(npz_path, allow_pickle=True)
    pdb_id = as_str(z["pdb_id"])
    chain  = as_str(z["chain"])
    resseq = z["resseq"].astype(int)
    y_true = z["y"].astype(int)

    # Load coords
    structure = load_structure_from_cif_gz(pdb_id)
    cmap = chain_ca_map(structure, chain)
    if len(cmap) == 0:
        print("No coords for", npz_path.name, "| pdb:", pdb_id, "chain:", chain)
        return None

    # match coords in the SAME residue order as graph nodes
    coords = []
    keep_idx = []
    for i, r in enumerate(resseq):
        rr = int(r)
        if rr in cmap:
            coords.append(cmap[rr])
            keep_idx.append(i)

    if len(coords) == 0:
        print("No matched coords for", npz_path.name)
        return None

    coords = np.vstack(coords)
    keep_idx = np.array(keep_idx, dtype=int)

    # restrict y/probs to those nodes that actually had coords
    yk = y_true[keep_idx]
    pk = probs[keep_idx]

    # top-k predicted indices (within the kept list)
    k = min(topk, len(pk))
    top_idx = np.argsort(-pk)[:k]

    # save-only plot
    fig = plt.figure(figsize=(7,6))
    ax = fig.add_subplot(111, projection="3d")

    # all residues (faint)
    ax.scatter(coords[:,0], coords[:,1], coords[:,2], s=6, alpha=0.15)

    # true binding residues (orange big)
    if (yk == 1).any():
        c = coords[yk == 1]
        ax.scatter(c[:,0], c[:,1], c[:,2], s=30, alpha=0.95)

    # predicted top-k residues (green medium)
    if len(top_idx) > 0:
        c = coords[top_idx]
        ax.scatter(c[:,0], c[:,1], c[:,2], s=16, alpha=0.95)

    ax.set_title(f"{npz_path.stem} | true(y=1)=orange | pred(top{topk})=green")
    ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")

    out_path = PLOT_DIR / f"{npz_path.stem}_top{topk}.png"
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close(fig)

    ok = out_path.exists() and out_path.stat().st_size > 0
    print("Saved:", out_path.name, "| bytes:", out_path.stat().st_size if out_path.exists() else 0, "| ok:", ok)

    # quick tiny summary
    true_pos = int(yk.sum())
    hit_in_topk = int(yk[top_idx].sum()) if len(top_idx) else 0
    print(f"   True positives in this protein: {true_pos} / {len(yk)}")
    print(f"   Hits inside top{topk}: {hit_in_topk} (i.e., {hit_in_topk}/{k} of our guesses were truly binding)")

    return out_path

print("plotting functions ready (save-only).")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[?25hplotting functions ready (save-only).


In [27]:
from PIL import Image
import torch

# Sanity: make sure model_g exists
print("Has model_g:", "model_g" in globals())
print("Has device:", "device" in globals())
print("Has npz_to_data_gcn_weighted:", "npz_to_data_gcn_weighted" in globals())

assert "model_g" in globals(), "model_g not found. Use the trained distance-aware model from Step 3."
assert "device" in globals(), "device not found."
assert "npz_to_data_gcn_weighted" in globals(), "npz_to_data_gcn_weighted not found (from Step 2A)."

model_g.eval()

saved = []
for p in two_paths:
    g = load_npz(p)
    d = npz_to_data_gcn_weighted(g).to(device)

    with torch.no_grad():
        probs = torch.sigmoid(model_g(d)).detach().cpu().numpy()

    outp = save_true_vs_topk(p, probs, topk=20)
    if outp is not None:
        saved.append(outp)

# Combine into one image (side-by-side)
if len(saved) == 2:
    img1 = Image.open(saved[0]).convert("RGB")
    img2 = Image.open(saved[1]).convert("RGB")

    H = max(img1.height, img2.height)
    W = img1.width + img2.width
    canvas = Image.new("RGB", (W, H), (255,255,255))
    canvas.paste(img1, (0, 0))
    canvas.paste(img2, (img1.width, 0))

    combo_path = PLOT_DIR / "day14_two_random_top20.png"
    canvas.save(combo_path)
    print("Saved combined image:", combo_path.name, "| bytes:", combo_path.stat().st_size)

# List latest saved pngs
pngs = sorted(PLOT_DIR.glob("*.png"))
print("\nPNG count in folder:", len(pngs))
print("Last 10 PNGs:", [x.name for x in pngs[-10:]])

print("generated & saved top-20 prediction visualizations for 2 random TEST proteins.")


Has model_g: True
Has device: True
Has npz_to_data_gcn_weighted: True
Saved: 7BKB_E_000424_top20.png | bytes: 290146 | ok: True
   True positives in this protein: 5 / 411
   Hits inside top20: 2 (i.e., 2/20 of our guesses were truly binding)
Saved: 3ABI_A_000316_top20.png | bytes: 245278 | ok: True
   True positives in this protein: 6 / 349
   Hits inside top20: 1 (i.e., 1/20 of our guesses were truly binding)
Saved combined image: day14_two_random_top20.png | bytes: 469434

PNG count in folder: 4
Last 10 PNGs: ['3ABI_A_000316_top20.png', '7BKB_E_000424_top20.png', 'TEST_WRITE.png', 'day14_two_random_top20.png']
generated & saved top-20 prediction visualizations for 2 random TEST proteins.
