In [1]:
# Step A (prep): inspect columns + find the best “ligand type” column

from google.colab import drive
drive.mount('/content/drive')

from pathlib import Path
import pandas as pd

BASE = Path("/content/drive/MyDrive/biolip_gnn")
CSV1200 = BASE / "out" / "biolip_subset_1200.csv"

df = pd.read_csv(CSV1200)
print("Rows:", len(df))
print("\nCOLUMNS:")
print(list(df.columns))

# Try a few common ligand-type column names
candidates = [
    "ligand", "ligand_id", "ligand_type", "het_id", "het", "ligand_name",
    "ligand_code", "chem", "compound", "comp_id", "resname"
]

present = [c for c in candidates if c in df.columns]
print("\nPresent candidate columns:", present)

# If at least one candidate exists, print value_counts for the first one.
# If none exist, auto-scan for “ligand-like” columns by checking uniqueness.
if present:
    col = present[0]
    print(f"\nVALUE_COUNTS HEAD(15) for candidate column: {col}")
    print(df[col].astype(str).value_counts().head(15))
else:
    # Auto-scan for likely categorical columns (few thousand unique max)
    likely = []
    for c in df.columns:
        nun = df[c].nunique(dropna=True)
        if 2 <= nun <= 500:  # heuristic: categorical-ish
            likely.append((c, nun))
    likely = sorted(likely, key=lambda x: x[1])
    print("\nNo obvious ligand column found. Here are likely categorical columns (name, #unique):")
    print(likely[:25])

    # Pick the first likely column and show its value counts
    if likely:
        col = likely[0][0]
        print(f"\nVALUE_COUNTS HEAD(15) for auto-picked column: {col}")
        print(df[col].astype(str).value_counts().head(15))


Mounted at /content/drive
Rows: 1200

COLUMNS:
['pdb_id', 'chain', 'resolution', 'seq_len', 'sequence', 'raw_line', 'pdb_chain_key']

Present candidate columns: []

No obvious ligand column found. Here are likely categorical columns (name, #unique):
[('chain', 105), ('resolution', 322), ('seq_len', 384)]

VALUE_COUNTS HEAD(15) for auto-picked column: chain
chain
A      659
B      161
C       72
D       41
F       23
H       22
E       16
R       15
L       14
G       13
I       10
M        7
K        6
AAA      6
Z        5
Name: count, dtype: int64


In [2]:
# Extract ligand code from raw_line + show counts

import re
import pandas as pd

# 1) Quick peek: see how raw_line looks when split by tabs
# (This helps confirm which column holds ligand code.)
for i in range(3):
    parts = str(df.loc[i, "raw_line"]).split("\t")
    print(f"\nRow {i} split length = {len(parts)}")
    print(parts[:10])  # first 10 fields only

# 2) BioLiP raw_line typically starts like:
# pdb_id, chain, resolution, BSxx, LIGAND_CODE, ...
# So ligand code is often parts[4]. We'll extract safely.
def extract_ligand_code(raw_line: str):
    if not isinstance(raw_line, str) or raw_line.strip() == "":
        return None
    parts = raw_line.split("\t")
    if len(parts) >= 5:
        lig = parts[4].strip()
        # keep only plausible ligand IDs (often 1–3 chars, sometimes longer like "peptide")
        # We'll accept alnum and underscore/dash.
        lig = re.sub(r"[^A-Za-z0-9_\-]", "", lig)
        return lig if lig else None
    return None

df["ligand_code"] = df["raw_line"].apply(extract_ligand_code)

print("\nMissing ligand_code:", df["ligand_code"].isna().sum(), "out of", len(df))

print("\nVALUE_COUNTS HEAD(15) for ligand_code:")
print(df["ligand_code"].astype(str).value_counts().head(15))



Row 0 split length = 21
['1aei', 'A', '2.8', 'BS01', 'CA', 'A', '1', 'K68 L71 E76', 'K67 L70 E75', '']

Row 1 split length = 21
['1afa', '1', '2.0', 'BS01', 'MBG', '1', '1', 'Q185 D187 W189 E198 N210', 'Q113 D115 W117 E126 N138', '']

Row 2 split length = 21
['1ah5', 'A', '2.4', 'BS01', 'DPM', 'A', '1', 'S81 K83 D84 T127 S128 S129 R131 R132 R155 L169 A170 C242', 'S67 K69 D70 T113 S114 S115 R117 R118 R141 L155 A156 C228', 'K83 D84 R131 R132 R149 R155 C242']

Missing ligand_code: 0 out of 1200

VALUE_COUNTS HEAD(15) for ligand_code:
ligand_code
rna        130
peptide    113
ZN         113
MG          68
CA          60
dna         28
HEM         26
NAP         25
NAD         24
MN          22
ADP         14
CLA         14
FE          14
FMN         13
BGC         12
Name: count, dtype: int64


In [3]:
from pathlib import Path
import pandas as pd
import numpy as np

BASE = Path("/content/drive/MyDrive/biolip_gnn")
OUT_DIR = BASE / "out"
OUT_DIR.mkdir(exist_ok=True)

# df already loaded from biolip_subset_1200.csv and has df["ligand_code"] from earlier
df2 = df.dropna(subset=["ligand_code"]).copy()

# Optional: normalize ligand names to be consistent
df2["ligand_code"] = df2["ligand_code"].astype(str)

TARGET_N = 500
SEED = 42

# 1 Group rows by ligand_code
groups = {k: g.sample(frac=1, random_state=SEED).reset_index(drop=True)  # shuffle within group
          for k, g in df2.groupby("ligand_code")}

ligands_sorted = sorted(groups.keys(), key=lambda k: len(groups[k]), reverse=True)

print("Total ligands:", len(ligands_sorted))
print("Top 10 ligand sizes:", [(k, len(groups[k])) for k in ligands_sorted[:10]])

# 2 Decide a "base quota" per ligand.
# We start with equal-ish allocation: TARGET_N / num_ligands, but cap it
# so large ligands don't eat everything and small ones can still contribute.
num_lig = len(ligands_sorted)
base_quota = max(1, TARGET_N // num_lig)
cap = max(base_quota, 25)  # cap per ligand in initial pass (tunable)
print("base_quota:", base_quota, "cap:", cap)

selected_rows = []

# 3 Initial pass: take up to min(group_size, cap) from each ligand
for lig in ligands_sorted:
    g = groups[lig]
    take = min(len(g), cap)
    if take > 0:
        selected_rows.append(g.iloc[:take])
        groups[lig] = g.iloc[take:].reset_index(drop=True)

sel = pd.concat(selected_rows, ignore_index=True) if selected_rows else pd.DataFrame()

# 4 Fill remaining slots by round-robin from remaining groups
need = TARGET_N - len(sel)
print("After capped pass:", len(sel), "need:", need)

while need > 0:
    made_progress = False
    for lig in ligands_sorted:
        g = groups[lig]
        if len(g) == 0:
            continue
        sel = pd.concat([sel, g.iloc[[0]]], ignore_index=True)
        groups[lig] = g.iloc[1:].reset_index(drop=True)
        need -= 1
        made_progress = True
        if need == 0:
            break
    if not made_progress:
        print("Ran out of rows before reaching target.")
        break

# 5 Final shuffle of the selected 500 rows (optional but nice)
sel = sel.sample(frac=1, random_state=SEED).reset_index(drop=True)

# 6 Save + report counts
subset500_path = OUT_DIR / "subset_500.csv"
sel.to_csv(subset500_path, index=False)

print("\nSaved:", subset500_path)
print("Final subset size:", len(sel))

print("\nCounts per ligand_code (top 25):")
print(sel["ligand_code"].value_counts().head(25))

# Also save a small counts table for your report
counts_path = OUT_DIR / "subset_500_ligand_counts.csv"
sel["ligand_code"].value_counts().reset_index().rename(
    columns={"index":"ligand_code", "ligand_code":"count"}
).to_csv(counts_path, index=False)

print("Saved counts table:", counts_path)


Total ligands: 345
Top 10 ligand sizes: [('rna', 130), ('ZN', 113), ('peptide', 113), ('MG', 68), ('CA', 60), ('dna', 28), ('HEM', 26), ('NAP', 25), ('NAD', 24), ('MN', 22)]
base_quota: 1 cap: 25
After capped pass: 837 need: -337

Saved: /content/drive/MyDrive/biolip_gnn/out/subset_500.csv
Final subset size: 837

Counts per ligand_code (top 25):
ligand_code
peptide    25
CA         25
ZN         25
NAP        25
HEM        25
rna        25
MG         25
dna        25
NAD        24
MN         22
FE         14
ADP        14
CLA        14
FMN        13
ATP        12
FAD        12
BGC        12
PLP        12
SAH        11
SF4        11
COA        11
GLC         9
PO4         9
NDP         8
GDP         8
Name: count, dtype: int64
Saved counts table: /content/drive/MyDrive/biolip_gnn/out/subset_500_ligand_counts.csv


In [4]:
print("Unique PDB IDs in subset_500:", sel["pdb_id"].nunique())
print("Example keys:", (sel["pdb_id"].str.upper() + "_" + sel["chain"].astype(str)).head(5).tolist())


Unique PDB IDs in subset_500: 829
Example keys: ['4B79_A', '7B1S_C', '1YDF_A', '3U5O_C', '7ZM7_E']


In [5]:
# Download missing *.cif.gz for subset_500 (with log)


import time
import urllib.request

BASE = Path("/content/drive/MyDrive/biolip_gnn")
OUT_DIR = BASE / "out"
STRUCT_DIR = BASE / "structures"
STRUCT_DIR.mkdir(exist_ok=True)

CSV500 = OUT_DIR / "subset_500.csv"
df500 = pd.read_csv(CSV500)

# Unique PDB IDs
pdb_ids = sorted({str(x).lower() for x in df500["pdb_id"].astype(str).tolist()})
print("Unique PDB IDs in subset_500:", len(pdb_ids))

def cif_path(pdb_id: str) -> Path:
    return STRUCT_DIR / f"{pdb_id.lower()}.cif.gz"

missing = [pid for pid in pdb_ids if not cif_path(pid).exists()]
present = len(pdb_ids) - len(missing)

print("Already present:", present)
print("Missing to download:", len(missing))

LOG_PATH = OUT_DIR / "download_log_500.txt"

# RCSB mmCIF gz endpoint
def rcsb_url(pdb_id: str) -> str:
    pid = pdb_id.lower()
    return f"https://files.rcsb.org/download/{pid}.cif.gz"

# Download loop
ok = 0
fail = 0

with open(LOG_PATH, "w") as log:
    log.write(f"Total unique PDB IDs: {len(pdb_ids)}\n")
    log.write(f"Already present: {present}\n")
    log.write(f"Missing initially: {len(missing)}\n\n")

    for i, pid in enumerate(missing, 1):
        url = rcsb_url(pid)
        out = cif_path(pid)

        try:
            urllib.request.urlretrieve(url, out)
            ok += 1
            log.write(f"OK\t{pid}\t{out.name}\n")
        except Exception as e:
            fail += 1
            log.write(f"FAIL\t{pid}\t{type(e).__name__}\t{str(e)[:200]}\n")

        # be nice to the server
        if i % 20 == 0:
            time.sleep(1)

print("Downloaded OK:", ok)
print("Failed:", fail)
print("Log saved to:", LOG_PATH)

# Quick final count
final_present = len(list(STRUCT_DIR.glob("*.cif.gz")))
print("Total .cif.gz now in structures/:", final_present)


Unique PDB IDs in subset_500: 829
Already present: 829
Missing to download: 0
Downloaded OK: 0
Failed: 0
Log saved to: /content/drive/MyDrive/biolip_gnn/out/download_log_500.txt
Total .cif.gz now in structures/: 893


In [6]:
# sanity check on few random structures

import random
sample = random.sample(pdb_ids, k=min(10, len(pdb_ids)))
print("Sample PDB IDs:", sample)
print("Exists flags:", [(pid, cif_path(pid).exists()) for pid in sample])


Sample PDB IDs: ['2ve3', '1m3u', '2gcd', '7lep', '9e97', '8dq5', '5zz7', '5kf2', '2qes', '6p9e']
Exists flags: [('2ve3', True), ('1m3u', True), ('2gcd', True), ('7lep', True), ('9e97', True), ('8dq5', True), ('5zz7', True), ('5kf2', True), ('2qes', True), ('6p9e', True)]


In [7]:
# Build + label graphs for the 500 subset

BASE = Path("/content/drive/MyDrive/biolip_gnn")
OUT_DIR = BASE / "out"
STRUCT_DIR = BASE / "structures"
G500_DIR = BASE / "graphs_labeled_v4_500"
G500_DIR.mkdir(exist_ok=True)

CSV500 = OUT_DIR / "subset_500.csv"
df500 = pd.read_csv(CSV500)

print("Rows in subset_500:", len(df500))
print("Unique pdb_id:", df500["pdb_id"].nunique())
print("Unique pdb_id+chain:", (df500["pdb_id"].astype(str).str.upper() + "_" + df500["chain"].astype(str)).nunique())
df500.head(3)


Rows in subset_500: 837
Unique pdb_id: 829
Unique pdb_id+chain: 837


Unnamed: 0,pdb_id,chain,resolution,seq_len,sequence,raw_line,pdb_chain_key,ligand_code
0,4B79,A,1.98,236,MVFQHDIYAGQQVLVTGGSSGIGAAIAMQFAELGAEVVALGLDADG...,4b79\tA\t1.98\tBS01\tNAD\tA\t1\tG17 S19 S20 G2...,4B79_A,NAD
1,7B1S,C,0.992,265,VYQRQFLPADDRVTKNRKKVVDPSVKLEKIRTLSDKDFLTLIGHRH...,7b1s\tC\t0.992\tBS01\tUSN\tC\t1\tY120 S121 G12...,7B1S_C,USN
2,1YDF,A,2.6,255,SLYKGYLIDLDGTIYKGKDRIPAGETFVHELQKRDIPYLFVTNNTT...,1ydf\tA\t2.6\tBS01\tMG\tA\t1\tD10 D12 D207\tD9...,1YDF_A,MG


In [8]:
# Parse mmCIF + build graph + labels

!pip -q install biopython

import gzip, re
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB.Polypeptide import is_aa
from Bio.Data.IUPACData import protein_letters_3to1_extended

parser = MMCIFParser(QUIET=True)

def resname_to_aa1(resname: str) -> str:
    return protein_letters_3to1_extended.get(resname.capitalize(), "X")

AA_TO_IDX = {a:i for i,a in enumerate(list("ACDEFGHIKLMNPQRSTVWY")+["X"])}

def aa_list_to_idx(aa_list):
    return np.array([AA_TO_IDX.get(a, 20) for a in aa_list], dtype=np.int64)

def load_structure_from_cif_gz(pdb_id: str):
    pdb_id_l = pdb_id.lower()
    path_gz = STRUCT_DIR / f"{pdb_id_l}.cif.gz"
    if not path_gz.exists():
        return None, f"missing_cif_gz:{path_gz.name}"
    try:
        with gzip.open(path_gz, "rt") as handle:
            structure = parser.get_structure(pdb_id_l, handle)
        return structure, "ok"
    except Exception as e:
        return None, f"parse_error:{type(e).__name__}"

def parse_chain_ca(structure, chain_id: str):
    model = structure[0]
    if chain_id not in model:
        return None, None, None, "chain_not_found"

    chain = model[chain_id]
    resseq, aa1, coords = [], [], []

    for res in chain:
        if not is_aa(res, standard=False):
            continue
        if "CA" not in res:
            continue

        rid = res.get_id()  # (' ', resseq, icode)
        resnum = int(rid[1])
        icode = rid[2].strip()
        if icode != "":
            continue  # ignore insertion codes for simplicity

        resname = res.get_resname().upper()
        ca = res["CA"].get_coord()

        resseq.append(resnum)
        coords.append(ca)
        aa1.append(resname_to_aa1(resname))

    if len(resseq) == 0:
        return None, None, None, "no_residues_with_ca"

    return resseq, aa1, np.array(coords, dtype=np.float32), "ok"

def build_graph_from_coords(coords, dist_thresh=8.0, add_seq_edges=True):
    N = coords.shape[0]
    diff = coords[:,None,:] - coords[None,:,:]
    dmat = np.sqrt((diff**2).sum(-1))

    edges = []
    dists = []

    # contact edges
    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            if dmat[i,j] < dist_thresh:
                edges.append((i,j))
                dists.append(dmat[i,j])

    # sequence edges
    if add_seq_edges:
        for i in range(N-1):
            for (u,v) in [(i,i+1),(i+1,i)]:
                edges.append((u,v))
                dists.append(float(dmat[u,v]))

    edge_index = np.array(edges, dtype=np.int64).T
    edge_dist  = np.array(dists, dtype=np.float32)
    return edge_index, edge_dist

def binding_numbers_from_raw_line(raw_line: str):
    if not isinstance(raw_line, str) or raw_line.strip() == "":
        return set()
    toks = re.findall(r"[A-Z][0-9]+", raw_line)
    nums = set()
    for t in toks:
        try:
            nums.add(int(t[1:]))
        except:
            pass
    return nums

def make_labels(resseq_list, binding_nums):
    y = np.zeros(len(resseq_list), dtype=np.int64)
    if not binding_nums:
        return y, "no_binding_info"
    for i, rn in enumerate(resseq_list):
        if int(rn) in binding_nums:
            y[i] = 1
    return y, "pdb_resseq_forced"


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m1.7/3.2 MB[0m [31m51.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m56.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [9]:
# Build + save 500 graphs (with row-index keys) + failure log

from tqdm import tqdm

fail_log = []
saved = 0

for idx, r in tqdm(df500.iterrows(), total=len(df500)):
    pdb_id = str(r["pdb_id"]).upper()
    chain  = str(r["chain"])
    raw_line = r.get("raw_line", "")

    # Unique key per ROW to avoid overwriting
    key = f"{pdb_id}_{chain}_{idx:06d}"
    outpath = G500_DIR / f"{key}.npz"

    # skip if already built
    if outpath.exists():
        continue

    structure, st = load_structure_from_cif_gz(pdb_id)
    if structure is None:
        fail_log.append((key, "load", st))
        continue

    resseq_list, aa_list, coords, st2 = parse_chain_ca(structure, chain)
    if st2 != "ok":
        fail_log.append((key, "parse_chain", st2))
        continue

    edge_index, edge_dist = build_graph_from_coords(coords, dist_thresh=8.0, add_seq_edges=True)

    binding_nums = binding_numbers_from_raw_line(str(raw_line))
    y, label_mode = make_labels(resseq_list, binding_nums)

    np.savez_compressed(
        outpath,
        pdb_id=pdb_id,
        chain=chain,
        row_idx=int(idx),
        ligand_code=str(r.get("ligand_code","")),
        n_nodes=len(resseq_list),
        resseq=np.array(resseq_list, dtype=np.int64),
        x_idx=aa_list_to_idx(aa_list),
        edge_index=edge_index,
        edge_dist=edge_dist,
        y=y,
        label_mode=label_mode
    )
    saved += 1

print("New graphs saved:", saved)
print("Failures:", len(fail_log))
print("Total graphs now in graphs_labeled_v4_500:", len(list(G500_DIR.glob('*.npz'))))
fail_log[:10]


100%|██████████| 837/837 [17:14<00:00,  1.24s/it]

New graphs saved: 640
Failures: 0
Total graphs now in graphs_labeled_v4_500: 837





[]

In [10]:
!pip -q install torch_geometric scikit-learn

from pathlib import Path
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import degree

from sklearn.metrics import precision_recall_curve, average_precision_score



[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [17]:
from pathlib import Path
BASE = Path("/content/drive/MyDrive/biolip_gnn")
GDIR = BASE / "graphs_labeled_v4_500"
OUT_DIR = BASE / "out"
OUT_DIR.mkdir(exist_ok=True)

print("GDIR:", GDIR)
print("OUT_DIR:", OUT_DIR)

GDIR: /content/drive/MyDrive/biolip_gnn/graphs_labeled_v4_500
OUT_DIR: /content/drive/MyDrive/biolip_gnn/out


In [18]:
!pip -q install torch_geometric scikit-learn

import numpy as np
import pandas as pd

npz_files = sorted(GDIR.glob("*.npz"))
print("Graphs found:", len(npz_files))  # should be ~837

def load_npz(path: Path) -> dict:
    z = np.load(path, allow_pickle=True)
    return {k: z[k] for k in z.files}

raw_graphs = [load_npz(p) for p in npz_files]
print("Loaded graphs:", len(raw_graphs))


Graphs found: 837
Loaded graphs: 837


In [19]:
# Convert each graph to a PyTorch Geometric Data object (AA + degree features)

import torch
from torch_geometric.data import Data
from torch_geometric.utils import degree

def compute_degree(edge_index: torch.Tensor, n_nodes: int) -> torch.Tensor:
    deg = degree(edge_index[0], num_nodes=n_nodes).view(-1, 1).float()
    return (deg - deg.mean()) / (deg.std() + 1e-9)

def to_pyg_degree(graph_dict: dict) -> Data:
    x_idx = torch.tensor(graph_dict["x_idx"], dtype=torch.long)              # (N,)
    edge_index = torch.tensor(graph_dict["edge_index"], dtype=torch.long)    # (2,E)
    y = torch.tensor(graph_dict["y"], dtype=torch.long)                      # (N,)

    n_nodes = x_idx.numel()
    deg = compute_degree(edge_index, n_nodes)                                # (N,1)

    # x: [AA_index, normalized_degree]
    x = torch.cat([x_idx.view(-1,1), deg], dim=1)                             # (N,2)
    return Data(x=x, edge_index=edge_index, y=y)

dataset = [to_pyg_degree(g) for g in raw_graphs]
print("Dataset graphs:", len(dataset))
print("Example x shape:", dataset[0].x.shape)


Dataset graphs: 837
Example x shape: torch.Size([315, 2])


In [20]:
# Define model + helper functions (thresholds + metrics)

import torch.nn as nn
import torch.nn.functional as F
import random

from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader
from sklearn.metrics import precision_recall_curve, average_precision_score

class SAGE_NodeClassifier(nn.Module):
    def __init__(self, num_aa=21, emb_dim=32, hidden=64, extra_feats=1):
        super().__init__()
        self.emb = nn.Embedding(num_aa, emb_dim)
        in_dim = emb_dim + extra_feats
        self.conv1 = SAGEConv(in_dim, hidden)
        self.conv2 = SAGEConv(hidden, hidden)
        self.lin1  = nn.Linear(hidden, hidden)
        self.lin2  = nn.Linear(hidden, 1)

    def forward(self, data):
        aa_idx = data.x[:,0].long()
        h = self.emb(aa_idx)
        h = torch.cat([h, data.x[:,1:].float()], dim=1)
        h = F.relu(self.conv1(h, data.edge_index))
        h = F.relu(self.conv2(h, data.edge_index))
        h = F.relu(self.lin1(h))
        return self.lin2(h).squeeze(-1)

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    probs_all, y_all = [], []
    for batch in loader:
        batch = batch.to(device)
        probs = torch.sigmoid(model(batch)).detach().cpu().numpy()
        y = batch.y.detach().cpu().numpy()
        probs_all.append(probs); y_all.append(y)
    return np.concatenate(probs_all), np.concatenate(y_all)

def pos_weight_from_train(train_set):
    pos = sum(int(d.y.sum()) for d in train_set)
    tot = sum(int(d.y.numel()) for d in train_set)
    neg = tot - pos
    return torch.tensor([neg / max(pos, 1)], dtype=torch.float)

def split_dataset(ds, seed=42):
    ds = ds.copy()
    random.Random(seed).shuffle(ds)
    n = len(ds)
    n_train = int(0.7*n)
    n_val   = int(0.15*n)
    return ds[:n_train], ds[n_train:n_train+n_val], ds[n_train+n_val:]

def threshold_max_f1(probs, y_true):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    f1 = (2*prec[:-1]*rec[:-1]) / (prec[:-1]+rec[:-1] + 1e-9)
    return float(thr[int(np.argmax(f1))])

def threshold_precision_target(probs, y_true, target_precision=0.20):
    prec, rec, thr = precision_recall_curve(y_true, probs)
    prec2, rec2, thr2 = prec[:-1], rec[:-1], thr
    ok = np.where(prec2 >= target_precision)[0]
    if len(ok) == 0:
        return None
    best = ok[np.argmax(rec2[ok])]
    return float(thr2[best])

def metrics_at_threshold(probs, y_true, thr):
    y_pred = (probs >= thr).astype(int)
    tp = int(((y_pred==1) & (y_true==1)).sum())
    fp = int(((y_pred==1) & (y_true==0)).sum())
    fn = int(((y_pred==0) & (y_true==1)).sum())
    precision = tp / (tp+fp+1e-9)
    recall    = tp / (tp+fn+1e-9)
    f1        = (2*precision*recall) / (precision+recall+1e-9)
    return precision, recall, f1


In [21]:
# Train/evaluate ONE seed (quick test run)

def train_one(ds, seed=42, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_set, val_set, test_set = split_dataset(ds, seed=seed)

    model = SAGE_NodeClassifier(extra_feats=1).to(device)
    train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=4, shuffle=False)
    test_loader  = DataLoader(test_set, batch_size=4, shuffle=False)

    pw = pos_weight_from_train(train_set).to(device)
    crit = nn.BCEWithLogitsLoss(pos_weight=pw)
    opt  = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    model.train()
    for _ in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            loss = crit(model(batch), batch.y.float())
            opt.zero_grad(); loss.backward(); opt.step()

    # thresholds from VAL
    val_probs, val_y = collect_probs_and_labels(model, val_loader, device)
    thr_f1  = threshold_max_f1(val_probs, val_y)
    thr_p20 = threshold_precision_target(val_probs, val_y, 0.20)
    thr_p15 = threshold_precision_target(val_probs, val_y, 0.15)

    fb20 = thr_p20 is None
    fb15 = thr_p15 is None
    if thr_p20 is None: thr_p20 = thr_f1
    if thr_p15 is None: thr_p15 = thr_f1

    # TEST metrics
    test_probs, test_y = collect_probs_and_labels(model, test_loader, device)
    auprc = float(average_precision_score(test_y, test_probs))

    p_f1, r_f1, f1_f1 = metrics_at_threshold(test_probs, test_y, thr_f1)
    p20,  r20,  f1_20 = metrics_at_threshold(test_probs, test_y, thr_p20)
    p15,  r15,  f1_15 = metrics_at_threshold(test_probs, test_y, thr_p15)

    return {
        "seed": seed,
        "n_graphs": len(ds),
        "test_auprc": auprc,
        "val_thr_maxf1": thr_f1,
        "val_thr_p20": thr_p20, "p20_fallback": fb20,
        "val_thr_p15": thr_p15, "p15_fallback": fb15,
        "test_P_maxF1": p_f1, "test_R_maxF1": r_f1, "test_F1_maxF1": f1_f1,
        "test_P_p20": p20,    "test_R_p20": r20,    "test_F1_p20": f1_20,
        "test_P_p15": p15,    "test_R_p15": r15,    "test_F1_p15": f1_15,
    }

one = train_one(dataset, seed=42, epochs=10)
one


{'seed': 42,
 'n_graphs': 837,
 'test_auprc': 0.1243319600858796,
 'val_thr_maxf1': 0.7090231776237488,
 'val_thr_p20': 0.7544869184494019,
 'p20_fallback': False,
 'val_thr_p15': 0.6882699131965637,
 'p15_fallback': False,
 'test_P_maxF1': 0.15581787521074003,
 'test_R_maxF1': 0.23146292585158745,
 'test_F1_maxF1': 0.1862527711376258,
 'test_P_p20': 0.1950745301359721,
 'test_R_p20': 0.1508016032063373,
 'test_F1_p20': 0.1701045488158102,
 'test_P_p15': 0.14620797498041538,
 'test_R_p15': 0.2810621242483562,
 'test_F1_p15': 0.19235384834110092}

In [22]:
# Run ALL 5 seeds + save report + summary

seeds = [1, 7, 42, 123, 999]
rows = [train_one(dataset, seed=s, epochs=10) for s in seeds]

report = pd.DataFrame(rows)
display(report)

summary = report[[
    "test_auprc",
    "test_P_maxF1","test_R_maxF1","test_F1_maxF1",
    "test_P_p20","test_R_p20","test_F1_p20",
    "test_P_p15","test_R_p15","test_F1_p15"
]].agg(["mean","std"])
display(summary)

PATH = OUT_DIR / "day10_report_ALL837.csv"
report.to_csv(PATH, index=False)
print("Saved:", PATH)


Unnamed: 0,seed,n_graphs,test_auprc,val_thr_maxf1,val_thr_p20,p20_fallback,val_thr_p15,p15_fallback,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
0,1,837,0.127053,0.642041,0.684995,False,0.600548,False,0.146751,0.266101,0.189175,0.175838,0.212182,0.192308,0.137409,0.33999,0.195718
1,7,837,0.1113,0.727931,0.828183,False,0.750456,False,0.121341,0.293333,0.171669,0.172257,0.097436,0.124468,0.131967,0.247692,0.172193
2,42,837,0.126929,0.638472,0.713477,False,0.624198,False,0.145307,0.250501,0.183925,0.182709,0.158818,0.169928,0.143411,0.278056,0.189226
3,123,837,0.139667,0.629267,0.720038,False,0.61804,False,0.15519,0.3143,0.207784,0.212121,0.169783,0.188605,0.146785,0.326427,0.202508
4,999,837,0.1243,0.67514,0.807173,False,0.696875,False,0.153324,0.261962,0.193433,0.22597,0.064766,0.100676,0.168726,0.231029,0.195022


Unnamed: 0,test_auprc,test_P_maxF1,test_R_maxF1,test_F1_maxF1,test_P_p20,test_R_p20,test_F1_p20,test_P_p15,test_R_p15,test_F1_p15
mean,0.12585,0.144382,0.27724,0.189197,0.193779,0.140597,0.155197,0.14566,0.284639,0.190933
std,0.010095,0.013547,0.026004,0.013214,0.023876,0.058995,0.040704,0.014088,0.047677,0.011486


Saved: /content/drive/MyDrive/biolip_gnn/out/day10_report_ALL837.csv
