In [1]:
from google.colab import drive
drive.mount("/content/drive")

from pathlib import Path
import numpy as np, pandas as pd, random
import torch, torch.nn as nn
import torch.nn.functional as F

BASE = Path("/content/drive/MyDrive/biolip_gnn")
GRAPH_DIR = BASE / "graphs_labeled_v6_feat837"
OUT_DIR = BASE / "out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

npz_files = sorted(GRAPH_DIR.glob("*.npz"))
print("Graphs:", len(npz_files))
assert len(npz_files) > 0, "No graphs found. Fix GRAPH_DIR."
print("Drive + paths ready.")


Mounted at /content/drive
Graphs: 837
Drive + paths ready.


In [3]:
# define splits + load one graph utility
def split_paths(paths, seed=42):
    paths = paths.copy()
    random.Random(seed).shuffle(paths)
    n = len(paths)
    ntr = int(0.70*n)
    nva = int(0.15*n)
    return paths[:ntr], paths[ntr:ntr+nva], paths[ntr+nva:]

def load_npz(p):
    z = np.load(p, allow_pickle=True)
    return {k: z[k] for k in z.files}

seeds = [1, 7, 42, 123, 999]
print("split protocol + loader ready.")


split protocol + loader ready.


In [4]:
# hard-negative sampler

def hard_negative_sample(g, near_cutoff=12.0, neg_per_pos=4, rand_neg_frac=0.25, max_total=800):
    """
    Returns:
      idx (np.array): node indices used for training this graph
      y_sub (np.array): labels for those nodes
    """
    y = g["y"].astype(int)
    N = len(y)
    pos = np.where(y == 1)[0]
    neg = np.where(y == 0)[0]

    if len(pos) == 0:
        # fallback: if no positives, just sample some negatives
        k = min(len(neg), 256)
        idx = np.random.choice(neg, size=k, replace=False)
        return idx, y[idx]

    edge_index = g["edge_index"]
    edge_dist  = g["edge_dist"].astype(float)

    # build adjacency list with distances
    # edge_index shape (2, E)
    src = edge_index[0]
    dst = edge_index[1]

    # for each node, list neighbors within near_cutoff (using dist)
    near = [set() for _ in range(N)]
    for s, d, dist in zip(src, dst, edge_dist):
        if dist <= near_cutoff and dist > 0:   # use real spatial edges only; ignore seq edges (dist=0)
            near[s].add(d)

    # negatives near positives
    near_neg = set()
    pos_set = set(pos.tolist())
    for p in pos:
        for nb in near[p]:
            if nb not in pos_set:
                near_neg.add(nb)
    near_neg = np.array(sorted(list(near_neg)), dtype=int)

    # how many negatives?
    target_neg = min(len(neg), neg_per_pos * len(pos))
    # take most from near-neg, remainder random
    take_near = min(len(near_neg), int((1 - rand_neg_frac) * target_neg))
    take_rand = target_neg - take_near

    chosen = []
    if take_near > 0:
        chosen.append(np.random.choice(near_neg, size=take_near, replace=False))

    # random negatives from all negatives excluding chosen near-neg
    if take_rand > 0:
        pool = np.setdiff1d(neg, chosen[0] if chosen else np.array([], dtype=int))
        if len(pool) > 0:
            take_rand = min(take_rand, len(pool))
            chosen.append(np.random.choice(pool, size=take_rand, replace=False))

    neg_chosen = np.concatenate(chosen) if chosen else np.array([], dtype=int)

    idx = np.concatenate([pos, neg_chosen])
    np.random.shuffle(idx)

    if len(idx) > max_total:
        idx = np.random.choice(idx, size=max_total, replace=False)

    return idx, y[idx]

print("hard-negative sampler ready.")


hard-negative sampler ready.


In [13]:
# Build SAGE model (baseline) + training loop that uses sampled nodes

!pip -q install torch_geometric -U
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv

def npz_to_data_baseline(g):
    edge_index = torch.tensor(g["edge_index"], dtype=torch.long)
    y = torch.tensor(g["y"], dtype=torch.long)
    aa = torch.tensor(g["x_idx"], dtype=torch.long)

    N = aa.numel()
    deg = torch.zeros(N, dtype=torch.float)
    deg.scatter_add_(0, edge_index[0], torch.ones(edge_index.shape[1]))
    deg = (deg - deg.mean()) / (deg.std() + 1e-9)

    x = torch.cat([aa.view(-1,1).float(), deg.view(-1,1)], dim=1)
    return Data(x=x, edge_index=edge_index, y=y)

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        xf = data.x[:,1:].float()
        h = self.emb(aa_idx)
        h = torch.cat([h, xf], dim=1)
        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("baseline GraphSAGE model ready.")

Device: cpu
baseline GraphSAGE model ready.


In [15]:
import torch.nn.functional as F
print("F is now:", F, "type:", type(F))


F is now: <module 'torch.nn.functional' from '/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py'> type: <class 'module'>


In [16]:
print("Type(F) right before training:", type(F))

Type(F) right before training: <class 'module'>


In [17]:
# Train/eval on 5 seeds: baseline vs +hardneg (same evaluation as Day 13/14)

from sklearn.metrics import precision_recall_curve, average_precision_score

@torch.no_grad()
def collect_probs(model, loader, device):
    model.eval()
    P, Y = [], []
    for b in loader:
        b = b.to(device)
        p = torch.sigmoid(model(b)).cpu().numpy()
        y = b.y.cpu().numpy()
        P.append(p); Y.append(y)
    return np.concatenate(P), np.concatenate(Y)

def thr_maxf1(probs, y):
    prec, rec, thr = precision_recall_curve(y, probs)
    f1 = (2*prec[:-1]*rec[:-1])/(prec[:-1]+rec[:-1]+1e-9)
    return float(thr[int(np.argmax(f1))])

def thr_precision_target(probs, y, target=0.20):
    prec, rec, thr = precision_recall_curve(y, probs)
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr
    ok = np.where(prec2 >= target)[0]
    if len(ok)==0:
        return None
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def prf_at_thr(probs, y, thr):
    pred = (probs >= thr).astype(int)
    tp = int(((pred==1) & (y==1)).sum())
    fp = int(((pred==1) & (y==0)).sum())
    fn = int(((pred==0) & (y==1)).sum())
    prec = tp/(tp+fp+1e-9)
    rec  = tp/(tp+fn+1e-9)
    f1   = 2*prec*rec/(prec+rec+1e-9)
    return prec, rec, f1

def compute_pos_weight(graphs):
    pos = sum(int(d.y.sum()) for d in graphs)
    tot = sum(int(d.y.numel()) for d in graphs)
    neg = tot - pos
    return torch.tensor([neg/max(pos,1)], dtype=torch.float)

def train_baseline(train_graphs, epochs=6):
    loader = DataLoader(train_graphs, batch_size=4, shuffle=True)
    model = SAGE_NodeClassifier(extra_feats=1).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=compute_pos_weight(train_graphs).to(device))
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    for _ in range(epochs):
        model.train()
        for b in loader:
            b = b.to(device)
            loss = crit(model(b), b.y.float())
            opt.zero_grad(); loss.backward(); opt.step()
    return model

def train_hardneg(train_paths, epochs=6):
    # train graph-by-graph (sampling nodes within each graph)
    model = SAGE_NodeClassifier(extra_feats=1).to(device)

    # pos_weight based on full labels (still ok)
    train_graphs_full = [npz_to_data_baseline(load_npz(p)) for p in train_paths]
    pos_w = compute_pos_weight(train_graphs_full).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pos_w)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    for ep in range(epochs):
        random.shuffle(train_paths)
        model.train()
        losses = []
        for p in train_paths:
            g = load_npz(p)
            d = npz_to_data_baseline(g).to(device)
            idx, y_sub = hard_negative_sample(g)
            idx_t = torch.tensor(idx, dtype=torch.long, device=device)
            y_t   = torch.tensor(y_sub, dtype=torch.float, device=device)

            logits = model(d)[idx_t]
            loss = crit(logits, y_t)
            opt.zero_grad(); loss.backward(); opt.step()
            losses.append(float(loss.item()))
        # small progress
        print(f"epoch {ep+1:02d} | hardneg avg loss {np.mean(losses):.4f}")

    return model

rows = []
for sd in seeds:
    tr, va, te = split_paths(npz_files, sd)

    # build Data objects for val/test (full graphs)
    val_graphs  = [npz_to_data_baseline(load_npz(p)) for p in va]
    test_graphs = [npz_to_data_baseline(load_npz(p)) for p in te]
    val_loader  = DataLoader(val_graphs, batch_size=4, shuffle=False)
    test_loader = DataLoader(test_graphs, batch_size=4, shuffle=False)

    # baseline
    train_graphs = [npz_to_data_baseline(load_npz(p)) for p in tr]
    m_base = train_baseline(train_graphs, epochs=6)

    vprob, vy = collect_probs(m_base, val_loader, device)
    tprob, ty = collect_probs(m_base, test_loader, device)

    thr_f1 = thr_maxf1(vprob, vy)
    thr_p20 = thr_precision_target(vprob, vy, 0.20)
    thr_p15 = thr_precision_target(vprob, vy, 0.15)

    row = {"seed": sd, "model":"baseline", "n_graphs": len(npz_files),
           "test_auprc": float(average_precision_score(ty, tprob))}
    P,R,f1_score = prf_at_thr(tprob, ty, thr_f1)
    row.update({"test_P_maxF1":P, "test_R_maxF1":R, "test_F1_maxF1":f1_score})

    if thr_p20 is None:
        row.update({"test_P_p20":np.nan,"test_R_p20":np.nan,"test_F1_p20":np.nan,"p20_fallback":True})
    else:
        P,R,f1_score = prf_at_thr(tprob, ty, thr_p20)
        row.update({"test_P_p20":P,"test_R_p20":R,"test_F1_p20":f1_score,"p20_fallback":False})

    if thr_p15 is None:
        row.update({"test_P_p15":np.nan,"test_R_p15":np.nan,"test_F1_p15":np.nan,"p15_fallback":True})
    else:
        P,R,f1_score = prf_at_thr(tprob, ty, thr_p15)
        row.update({"test_P_p15":P,"test_R_p15":R,"test_F1_p15":f1_score,"p15_fallback":False})

    rows.append(row)

    # +hard negatives
    m_hn = train_hardneg(tr, epochs=4)  # start with 4 epochs to keep it fast

    vprob, vy = collect_probs(m_hn, val_loader, device)
    tprob, ty = collect_probs(m_hn, test_loader, device)

    thr_f1 = thr_maxf1(vprob, vy)
    thr_p20 = thr_precision_target(vprob, vy, 0.20)
    thr_p15 = thr_precision_target(vprob, vy, 0.15)

    row = {"seed": sd, "model":"hardneg", "n_graphs": len(npz_files),
           "test_auprc": float(average_precision_score(ty, tprob))}
    P,R,f1_score = prf_at_thr(tprob, ty, thr_f1)
    row.update({"test_P_maxF1":P, "test_R_maxF1":R, "test_F1_maxF1":f1_score})

    if thr_p20 is None:
        row.update({"test_P_p20":np.nan,"test_R_p20":np.nan,"test_F1_p20":np.nan,"p20_fallback":True})
    else:
        P,R,f1_score = prf_at_thr(tprob, ty, thr_p20)
        row.update({"test_P_p20":P,"test_R_p20":R,"test_F1_p20":f1_score,"p20_fallback":False})

    if thr_p15 is None:
        row.update({"test_P_p15":np.nan,"test_R_p15":np.nan,"test_F1_p15":np.nan,"p15_fallback":True})
    else:
        P,R,f1_score = prf_at_thr(tprob, ty, thr_p15)
        row.update({"test_P_p15":P,"test_R_p15":R,"test_F1_p15":f1_score,"p15_fallback":False})

    rows.append(row)

df = pd.DataFrame(rows)
display(df)

# mean ± std by model
metrics = ["test_auprc","test_P_maxF1","test_R_maxF1","test_F1_maxF1","test_P_p20","test_R_p20","test_F1_p20","test_P_p15","test_R_p15","test_F1_p15"]
summary = df.groupby("model")[metrics].agg(["mean","std"])
display(summary)

save_path = OUT_DIR / "day15_hardneg_report.csv"
df.to_csv(save_path, index=False)
print("Saved:", save_path)
print("baseline vs hardneg evaluation completed + saved report.")


epoch 01 | hardneg avg loss 2.4183
epoch 02 | hardneg avg loss 2.3269
epoch 03 | hardneg avg loss 2.3070
epoch 04 | hardneg avg loss 2.2921
epoch 01 | hardneg avg loss 2.4442
epoch 02 | hardneg avg loss 2.3217
epoch 03 | hardneg avg loss 2.2880
epoch 04 | hardneg avg loss 2.2803
epoch 01 | hardneg avg loss 2.4407
epoch 02 | hardneg avg loss 2.3011
epoch 03 | hardneg avg loss 2.2922
epoch 04 | hardneg avg loss 2.2713
epoch 01 | hardneg avg loss 2.4472
epoch 02 | hardneg avg loss 2.3507
epoch 03 | hardneg avg loss 2.3147
epoch 04 | hardneg avg loss 2.3052
epoch 01 | hardneg avg loss 2.3835
epoch 02 | hardneg avg loss 2.3055
epoch 03 | hardneg avg loss 2.2783
epoch 04 | hardneg avg loss 2.2653


Unnamed: 0,seed,model,n_graphs,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,p20_fallback,test_P_p15,test_R_p15,test_F1_p15,p15_fallback
0,1,baseline,837,0.084438,0.111842,0.231969,0.150919,0.148515,0.087719,0.110294,False,0.116815,0.217349,0.151959,False
1,1,hardneg,837,0.068136,0.090868,0.197856,0.12454,0.12,0.032164,0.05073,False,0.106814,0.11306,0.109848,False
2,7,baseline,837,0.082763,0.101455,0.25831,0.145688,0.234234,0.024691,0.044674,False,0.223684,0.032289,0.056432,False
3,7,hardneg,837,0.062582,0.079285,0.193732,0.112521,,,,True,0.101093,0.035138,0.052149,False
4,42,baseline,837,0.094045,0.137048,0.175168,0.153781,0.223256,0.046198,0.076555,False,0.127004,0.190568,0.152425,False
5,42,hardneg,837,0.071115,0.090909,0.202117,0.125411,0.0,0.0,0.0,False,0.168889,0.036574,0.060127,False
6,123,baseline,837,0.107957,0.105914,0.330198,0.160384,0.20935,0.096896,0.132476,False,0.139031,0.229539,0.173172,False
7,123,hardneg,837,0.072247,0.077829,0.335842,0.126372,0.185185,0.009407,0.017905,False,0.159091,0.032926,0.05456,False
8,999,baseline,837,0.095434,0.121805,0.212971,0.154974,,,,True,0.295082,0.015776,0.02995,False
9,999,hardneg,837,0.077176,0.097131,0.246275,0.139316,,,,True,,,,True


Unnamed: 0_level_0,test_auprc,test_auprc,test_P_maxF1,test_P_maxF1,test_R_maxF1,test_R_maxF1,test_F1_maxF1,test_F1_maxF1,test_P_p20,test_P_p20,test_R_p20,test_R_p20,test_F1_p20,test_F1_p20,test_P_p15,test_P_p15,test_R_p15,test_R_p15,test_F1_p15,test_F1_p15
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
baseline,0.092928,0.01011,0.115613,0.0142,0.241723,0.057996,0.153149,0.005402,0.203839,0.038262,0.063876,0.034189,0.091,0.038502,0.180323,0.076862,0.137104,0.104342,0.112788,0.064788
hardneg,0.070251,0.005384,0.087204,0.008311,0.235165,0.060125,0.125632,0.009499,0.101728,0.093935,0.013857,0.016537,0.022878,0.025728,0.133972,0.03497,0.054424,0.039119,0.069171,0.027323


Saved: /content/drive/MyDrive/biolip_gnn/out/day15_hardneg_report.csv
baseline vs hardneg evaluation completed + saved report.


In [21]:
# Visualization (save-only) for 2 random test proteins: baseline vs hardneg

!pip install biopython -U
import matplotlib.pyplot as plt
import gzip
from PIL import Image
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB.Polypeptide import is_aa

STRUCT_DIR = BASE / "structures"
PLOT_DIR = OUT_DIR / "day15_pred_plots"
PLOT_DIR.mkdir(parents=True, exist_ok=True)

parser = MMCIFParser(QUIET=True)

def load_structure_from_cif_gz(pdb_id: str):
    path = STRUCT_DIR / f"{pdb_id.lower()}.cif.gz"
    if not path.exists():
        return None
    with gzip.open(path, "rt") as handle:
        return parser.get_structure(pdb_id.lower(), handle)

def chain_ca_map(structure, chain_id: str):
    if structure is None:
        return {}
    model = structure[0]
    if chain_id not in model:
        return {}
    m = {}
    for res in model[chain_id]:
        if not is_aa(res, standard=False):
            continue
        if "CA" not in res:
            continue
        rid = res.get_id()   # (' ', resseq, icode)
        resnum = int(rid[1])
        icode = rid[2].strip()
        if icode != "":
            continue
        m[resnum] = res["CA"].get_coord().astype(float)
    return m

def save_plot_for_model(npz_path, model, tag, topk=20):
    z = np.load(npz_path, allow_pickle=True)
    pdb_id = str(z["pdb_id"])
    chain  = str(z["chain"])
    resseq = z["resseq"].astype(int)
    y_true = z["y"].astype(int)

    structure = load_structure_from_cif_gz(pdb_id)
    cmap = chain_ca_map(structure, chain)
    if len(cmap)==0:
        print("No coords:", npz_path.name)
        return None

    coords, keep = [], []
    for i, r in enumerate(resseq):
        if int(r) in cmap:
            coords.append(cmap[int(r)])
            keep.append(i)
    coords = np.vstack(coords)
    keep = np.array(keep, dtype=int)

    # model probs
    g = load_npz(npz_path)
    d = npz_to_data_baseline(g).to(device)
    model.eval()
    with torch.no_grad():
        probs = torch.sigmoid(model(d)).cpu().numpy()

    yk = y_true[keep]
    pk = probs[keep]
    k = min(topk, len(pk))
    top_idx = np.argsort(-pk)[:k]

    fig = plt.figure(figsize=(7,6))
    ax = fig.add_subplot(111, projection="3d")

    ax.scatter(coords[:,0], coords[:,1], coords[:,2], s=6, alpha=0.12)
    if (yk==1).any():
        c = coords[yk==1]
        ax.scatter(c[:,0], c[:,1], c[:,2], s=30, alpha=0.95)

    c = coords[top_idx]
    ax.scatter(c[:,0], c[:,1], c[:,2], s=16, alpha=0.95)

    hits = int(yk[top_idx].sum())
    ax.set_title(f"{npz_path.stem} | {tag} | hits_in_top{topk}={hits}/{k}")
    ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z")

    outp = PLOT_DIR / f"{npz_path.stem}_{tag}_top{topk}.png"
    fig.savefig(outp, dpi=200, bbox_inches="tight")
    plt.close(fig)

    print("Saved:", outp.name, "| hits:", hits, "/", k)
    return outp

# pick 2 random test proteins for a chosen seed
seed_vis = 42
tr, va, te = split_paths(npz_files, seed_vis)
two = random.sample(te, 2)
print("Chosen test proteins:", [p.name for p in two])

# get the trained models from df (we just trained inside step 4 loop)
# easiest: re-train quickly for seed 42 for both baseline + hardneg
m_base = train_baseline([npz_to_data_baseline(load_npz(p)) for p in tr], epochs=6)
m_hn   = train_hardneg(tr, epochs=4)

saved = []
for p in two:
    saved.append(save_plot_for_model(p, m_base, "baseline", topk=20))
    saved.append(save_plot_for_model(p, m_hn, "hardneg", topk=20))

# sanity check
pngs = sorted(PLOT_DIR.glob("*.png"))
print("PNG count:", len(pngs))
print("Last 8:", [x.name for x in pngs[-8:]])

print("saved baseline vs hardneg plots for 2 test proteins.")


Collecting biopython
  Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.86
Chosen test proteins: ['1Q4S_A_000025.npz', '4FHA_A_000045.npz']
epoch 01 | hardneg avg loss 2.4140
epoch 02 | hardneg avg loss 2.3180
epoch 03 | hardneg avg loss 2.2877
epoch 04 | hardneg avg loss 2.2719
Saved: 1Q4S_A_000025_baseline_top20.png | hits: 1 / 20
Saved: 1Q4S_A_000025_hardneg_top20.png | hits: 1 / 20
Saved: 4FHA_A_000045_baseline_top20.png | hits: 0 / 20
Saved: 4FHA_A_000045_hardneg_top20.png | hits: 1 / 20
PNG count: 4
Last 8: ['1Q4S_A_000025_baseline_top20.png', '1Q4S_A_000025_hardneg_top20.png', '4FHA_A_000045_ba