In [1]:
# mount drive + load graphs + build AA+degree dataset

from google.colab import drive
drive.mount('/content/drive')

from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

!pip -q install torch_geometric scikit-learn

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import degree

from sklearn.metrics import precision_recall_curve, average_precision_score

BASE = Path("/content/drive/MyDrive/biolip_gnn")
LABELED_DIR = BASE / "graphs_labeled_v3"
OUT_DIR = BASE / "out"
OUT_DIR.mkdir(exist_ok=True)

npz_files = sorted(LABELED_DIR.glob("*.npz"))
print("Graphs found in graphs_labeled_v3:", len(npz_files))

def load_npz(path: Path) -> dict:
    z = np.load(path, allow_pickle=True)
    return {k: z[k] for k in z.files}

raw_graphs = [load_npz(p) for p in npz_files]

def compute_degree(edge_index: torch.Tensor, n_nodes: int) -> torch.Tensor:
    deg = degree(edge_index[0], num_nodes=n_nodes).view(-1, 1).float()
    return (deg - deg.mean()) / (deg.std() + 1e-9)

def to_pyg_degree(graph_dict: dict) -> Data:
    x_idx = torch.tensor(graph_dict["x_idx"], dtype=torch.long)              # (N,)
    edge_index = torch.tensor(graph_dict["edge_index"], dtype=torch.long)    # (2,E)
    y = torch.tensor(graph_dict["y"], dtype=torch.long)                      # (N,)

    n_nodes = x_idx.numel()
    deg = compute_degree(edge_index, n_nodes)                                # (N,1)

    # x: col0 AA index, col1 normalized degree
    x = torch.cat([x_idx.view(-1,1), deg], dim=1)                             # (N,2)

    edge_attr = None
    if "edge_dist" in graph_dict:
        edge_attr = torch.tensor(graph_dict["edge_dist"], dtype=torch.float).view(-1,1)

    d = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    return d

dataset = [to_pyg_degree(g) for g in raw_graphs]
print("Dataset size:", len(dataset), "Example x shape:", dataset[0].x.shape)


Mounted at /content/drive
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[?25hGraphs found in graphs_labeled_v3: 50
Dataset size: 50 Example x shape: torch.Size([387, 2])


In [2]:
# define model + helpers to collect probabilites/labels

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats

        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        h = self.emb(aa_idx)
        extras = data.x[:,1:].float()
        h = torch.cat([h, extras], dim=1)

        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    all_probs, all_y = [], []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        y = batch.y.detach().cpu().numpy()
        all_probs.append(probs)
        all_y.append(y)
    return np.concatenate(all_probs), np.concatenate(all_y)

def pos_weight_from_train(train_set):
    pos = sum(int(d.y.sum()) for d in train_set)
    tot = sum(int(d.y.numel()) for d in train_set)
    neg = tot - pos
    return torch.tensor([neg / max(pos, 1)], dtype=torch.float)


In [3]:
# two threshold rules: max F1 and precision >= 0.20

def threshold_max_f1(probs, y_true):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    f1 = (2*prec[:-1]*rec[:-1]) / (prec[:-1]+rec[:-1] + 1e-9)
    i = int(np.argmax(f1))
    return float(thr[i])

def threshold_precision_target(probs, y_true, target_precision=0.20):
    """
    Choose the *highest recall* threshold that achieves precision >= target_precision.
    If none achieve target, fall back to max-F1 threshold.
    """
    prec, rec, thr = precision_recall_curve(y_true, probs)
    # thr aligns with prec[:-1], rec[:-1]
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr

    ok = np.where(prec2 >= target_precision)[0]
    if len(ok) == 0:
        return None  # caller will fallback

    # among those thresholds, pick the one with max recall
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def metrics_at_threshold(probs, y_true, thr):
    y_pred = (probs >= thr).astype(int)
    tp = int(((y_pred==1) & (y_true==1)).sum())
    fp = int(((y_pred==1) & (y_true==0)).sum())
    fn = int(((y_pred==0) & (y_true==1)).sum())
    precision = tp / (tp+fp+1e-9)
    recall    = tp / (tp+fn+1e-9)
    f1        = (2*precision*recall) / (precision+recall+1e-9)
    auprc     = float(average_precision_score(y_true, probs))
    return auprc, precision, recall, f1


In [4]:
# train + evaluate one run with both threshold rules

from torch_geometric.loader import DataLoader

def split_dataset(ds, seed=42):
    ds = ds.copy()
    random.Random(seed).shuffle(ds)
    n = len(ds)
    n_train = int(0.7*n)
    n_val   = int(0.15*n)
    return ds[:n_train], ds[n_train:n_train+n_val], ds[n_train+n_val:]

def train_one(ds, seed=42, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_set, val_set, test_set = split_dataset(ds, seed=seed)

    model = SAGE_NodeClassifier(extra_feats=1).to(device)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    pw = pos_weight_from_train(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pw)
    opt  = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    model.train()
    for _ in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            logits = model(batch)
            loss = crit(logits, batch.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    # validation probs for threshold selection
    val_probs, val_y = collect_probs_and_labels(model, val_loader, device)
    thr_f1 = threshold_max_f1(val_probs, val_y)

    thr_p = threshold_precision_target(val_probs, val_y, target_precision=0.20)
    if thr_p is None:
        thr_p = thr_f1
        used_fallback = True
    else:
        used_fallback = False

    # test evaluation
    test_probs, test_y = collect_probs_and_labels(model, test_loader, device)

    auprc_f1, p_f1, r_f1, f1_f1 = metrics_at_threshold(test_probs, test_y, thr_f1)
    auprc_p,  p_p,  r_p,  f1_p  = metrics_at_threshold(test_probs, test_y, thr_p)

    return {
        "seed": seed,
        "n_graphs": len(ds),
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p,
        "p20_fallback_to_f1": used_fallback,
        "test_auprc": float(average_precision_score(test_y, test_probs)),
        "test_P_maxF1": p_f1, "test_R_maxF1": r_f1, "test_F1_maxF1": f1_f1,
        "test_P_p20": p_p,    "test_R_p20": r_p,    "test_F1_p20": f1_p
    }

print(train_one(dataset, seed=42, epochs=10))


{'seed': 42, 'n_graphs': 50, 'val_thr_maxf1': 0.46328476071357727, 'val_thr_p20': 0.7005248665809631, 'p20_fallback_to_f1': False, 'test_auprc': 0.166468318577282, 'test_P_maxF1': 0.11736334405125826, 'test_R_maxF1': 0.6517857142798947, 'test_F1_maxF1': 0.19891008148471664, 'test_P_p20': 0.2272727272701446, 'test_R_p20': 0.17857142856983418, 'test_F1_p20': 0.19999999950519998}


In [5]:
# run three seeds and save "50 graphs" repot

seeds = [1, 42, 123]
rows_50 = [train_one(dataset, seed=s, epochs=10) for s in seeds]
report_50 = pd.DataFrame(rows_50)
display(report_50)

summary_50 = report_50[["test_P_maxF1","test_R_maxF1","test_F1_maxF1",
                        "test_P_p20","test_R_p20","test_F1_p20",
                        "test_auprc"]].agg(["mean","std"])
display(summary_50)

PATH_50 = OUT_DIR / "day9_report_50.csv"
report_50.to_csv(PATH_50, index=False)
print("Saved:", PATH_50)

Unnamed: 0,seed,n_graphs,val_thr_maxf1,val_thr_p20,p20_fallback_to_f1,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20
0,1,50,0.542755,0.743259,False,0.072835,0.089239,0.309091,0.138493,0.068966,0.036364,0.047619
1,42,50,0.448618,0.767681,False,0.158242,0.102941,0.75,0.181034,0.257143,0.080357,0.122449
2,123,50,0.641167,0.833923,False,0.108684,0.067766,0.41573,0.116535,0.277778,0.05618,0.093458


Unnamed: 0,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_auprc
mean,0.086649,0.491607,0.145354,0.201295,0.057634,0.087842,0.113254
std,0.01773,0.230039,0.032792,0.115065,0.022033,0.03773,0.042886


Saved: /content/drive/MyDrive/biolip_gnn/out/day9_report_50.csv


In [6]:
from pathlib import Path
BASE = Path("/content/drive/MyDrive/biolip_gnn")
LABELED_DIR = BASE / "graphs_labeled_v3"
print("Graphs in labeled dir:", len(list(LABELED_DIR.glob("*.npz"))))


Graphs in labeled dir: 50


In [7]:
# expand from 50 to 200 labeled graphs

In [8]:
from pathlib import Path
import pandas as pd

BASE = Path("/content/drive/MyDrive/biolip_gnn")
OUT_DIR = BASE / "out"
STRUCT_DIR = BASE / "structures"
LABELED_DIR = BASE / "graphs_labeled_v3"
LABELED_DIR.mkdir(exist_ok=True)

CSV200 = OUT_DIR / "subset_200.csv"

print("CSV200 exists:", CSV200.exists(), CSV200)
print("STRUCT_DIR exists:", STRUCT_DIR.exists(), STRUCT_DIR)
print("LABELED_DIR:", LABELED_DIR)

df = pd.read_csv(CSV200)
print("subset_200 rows:", len(df))
print(df.columns)
df.head()


CSV200 exists: True /content/drive/MyDrive/biolip_gnn/out/subset_200.csv
STRUCT_DIR exists: True /content/drive/MyDrive/biolip_gnn/structures
LABELED_DIR: /content/drive/MyDrive/biolip_gnn/graphs_labeled_v3
subset_200 rows: 200
Index(['pdb_id', 'chain', 'resolution', 'seq_len', 'sequence', 'raw_line',
       'pdb_chain_key'],
      dtype='object')


Unnamed: 0,pdb_id,chain,resolution,seq_len,sequence,raw_line,pdb_chain_key
0,9HIJ,A,1.6,97,EIKGYEYQLYVYASDKLFRADISEDYKTRGRKLLRFNGPVPPPGGS...,9hij\tA\t1.6\tBS01\tMG\tA\t1\tE53 E55\tE50 E52...,9HIJ_A
1,7M3Y,B,1.69,109,SEVEYRAEVGQNAYLPCFYTPAAPGNLVPVCWGKGACPVFECGNVV...,7m3y\tB\t1.69\tBS01\tYQ7\tB\t1\tV54 W57 S59 Y6...,7M3Y_B
2,1W58,1,2.5,337,SPEDKELLEYLQQTKAKITVVGCGGAGNNTITRLKMEGIEGAKTVA...,1w58\t1\t2.5\tBS01\tG2P\t1\t1\tG46 G47 A48 A97...,1W58_1
3,4PDD,A,1.7,303,QTILKIGYTPPKDSHYGVGATTFCDEVEKGTQERYKCQHFPSSALG...,4pdd\tA\t1.7\tBS01\tEAX\tA\t1\tT40 Y47 E79 N15...,4PDD_A
4,1M3U,A,1.8,262,PTTISLLQKYKQEKKRFATITAYDYSFAKLFADEGLNVMLVGDSLG...,1m3u\tA\t1.8\tBS01\tMG\tA\t1\tD45 D84\tD43 D82...,1M3U_A


In [9]:
import re

def key_from_row(r):
    # expecting columns like pdb_id and chain in your CSV
    return f"{str(r['pdb_id']).upper()}_{str(r['chain'])}"

existing = {p.stem for p in LABELED_DIR.glob("*.npz")}  # stems like 1KMM_C
df["graph_key"] = df.apply(key_from_row, axis=1)

missing = df[~df["graph_key"].isin(existing)].copy()
print("Already labeled:", len(existing))
print("Missing to label:", len(missing))
missing[["graph_key","pdb_id","chain"]].head(10)


Already labeled: 50
Missing to label: 150


Unnamed: 0,graph_key,pdb_id,chain
0,9HIJ_A,9HIJ,A
1,7M3Y_B,7M3Y,B
2,1W58_1,1W58,1
3,4PDD_A,4PDD,A
4,1M3U_A,1M3U,A
5,8ZNJ_A,8ZNJ,A
6,3QNM_A,3QNM,A
7,8ARU_A,8ARU,A
8,4K70_A,4K70,A
10,7BKB_E,7BKB,E


In [10]:
!pip -q install biopython

import gzip
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB.Polypeptide import is_aa
import numpy as np

parser = MMCIFParser(QUIET=True)

def load_structure_from_cif_gz(pdb_id: str):
    pdb_id = pdb_id.lower()
    path_gz = STRUCT_DIR / f"{pdb_id}.cif.gz"
    if not path_gz.exists():
        return None, f"missing_cif_gz:{path_gz.name}"
    try:
        with gzip.open(path_gz, "rt") as handle:
            structure = parser.get_structure(pdb_id, handle)
        return structure, "ok"
    except Exception as e:
        return None, f"parse_error:{type(e).__name__}"


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m2.9/3.2 MB[0m [31m112.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [16]:
from Bio.Data.IUPACData import protein_letters_3to1_extended

def resname_to_aa1(resname: str) -> str:
    return protein_letters_3to1_extended.get(resname.capitalize(), "X")


In [17]:
def parse_chain_ca(structure, chain_id: str):
    """
    Returns:
      resseq_list: list of PDB residue numbers (int)
      aa_list: list of 1-letter codes if possible, else 'X'
      coords: (N,3) numpy array of CA coords
    """
    model = structure[0]
    if chain_id not in model:
        return None, None, None, "chain_not_found"

    chain = model[chain_id]

    resseq, aa1, coords = [], [], []
    for res in chain:
        if not is_aa(res, standard=False):
            continue
        # CA atom
        if "CA" not in res:
            continue

        rid = res.get_id()
        # rid = (' ', resseq, icode)
        resnum = int(rid[1])
        icode = rid[2].strip()
        # ignore insertion codes for now; you can improve later
        if icode != "":
            continue

        resname = res.get_resname().upper()
        ca = res["CA"].get_coord()

        resseq.append(resnum)
        coords.append(ca)

        # crude mapping to 1-letter (fallback X)
        from Bio.PDB.Polypeptide import three_to_one
        try:
            aa1.append(three_to_one(resname))
        except:
            aa1.append("X")

    if len(resseq) == 0:
        return None, None, None, "no_residues_with_ca"

    return resseq, aa1, np.array(coords, dtype=np.float32), "ok"


In [18]:
import re

def binding_numbers_from_raw_line(raw_line: str):
    """
    Extract residue tokens like E83, R259, I206... and return set of ints {83,259,206}
    """
    if not isinstance(raw_line, str) or len(raw_line) == 0:
        return set()
    toks = re.findall(r"[A-Z][0-9]+", raw_line)
    nums = set()
    for t in toks:
        try:
            nums.add(int(t[1:]))
        except:
            pass
    return nums

def make_labels(resseq_list, binding_nums):
    y = np.zeros(len(resseq_list), dtype=np.int64)
    if not binding_nums:
        return y, "no_binding_info"
    for i, rn in enumerate(resseq_list):
        if int(rn) in binding_nums:
            y[i] = 1
    return y, "pdb_resseq_forced"


In [19]:
from tqdm import tqdm

MAX_NEW = 150
fail_log = []
saved = 0

AA_TO_IDX = {a:i for i,a in enumerate(list("ACDEFGHIKLMNPQRSTVWY")+["X"])}  # 21 incl X

def aa_list_to_idx(aa_list):
    return np.array([AA_TO_IDX.get(a, 20) for a in aa_list], dtype=np.int64)

for _, r in tqdm(missing.head(MAX_NEW).iterrows(), total=min(MAX_NEW, len(missing))):
    pdb_id = str(r["pdb_id"]).upper()
    chain  = str(r["chain"])
    key    = str(r["graph_key"])

    structure, status = load_structure_from_cif_gz(pdb_id)
    if structure is None:
        fail_log.append((key, status))
        continue

    resseq_list, aa_list, coords, st2 = parse_chain_ca(structure, chain)
    if st2 != "ok":
        fail_log.append((key, st2))
        continue

    edge_index, edge_dist = build_graph_from_coords(coords, dist_thresh=8.0, add_seq_edges=True)

    # labels
    binding_nums = binding_numbers_from_raw_line(r.get("raw_line",""))
    y, label_mode = make_labels(resseq_list, binding_nums)

    # save
    outpath = LABELED_DIR / f"{key}.npz"
    np.savez_compressed(
        outpath,
        pdb_id=pdb_id,
        chain=chain,
        n_nodes=len(resseq_list),
        resseq=np.array(resseq_list, dtype=np.int64),
        x_idx=aa_list_to_idx(aa_list),
        edge_index=edge_index,
        edge_dist=edge_dist,
        y=y,
        label_mode=label_mode
    )
    saved += 1

print("New labeled graphs saved:", saved)
print("Failures:", len(fail_log))
fail_log[:10]


  0%|          | 0/150 [00:00<?, ?it/s]


ImportError: cannot import name 'three_to_one' from 'Bio.PDB.Polypeptide' (/usr/local/lib/python3.12/dist-packages/Bio/PDB/Polypeptide.py)

In [20]:
from Bio.Data.IUPACData import protein_letters_3to1_extended

def resname_to_aa1(resname: str) -> str:
    return protein_letters_3to1_extended.get(resname.capitalize(), "X")


In [21]:
from Bio.PDB.Polypeptide import is_aa
import numpy as np

def parse_chain_ca(structure, chain_id: str):
    model = structure[0]
    if chain_id not in model:
        return None, None, None, "chain_not_found"

    chain = model[chain_id]
    resseq, aa1, coords = [], [], []

    for res in chain:
        if not is_aa(res, standard=False):
            continue
        if "CA" not in res:
            continue

        rid = res.get_id()                 # (' ', resseq, icode)
        resnum = int(rid[1])
        icode = rid[2].strip()
        if icode != "":
            continue

        resname = res.get_resname().upper()
        ca = res["CA"].get_coord()

        resseq.append(resnum)
        coords.append(ca)
        aa1.append(resname_to_aa1(resname))  # FIXED LINE ✅

    if len(resseq) == 0:
        return None, None, None, "no_residues_with_ca"

    return resseq, aa1, np.array(coords, dtype=np.float32), "ok"


In [22]:
print(resname_to_aa1("GLY"), resname_to_aa1("LYS"), resname_to_aa1("UNK"))


G K X


In [23]:
from tqdm import tqdm

MAX_NEW = 150
fail_log = []
saved = 0

AA_TO_IDX = {a:i for i,a in enumerate(list("ACDEFGHIKLMNPQRSTVWY")+["X"])}  # 21 incl X

def aa_list_to_idx(aa_list):
    return np.array([AA_TO_IDX.get(a, 20) for a in aa_list], dtype=np.int64)

for _, r in tqdm(missing.head(MAX_NEW).iterrows(), total=min(MAX_NEW, len(missing))):
    pdb_id = str(r["pdb_id"]).upper()
    chain  = str(r["chain"])
    key    = str(r["graph_key"])

    structure, status = load_structure_from_cif_gz(pdb_id)
    if structure is None:
        fail_log.append((key, status))
        continue

    resseq_list, aa_list, coords, st2 = parse_chain_ca(structure, chain)
    if st2 != "ok":
        fail_log.append((key, st2))
        continue

    edge_index, edge_dist = build_graph_from_coords(coords, dist_thresh=8.0, add_seq_edges=True)

    # labels
    binding_nums = binding_numbers_from_raw_line(r.get("raw_line",""))
    y, label_mode = make_labels(resseq_list, binding_nums)

    # save
    outpath = LABELED_DIR / f"{key}.npz"
    np.savez_compressed(
        outpath,
        pdb_id=pdb_id,
        chain=chain,
        n_nodes=len(resseq_list),
        resseq=np.array(resseq_list, dtype=np.int64),
        x_idx=aa_list_to_idx(aa_list),
        edge_index=edge_index,
        edge_dist=edge_dist,
        y=y,
        label_mode=label_mode
    )
    saved += 1

print("New labeled graphs saved:", saved)
print("Failures:", len(fail_log))
fail_log[:10]


100%|██████████| 150/150 [06:14<00:00,  2.49s/it]

New labeled graphs saved: 150
Failures: 0





[]

In [24]:
import pandas as pd

LOG_PATH = OUT_DIR / "day9_expand_to_200_failures.csv"
pd.DataFrame(fail_log, columns=["graph_key","reason"]).to_csv(LOG_PATH, index=False)

count_now = len(list(LABELED_DIR.glob("*.npz")))
print("Graphs now in graphs_labeled_v3:", count_now)
print("Saved failure log:", LOG_PATH)


Graphs now in graphs_labeled_v3: 200
Saved failure log: /content/drive/MyDrive/biolip_gnn/out/day9_expand_to_200_failures.csv


In [26]:
# load 200 graphs

npz_files = sorted(LABELED_DIR.glob("*.npz"))
print("Graphs found:", len(npz_files))

def load_npz(path: Path) -> dict:
    z = np.load(path, allow_pickle=True)
    return {k: z[k] for k in z.files}

raw_graphs = [load_npz(p) for p in npz_files]

def compute_degree(edge_index: torch.Tensor, n_nodes: int) -> torch.Tensor:
    deg = degree(edge_index[0], num_nodes=n_nodes).view(-1, 1).float()
    return (deg - deg.mean()) / (deg.std() + 1e-9)

def to_pyg_degree(graph_dict: dict) -> Data:
    x_idx = torch.tensor(graph_dict["x_idx"], dtype=torch.long)
    edge_index = torch.tensor(graph_dict["edge_index"], dtype=torch.long)
    y = torch.tensor(graph_dict["y"], dtype=torch.long)

    n_nodes = x_idx.numel()
    deg = compute_degree(edge_index, n_nodes)

    x = torch.cat([x_idx.view(-1,1), deg], dim=1)  # (N,2)

    d = Data(x=x, edge_index=edge_index, y=y)
    return d

dataset200 = [to_pyg_degree(g) for g in raw_graphs]
print("dataset200 size:", len(dataset200), "example x:", dataset200[0].x.shape)

Graphs found: 200
dataset200 size: 200 example x: torch.Size([211, 2])


In [28]:
# model + helpers

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        h = self.emb(aa_idx)
        h = torch.cat([h, data.x[:,1:].float()], dim=1)
        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    all_probs, all_y = [], []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        y = batch.y.detach().cpu().numpy()
        all_probs.append(probs); all_y.append(y)
    return np.concatenate(all_probs), np.concatenate(all_y)

def pos_weight_from_train(train_set):
    pos = sum(int(d.y.sum()) for d in train_set)
    tot = sum(int(d.y.numel()) for d in train_set)
    neg = tot - pos
    return torch.tensor([neg / max(pos, 1)], dtype=torch.float)

def split_dataset(ds, seed=42):
    ds = ds.copy()
    random.Random(seed).shuffle(ds)
    n = len(ds)
    n_train = int(0.7*n)
    n_val   = int(0.15*n)
    return ds[:n_train], ds[n_train:n_train+n_val], ds[n_train+n_val:]

def threshold_max_f1(probs, y_true):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    f1 = (2*prec[:-1]*rec[:-1]) / (prec[:-1]+rec[:-1] + 1e-9)
    i = int(np.argmax(f1))
    return float(thr[i])

def threshold_precision_target(probs, y_true, target_precision=0.20):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr
    ok = np.where(prec2 >= target_precision)[0]
    if len(ok) == 0:
        return None
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def metrics_at_threshold(probs, y_true, thr):
    y_pred = (probs >= thr).astype(int)
    tp = int(((y_pred==1) & (y_true==1)).sum())
    fp = int(((y_pred==1) & (y_true==0)).sum())
    fn = int(((y_pred==0) & (y_true==1)).sum())
    precision = tp / (tp+fp+1e-9)
    recall    = tp / (tp+fn+1e-9)
    f1        = (2*precision*recall) / (precision+recall+1e-9)
    auprc     = float(average_precision_score(y_true, probs))
    return auprc, precision, recall, f1


In [29]:
# train + evaluate once with 3 threshold modes (F1, p >=0.20, p>= 0.15)

def train_one(ds, seed=42, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_set, val_set, test_set = split_dataset(ds, seed=seed)

    model = SAGE_NodeClassifier(extra_feats=1).to(device)
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    pw = pos_weight_from_train(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pw)
    opt  = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    model.train()
    for _ in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            loss = crit(model(batch), batch.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    # validation
    val_probs, val_y = collect_probs_and_labels(model, val_loader, device)
    thr_f1 = threshold_max_f1(val_probs, val_y)

    thr_p20 = threshold_precision_target(val_probs, val_y, 0.20)
    thr_p15 = threshold_precision_target(val_probs, val_y, 0.15)

    fb20 = (thr_p20 is None)
    fb15 = (thr_p15 is None)
    if thr_p20 is None: thr_p20 = thr_f1
    if thr_p15 is None: thr_p15 = thr_f1

    # test
    test_probs, test_y = collect_probs_and_labels(model, test_loader, device)
    test_auprc = float(average_precision_score(test_y, test_probs))

    _, p_f1,  r_f1,  f1_f1  = metrics_at_threshold(test_probs, test_y, thr_f1)
    _, p20,   r20,   f1_20  = metrics_at_threshold(test_probs, test_y, thr_p20)
    _, p15,   r15,   f1_15  = metrics_at_threshold(test_probs, test_y, thr_p15)

    return {
        "seed": seed,
        "n_graphs": len(ds),
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p20,
        "val_thr_p15": thr_p15,
        "p20_fallback": fb20,
        "p15_fallback": fb15,
        "test_auprc": test_auprc,
        "test_P_maxF1": p_f1, "test_R_maxF1": r_f1, "test_F1_maxF1": f1_f1,
        "test_P_p20": p20,    "test_R_p20": r20,    "test_F1_p20": f1_20,
        "test_P_p15": p15,    "test_R_p15": r15,    "test_F1_p15": f1_15,
    }


In [31]:
# sun 3 seeds on 200 graphs + save report

seeds = [1, 42, 123]
rows_200 = [train_one(dataset200, seed=s, epochs=10) for s in seeds]
report_200 = pd.DataFrame(rows_200)
display(report_200)

summary_200 = report_200[[
    "test_auprc",
    "test_P_maxF1","test_R_maxF1","test_F1_maxF1",
    "test_P_p20","test_R_p20","test_F1_p20",
    "test_P_p15","test_R_p15","test_F1_p15",
]].agg(["mean","std"])
display(summary_200)

PATH_200 = OUT_DIR / "day9_report_200.csv"
report_200.to_csv(PATH_200, index=False)
print("Saved:", PATH_200)


Unnamed: 0,seed,n_graphs,val_thr_maxf1,val_thr_p20,val_thr_p15,p20_fallback,p15_fallback,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
0,1,200,0.656901,0.741166,0.678641,False,False,0.09444,0.092279,0.387863,0.149087,0.129956,0.155673,0.141657,0.095618,0.316623,0.146879
1,42,200,0.648237,0.734363,0.6465,False,False,0.11489,0.110965,0.330078,0.166093,0.142857,0.173828,0.156828,0.110823,0.333984,0.166423
2,123,200,0.479807,0.724261,0.644375,False,False,0.178412,0.126023,0.469512,0.19871,0.366667,0.067073,0.113402,0.262108,0.186992,0.218268


Unnamed: 0,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
mean,0.129247,0.109756,0.395818,0.171297,0.21316,0.132191,0.137296,0.156183,0.2792,0.17719
std,0.043788,0.016905,0.070057,0.025217,0.133097,0.05712,0.022039,0.092049,0.080325,0.036892


Saved: /content/drive/MyDrive/biolip_gnn/out/day9_report_200.csv


In [32]:
# a 50 vs 200 comparison report

PATH_50 = OUT_DIR / "day9_report_50.csv"

r50 = pd.read_csv(PATH_50)
r200 = pd.read_csv(PATH_200)

combined = pd.concat([r50, r200], ignore_index=True)
COMBINED_PATH = OUT_DIR / "day9_report_50_vs_200.csv"
combined.to_csv(COMBINED_PATH, index=False)

print("Saved combined report:", COMBINED_PATH)
combined


Saved combined report: /content/drive/MyDrive/biolip_gnn/out/day9_report_50_vs_200.csv


Unnamed: 0,seed,n_graphs,val_thr_maxf1,val_thr_p20,p20_fallback_to_f1,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,val_thr_p15,p20_fallback,p15_fallback,test_P_p15,test_R_p15,test_F1_p15
0,1,50,0.542755,0.743259,False,0.072835,0.089239,0.309091,0.138493,0.068966,0.036364,0.047619,,,,,,
1,42,50,0.448618,0.767681,False,0.158242,0.102941,0.75,0.181034,0.257143,0.080357,0.122449,,,,,,
2,123,50,0.641167,0.833923,False,0.108684,0.067766,0.41573,0.116535,0.277778,0.05618,0.093458,,,,,,
3,1,200,0.656901,0.741166,,0.09444,0.092279,0.387863,0.149087,0.129956,0.155673,0.141657,0.678641,False,False,0.095618,0.316623,0.146879
4,42,200,0.648237,0.734363,,0.11489,0.110965,0.330078,0.166093,0.142857,0.173828,0.156828,0.6465,False,False,0.110823,0.333984,0.166423
5,123,200,0.479807,0.724261,,0.178412,0.126023,0.469512,0.19871,0.366667,0.067073,0.113402,0.644375,False,False,0.262108,0.186992,0.218268
