In [2]:
from google.colab import drive
drive.mount('/content/drive')
from pathlib import Path
BASE = Path("/content/drive/MyDrive/biolip_gnn")
V5_DIR = BASE / "graphs_labeled_v5_fixed837"
V6_DIR = BASE / "graphs_labeled_v6_feat837"     # new output folder
OUT    = BASE / "out"
CSV500 = OUT / "subset_500.csv"
STRUCT = BASE / "structures"

V6_DIR.mkdir(exist_ok=True)
OUT.mkdir(exist_ok=True)

npz_files = sorted(V5_DIR.glob("*.npz"))
print("v5 fixed graphs:", len(npz_files))
print("Example:", npz_files[0].name if npz_files else "NONE")

Mounted at /content/drive
v5 fixed graphs: 837
Example: 1AEI_A_000434.npz


In [3]:
import pandas as pd

df = pd.read_csv(CSV500)
rowidx_to_seq = dict(zip(df.index.astype(int), df["sequence"].astype(str)))

print("subset_500 rows:", len(df))
print("Example seq len:", len(rowidx_to_seq.get(0, "")))


subset_500 rows: 837
Example seq len: 236


In [5]:
import numpy as np
from collections import defaultdict, Counter

def infer_idx_to_aa(npz_paths, rowidx_to_seq, max_graphs=400):
    votes = defaultdict(Counter)
    used = 0

    for p in npz_paths[:max_graphs]:
        z = np.load(p, allow_pickle=True)
        row_idx = int(z["row_idx"])
        seq = rowidx_to_seq.get(row_idx, "")
        x_idx = z["x_idx"].astype(int)

        if seq and len(seq) == len(x_idx):
            used += 1
            for i, idxv in enumerate(x_idx):
                aa = seq[i]
                votes[int(idxv)][aa] += 1

    # choose most common letter for each idx
    idx_to_aa = {}
    for idxv, c in votes.items():
        idx_to_aa[idxv] = c.most_common(1)[0][0]

    return idx_to_aa, used

idx_to_aa, used = infer_idx_to_aa(npz_files, rowidx_to_seq)
print("Graphs used for mapping:", used)
print("Inferred idx_to_aa size:", len(idx_to_aa))
print("Sample mapping:", dict(list(idx_to_aa.items())[:10]))

missing = [i for i in range(max(idx_to_aa.keys())+1) if i not in idx_to_aa]
print("Missing indices (if any):", missing[:20])


Graphs used for mapping: 365
Inferred idx_to_aa size: 21
Sample mapping: {17: 'V', 13: 'Q', 5: 'G', 16: 'T', 8: 'K', 12: 'P', 6: 'H', 0: 'A', 15: 'S', 4: 'F'}
Missing indices (if any): []


In [6]:
# define cheap chemistry groups + AA property vector

#residue groups
HYDRO = set(list("AVLIMFWY"))
POLAR = set(list("STNQCY"))
POS = set(list("KRH"))
NEG = set(list("DE"))
AROM = set(list("FWY"))

# very small AA property vector
def aa_props(aa: str):
  return np.array([
      1.0 if aa in HYDRO else 0.0,
      1.0 if aa in POLAR else 0.0,
      1.0 if aa in POS else 0.0,
      1.0 if aa in NEG else 0.0,
      1.0 if aa in AROM else 0.0,
  ], dtype=np.float32)



In [7]:
# feature builder from existing graph data
# added = 15 float features per node

def build_features_for_graph(x_idx, edge_index, edge_dist, idx_to_aa):
  """
  x_idx: (N,)
  edge_index: (2,E) directed
  edge_dist: (E,) float; sequence edges often have 0 or very small dist
  """
  x_idx = x_idx.astype(int)
  N = len(x_idx)
  E = edge_index.shape[1]

  # map each nodes aa
  aa_list = np.array([idx_to_aa.get(int(i), "X") for i in x_idx])

  # contact edges only (ignore seq edges if dist <= 0)
  d = edge_dist.astype(float)
  mask = d > 0.0
  src = edge_index[0, mask].astype(int)
  dst = edge_index[1, mask].astype(int)
  dist = d[mask]

  # multi-radius degrees (counts)
  deg6 = np.zeros(N, dtype=np.float32)
  deg8 = np.zeros(N, dtype = np.float32)
  deg10 = np.zeros(N, dtype = np.float32)

  # distamce stats
  sumd = np.zeros(N, dtype=np.float32)
  mind = np.zeros(N, dtype=np.float32)
  cnt = np.zeros(N, dtype=np.float32)

  # neighborhood chemistry at 8Å

  # counts of neighbour groups within 8Å
  nh = np.zeros(N, dtype=np.float32)
  npol = np.zeros(N, dtype=np.float32)
  npos = np.zeros(N, dtype=np.float32)
  nneg = np.zeros(N, dtype=np.float32)
  naro = np.zeros(N, dtype=np.float32)
  n8cnt = np.zeros(N, dtype=np.float32)

  for s, t, dd in zip(src, dst, dist):
    # density bins
    if dd <= 6.0: deg6[s] += 1.0
    if dd <= 8.0: deg8[s] += 1.0
    if dd <= 10.0: deg10[s] += 1.0

    # dist stats
    sumd[s] += dd
    cnt[s] += 1.0
    if dd < mind[s]:
      mind[s] = dd

    # neighbor chemistry at 8Å
    if dd <= 8.0:
      aat = aa_list[t]
      n8cnt[s] += 1.0
      if aat in HYDRO: nh[s] += 1.0
      if aat in POLAR: npol[s] += 1.0
      if aat in POS: npos[s] += 1.0
      if aat in NEG: nneg[s] += 1.0
      if aat in AROM: naro[s]

  mean_d = sumd / (cnt + 1e-9)
  min_d = np.where(np.isfinite(mind), mind, 0.0).astype(np.float32)

  # neigbor fractions
  frac_h = nh / (n8cnt + 1e-9)
  frac_p = npol / (n8cnt + 1e-9)
  frac_pos = npos / (n8cnt + 1e-9)
  frac_neg = nneg / (n8cnt + 1e-9)
  frac_ar = naro / (n8cnt + 1e-9)

  # AA property verctor node
  aa_feat = np.vstack([aa_props(a) for a in aa_list])

  # pack float features
  float_feat = np.column_stack([
      aa_feat,
      deg6, deg8, deg10,
      mean_d, min_d,
      frac_h, frac_p, frac_pos, frac_neg, frac_ar
  ]).astype(np.float32)

  # normalize per-graph (z-score) for stability
  mu = float_feat.mean(axis = 0, keepdims = True)
  sd = float_feat.std(axis = 0, keepdims = True) + 1e-9
  float_feat = (float_feat - mu) / sd

  return float_feat


In [8]:
# write v6 graphs to a new folder

from tqdm import tqdm

fail = []
written = 0

for p in tqdm(npz_files):
    z = np.load(p, allow_pickle=True)
    x_idx = z["x_idx"]
    edge_index = z["edge_index"]
    edge_dist = z["edge_dist"]

    try:
        x_feat = build_features_for_graph(x_idx, edge_index, edge_dist, idx_to_aa)

        outpath = V6_DIR / p.name
        np.savez_compressed(
            outpath,
            pdb_id=z["pdb_id"], chain=z["chain"], row_idx=z["row_idx"],
            ligand_code=z["ligand_code"] if "ligand_code" in z.files else "",
            n_nodes=z["n_nodes"], resseq=z["resseq"],
            x_idx=z["x_idx"], x_feat=x_feat,              # NEW
            edge_index=z["edge_index"], edge_dist=z["edge_dist"],
            y=z["y"], label_mode=z["label_mode"]
        )
        written += 1
    except Exception as e:
        fail.append({"graph": p.stem, "err": str(e)})

print("Written v6 graphs:", written)
print("Failures:", len(fail))
if fail:
    import pandas as pd
    pd.DataFrame(fail).to_csv(OUT / "day14_v6_failures.csv", index=False)
    print("Saved failures:", OUT / "day14_v6_failures.csv")



100%|██████████| 837/837 [00:43<00:00, 19.21it/s]

Written v6 graphs: 837
Failures: 0





In [9]:
!pip -q install torch_geometric scikit-learn

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[?25h

SyntaxError: invalid syntax (ipython-input-146727132.py, line 11)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from sklearn.metrics import precision_recall_curve, average_precision_score

In [14]:
v6_files = sorted(V6_DIR.glob("*.npz"))
print("v6 graphs:", len(v6_files))

def load_v6_npz(p):
  z = np.load(p, allow_pickle = True)
  return {k: z[k] for k in z.files}

def to_pyg_v6(gdict):
  aa = torch.tensor(gdict["x_idx"], dtype = torch.long)
  xf = torch.tensor(gdict["x_feat"], dtype = torch.float)
  x = torch.cat([aa.view(-1, 1).float(), xf], dim = 1)
  edge_index = torch.tensor(gdict["edge_index"], dtype = torch.long)
  y = torch.tensor(gdict["y"], dtype = torch.long)
  return Data(x=x, edge_index=edge_index, y=y)

raw = [load_v6_npz(p) for p in v6_files]
dataset_v6 = [to_pyg_v6(g) for g in raw]
print("example x shape:", dataset_v6[0].x.shape)

v6 graphs: 837
example x shape: torch.Size([315, 16])


In [15]:
# model + helpers (same thresholding)

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=15):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        xf = data.x[:,1:].float()
        h = self.emb(aa_idx)
        h = torch.cat([h, xf], dim=1)
        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    probs_all, y_all = [], []
    for batch in loader:
        batch = batch.to(device)
        probs = torch.sigmoid(model(batch)).detach().cpu().numpy()
        y = batch.y.detach().cpu().numpy()
        probs_all.append(probs); y_all.append(y)
    return np.concatenate(probs_all), np.concatenate(y_all)

def split_dataset(ds, seed=42):
    ds = ds.copy()
    random.Random(seed).shuffle(ds)
    n = len(ds)
    n_train = int(0.7*n)
    n_val   = int(0.15*n)
    return ds[:n_train], ds[n_train:n_train+n_val], ds[n_train+n_val:]

def pos_weight_from_train(train_set):
    pos = sum(int(d.y.sum()) for d in train_set)
    tot = sum(int(d.y.numel()) for d in train_set)
    neg = tot - pos
    return torch.tensor([neg / max(pos, 1)], dtype=torch.float)

def threshold_max_f1(probs, y_true):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    f1 = (2*prec[:-1]*rec[:-1]) / (prec[:-1]+rec[:-1] + 1e-9)
    return float(thr[int(np.argmax(f1))])

def threshold_precision_target(probs, y_true, target_precision=0.20):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr
    ok = np.where(prec2 >= target_precision)[0]
    if len(ok) == 0:
        return None
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def metrics_at_threshold(probs, y_true, thr):
    y_pred = (probs >= thr).astype(int)
    tp = int(((y_pred==1) & (y_true==1)).sum())
    fp = int(((y_pred==1) & (y_true==0)).sum())
    fn = int(((y_pred==0) & (y_true==1)).sum())
    precision = tp / (tp+fp+1e-9)
    recall    = tp / (tp+fn+1e-9)
    f1        = (2*precision*recall) / (precision+recall+1e-9)
    return precision, recall, f1


In [16]:
# Train/eval on 5 seeds + save report

def train_one(ds, seed=42, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_set, val_set, test_set = split_dataset(ds, seed=seed)

    model = SAGE_NodeClassifier(extra_feats=15).to(device)
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    pw = pos_weight_from_train(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pw)
    opt  = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    model.train()
    for _ in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            loss = crit(model(batch), batch.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    # VAL thresholds
    val_probs, val_y = collect_probs_and_labels(model, val_loader, device)
    thr_f1  = threshold_max_f1(val_probs, val_y)
    thr_p20 = threshold_precision_target(val_probs, val_y, 0.20)
    thr_p15 = threshold_precision_target(val_probs, val_y, 0.15)

    fb20 = thr_p20 is None
    fb15 = thr_p15 is None
    if thr_p20 is None: thr_p20 = thr_f1
    if thr_p15 is None: thr_p15 = thr_f1

    # TEST
    test_probs, test_y = collect_probs_and_labels(model, test_loader, device)
    auprc = float(average_precision_score(test_y, test_probs))

    p_f1, r_f1, f1_f1 = metrics_at_threshold(test_probs, test_y, thr_f1)
    p20,  r20,  f1_20 = metrics_at_threshold(test_probs, test_y, thr_p20)
    p15,  r15,  f1_15 = metrics_at_threshold(test_probs, test_y, thr_p15)

    return {
        "seed": seed, "n_graphs": len(ds), "test_auprc": auprc,
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p20, "p20_fallback": fb20,
        "val_thr_p15": thr_p15, "p15_fallback": fb15,
        "test_P_maxF1": p_f1, "test_R_maxF1": r_f1, "test_F1_maxF1": f1_f1,
        "test_P_p20": p20,    "test_R_p20": r20,    "test_F1_p20": f1_20,
        "test_P_p15": p15,    "test_R_p15": r15,    "test_F1_p15": f1_15,
    }

seeds = [1, 7, 42, 123, 999]
rows = [train_one(dataset_v6, seed=s, epochs=10) for s in seeds]
report = pd.DataFrame(rows)
display(report)

summary = report[[
    "test_auprc",
    "test_P_maxF1","test_R_maxF1","test_F1_maxF1",
    "test_P_p20","test_R_p20","test_F1_p20",
    "test_P_p15","test_R_p15","test_F1_p15"
]].agg(["mean","std"])
display(summary)

path = OUT / "day14_report_v6_feat837.csv"
report.to_csv(path, index=False)
print("Saved:", path)


Unnamed: 0,seed,n_graphs,test_auprc,val_thr_maxf1,val_thr_p20,p20_fallback,val_thr_p15,p15_fallback,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
0,1,837,0.121844,0.7551,0.823984,False,0.738575,False,0.136975,0.327485,0.193159,0.181467,0.183236,0.182347,0.125043,0.353801,0.18478
1,7,837,0.103801,0.867221,0.928664,False,0.86375,False,0.140706,0.25736,0.18194,0.202559,0.090218,0.124836,0.13856,0.266857,0.182408
2,42,837,0.112955,0.744329,0.800555,False,0.719098,False,0.143073,0.27334,0.187831,0.181909,0.176131,0.178973,0.13099,0.315688,0.185154
3,123,837,0.132977,0.807826,0.874528,False,0.793703,False,0.149159,0.308561,0.201104,0.206776,0.16651,0.184471,0.140275,0.335842,0.197894
4,999,837,0.142762,0.735526,0.826129,False,0.766888,False,0.165678,0.293602,0.211824,0.260674,0.101665,0.14628,0.184989,0.213848,0.198374


Unnamed: 0,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
mean,0.122868,0.147118,0.29207,0.195172,0.206677,0.143552,0.163381,0.143971,0.297207,0.189722
std,0.015499,0.01128,0.027768,0.011678,0.032333,0.044052,0.026579,0.023728,0.056805,0.007753


Saved: /content/drive/MyDrive/biolip_gnn/out/day14_report_v6_feat837.csv
