In [40]:
!pip install torch-geometric



In [2]:
import re, sys, math, json
from io import StringIO
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.impute import KNNImputer
from sklearn.model_selection import StratifiedShuffleSplit

from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool

In [3]:
URLS = {
    # Proteomics (gene-level)
    "prot_tumor":  "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_proteomics_gene_abundance_log2_reference_intensity_normalized_Tumor.txt",
    "prot_norm":   "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_proteomics_gene_abundance_log2_reference_intensity_normalized_Normal.txt",

    # Phosphoproteomics (site-level)
    "phos_tumor":  "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_phospho_site_abundance_log2_reference_intensity_normalized_Tumor.txt",
    "phos_norm":   "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_phospho_site_abundance_log2_reference_intensity_normalized_Normal.txt",

    # Mutations (gene-level binary), CNV, RNA
    "mut_gene_bin":"https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_somatic_mutation_gene_level_binary.txt",
    "cnv_log2":    "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_WES_CNV_gene_ratio_log2.txt",
    # alt CNV: GISTIC discrete
    # "cnv_gistic":  "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_WES_CNV_gene_gistic_level.txt",

    "rna_gene_tumor": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_RNAseq_gene_RSEM_coding_UQ_1500_log2_Tumor.txt",

    # CMS labels from Linkedomics clinical .tsi
    "tsi": "https://linkedomics.org/cptac-colon/Human__CPTAC_COAD__MS__Clinical__Clinical__03_01_2017__CPTAC__Clinical__BCM.tsi",
}


In [41]:
import io
import gzip
import requests
import numpy as np
import pandas as pd
from io import StringIO
from typing import Optional, Literal, Dict, List

def fetch_tsv(url: str, index_col: int | None = 0) -> pd.DataFrame:
    """
    Robust TSV fetcher:
      - Handles plain or gzipped content.
      - Uses pandas dtype inference but preserves the index as string.
      - Does not coerce to numeric here (we sanitize later, modality-specific).
    """
    r = requests.get(url, timeout=180)
    r.raise_for_status()
    content = r.content
    # Detect gzip by header
    if content[:2] == b"\x1f\x8b":
        buf = io.BytesIO(content)
        with gzip.GzipFile(fileobj=buf, mode="rb") as gz:
            text = gz.read().decode("utf-8", errors="replace")
        df = pd.read_csv(StringIO(text), sep="\t", header=0, low_memory=False)
    else:
        df = pd.read_csv(io.BytesIO(content), sep="\t", header=0, low_memory=False)

    if index_col is not None:
        # ensure index is string-like, then set as index
        idx_name = df.columns[index_col]
        df[idx_name] = df[idx_name].astype(str)
        df = df.set_index(idx_name)
    return df

def strip_ensembl_version(x: str):
    """Drop trailing .version from Ensembl ids; no-op for non-Ensembl."""
    if isinstance(x, str) and x.startswith(("ENS", "ens")) and "." in x:
        return x.split(".", 1)[0]
    return x

def clean_cols(df: pd.DataFrame) -> pd.DataFrame:
    """Trim whitespace from column names, preserve order."""
    df.columns = [c.strip() for c in df.columns]
    return df

def pick_valid_patients(id_list: List[str]) -> List[str]:
    """Keep non-empty, non-null patient IDs; trims whitespace."""
    keep = []
    for p in id_list:
        if isinstance(p, str):
            s = p.strip()
            if s and s.upper() not in {"NA", "NAN", "NULL"}:
                keep.append(s)
    return keep

# ---------- Utilities you reuse later ----------

def collapse_duplicate_rows(
    df: pd.DataFrame,
    how: Literal["median", "mean", "max_binary"] = "median"
) -> pd.DataFrame:
    """
    Collapse duplicate index rows:
      - 'median' or 'mean' for continuous (proteo/phospho/RNA/CNV),
      - 'max_binary' for mutation (0/1).
    """
    if not df.index.has_duplicates:
        return df
    if how == "max_binary":
        return df.groupby(level=0).max(numeric_only=True)
    elif how == "mean":
        return df.groupby(level=0).mean(numeric_only=True)
    else:  # median default
        return df.groupby(level=0).median(numeric_only=True)

def sanitize_numeric(df: pd.DataFrame, clip_abs: float | None = None) -> pd.DataFrame:
    """
    Make all entries numeric; non-numeric -> NaN; replace ±inf with NaN; optional clipping.
    Use before KNN impute to avoid distance explosions.
    """
    df2 = df.apply(pd.to_numeric, errors="coerce")
    df2 = df2.replace([np.inf, -np.inf], np.nan)
    if clip_abs is not None:
        df2 = df2.clip(lower=-clip_abs, upper=clip_abs)
    return df2

def knn_impute(df: pd.DataFrame, max_k: int = 5) -> pd.DataFrame:
    """
    KNN impute treating samples as rows (transpose inside).
    Falls back to per-row median if samples < 2.
    """
    from sklearn.impute import KNNImputer
    n_samples = df.shape[1]
    if n_samples < 2:
        row_med = df.median(axis=1, skipna=True)
        return df.apply(lambda col: col.fillna(row_med), axis=0)

    k = min(max_k, max(1, n_samples - 1))
    imp = KNNImputer(n_neighbors=k, weights="distance")
    vals = imp.fit_transform(df.T.values)  # [samples, features]
    return pd.DataFrame(vals.T, index=df.index, columns=df.columns)

def baseline_z_from_normals(
    tumor_df: pd.DataFrame,
    normal_df: pd.DataFrame,
    clip: float = 5.0
) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
    """
    Baseline-normalize tumor using normals only: (tumor - mean_norm) / std_norm.
    Returns (Z, mu_norm, sd_norm). NaNs/±inf -> 0 after z, with clipping.
    """
    mu = normal_df.mean(axis=1)
    sd = normal_df.std(axis=1, ddof=0).replace(0, np.nan)
    Z = (tumor_df.sub(mu, axis=0)).div(sd, axis=0)
    Z = Z.replace([np.inf, -np.inf], np.nan).fillna(0.0).clip(-clip, clip)
    return Z, mu, sd

def z_by_train_only(
    df_all: pd.DataFrame,
    train_cols: List[str],
    clip: float = 5.0
) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
    """
    Train-only standardization: z-score each row using mean/std computed on TRAIN columns ONLY.
    Returns (Z, mu_train, sd_train). NaNs/±inf -> 0 with clipping.
    """
    mu = df_all[train_cols].mean(axis=1)
    sd = df_all[train_cols].std(axis=1, ddof=0).replace(0, np.nan)
    Z = (df_all.sub(mu, axis=0)).div(sd, axis=0)
    Z = Z.replace([np.inf, -np.inf], np.nan).fillna(0.0).clip(-clip, clip)
    return Z, mu, sd

def parse_cms_from_tsi(url: str) -> Dict[str, str]:
    """
    Parse the Linkedomics .tsi and return {patient_id: CMS_label} for 'Transcriptomic_subtype'.
    The .tsi uses a wide format: first column is 'attrib_name', subsequent columns are patients.
    """
    df = fetch_tsv(url, index_col=0)
    df.index = df.index.astype(str)
    # Choose the transcriptomic subtype row robustly
    row_key = None
    candidates = [idx for idx in df.index if "Transcriptomic_subtype" in str(idx)]
    if candidates:
        row_key = candidates[0]
    else:
        # fallback: any row containing CMS-like labels
        for idx in df.index:
            vals = set(str(v) for v in df.loc[idx].values)
            if any(v.startswith("CMS") for v in vals):
                row_key = idx
                break
    if row_key is None:
        return {}

    row = df.loc[row_key]
    # Build mapping: keep only non-empty CMS labels
    cms_map = {}
    for pid, lab in row.items():
        if isinstance(lab, str):
            lab_s = lab.strip()
            if lab_s and lab_s.upper() != "NA":
                cms_map[pid.strip()] = lab_s
    return cms_map

In [42]:
def phospho_index_to_gene(phos_idx: pd.Index) -> pd.Series:
    """
    Map phosphosite row IDs to gene identifiers.
    Tries, in order:
      - Regex search for Ensembl ID (ENSG...).
      - Split on delimiters and look for ENSG or plausible gene symbol.
    Returns a Series of gene IDs aligned to phos_idx.
    """
    genes = []
    for s in phos_idx.astype(str):
        gene = None
        # Prefer explicit Ensembl IDs
        m = re.search(r"(ENSG[0-9]+(?:\.[0-9]+)?)", s)
        if m:
            gene = strip_ensembl_version(m.group(1))
        else:
            parts = re.split(r"[|,;:_\s]+", s)
            for p in parts:
                if p.startswith("ENSG"):
                    gene = strip_ensembl_version(p)
                    break
                if re.fullmatch(r"[A-Za-z][A-Za-z0-9\-]{0,20}", p):
                    gene = p.upper()
                    break
        genes.append(gene)
    return pd.Series(genes, index=phos_idx)

In [43]:
# Proteomics (gene-level)
prot_tumor = fetch_tsv(URLS["prot_tumor"], index_col=0)
prot_norm  = fetch_tsv(URLS["prot_norm"],  index_col=0)
prot_tumor.index = prot_tumor.index.map(strip_ensembl_version)
prot_norm.index  = prot_norm.index.map(strip_ensembl_version)
prot_tumor = clean_cols(prot_tumor)
prot_norm  = clean_cols(prot_norm)

# Phosphoproteomics (site-level) — stays site-level for now; aggregation happens later
phos_tumor = fetch_tsv(URLS["phos_tumor"], index_col=0)
phos_norm  = fetch_tsv(URLS["phos_norm"],  index_col=0)
phos_tumor = clean_cols(phos_tumor)
phos_norm  = clean_cols(phos_norm)

# Mutations (gene-level binary)
mut_bin = fetch_tsv(URLS["mut_gene_bin"], index_col=0)
mut_bin.index = mut_bin.index.map(strip_ensembl_version)
mut_bin = clean_cols(mut_bin)

# RNA (gene-level, tumor)
rna_tumor = fetch_tsv(URLS["rna_gene_tumor"], index_col=0)
rna_tumor.index = rna_tumor.index.map(strip_ensembl_version)
rna_tumor = clean_cols(rna_tumor)

# CNV (gene-level log2 ratio, tumor)
cnv_log2 = fetch_tsv(URLS["cnv_log2"], index_col=0)
cnv_log2.index = cnv_log2.index.map(strip_ensembl_version)
cnv_log2 = clean_cols(cnv_log2)

# ---- Light, safe cleanup: collapse duplicate gene rows ----
dup_prot_t = int(prot_tumor.index.duplicated().sum())
dup_prot_n = int(prot_norm.index.duplicated().sum())
dup_mut    = int(mut_bin.index.duplicated().sum())
dup_rna    = int(rna_tumor.index.duplicated().sum())
dup_cnv    = int(cnv_log2.index.duplicated().sum())

if dup_prot_t or dup_prot_n:
    prot_tumor = collapse_duplicate_rows(prot_tumor, how="median")
    prot_norm  = collapse_duplicate_rows(prot_norm,  how="median")

if dup_mut:
    # Ensure numeric first, then strict binary max across dups
    mut_bin = mut_bin.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    mut_bin = collapse_duplicate_rows(mut_bin, how="max_binary")

if dup_rna:
    rna_tumor = collapse_duplicate_rows(rna_tumor, how="median")

if dup_cnv:
    cnv_log2 = collapse_duplicate_rows(cnv_log2, how="median")

# Strict binary for mutation table (robust if source encodes ints/floats/strings)
mut_bin = mut_bin.apply(pd.to_numeric, errors="coerce").fillna(0.0)
mut_bin = (mut_bin > 0).astype(np.int8)

# Optional: drop obviously invalid/blank patient columns across all tables
# (keeps order; other alignment happens later)
for df_name in ["prot_tumor", "prot_norm", "phos_tumor", "phos_norm", "mut_bin", "rna_tumor", "cnv_log2"]:
    df = locals()[df_name]
    df.columns = pick_valid_patients(df.columns)

print("prot_tumor", prot_tumor.shape, "| prot_norm", prot_norm.shape, f"(dup rows removed: T={dup_prot_t}, N={dup_prot_n})")
print("phos_tumor", phos_tumor.shape, "| phos_norm", phos_norm.shape, "(site-level; will aggregate to genes later)")
print("mut_bin   ", mut_bin.shape,    f"(dup rows removed: {dup_mut})")
print("rna_tumor ", rna_tumor.shape,  f"(dup rows removed: {dup_rna})")
print("cnv_log2  ", cnv_log2.shape,   f"(dup rows removed: {dup_cnv})")

prot_tumor (9151, 97) | prot_norm (9152, 100) (dup rows removed: T=0, N=0)
phos_tumor (35487, 97) | phos_norm (35485, 100) (site-level; will aggregate to genes later)
mut_bin    (14783, 96) (dup rows removed: 0)
rna_tumor  (60624, 106) (dup rows removed: 45)
cnv_log2   (60558, 105) (dup rows removed: 45)


In [44]:
# ----- Cell 7 (CMS labels from .tsi; robust parsing) -----
tsi = fetch_tsv(URLS["tsi"], index_col=None)
assert "attrib_name" in tsi.columns, "Unexpected .tsi format: missing 'attrib_name'"
tsi = tsi.set_index("attrib_name")

# Find the row containing CMS transcriptomic subtype
row_key = None
if "Transcriptomic_subtype" in tsi.index:
    row_key = "Transcriptomic_subtype"
else:
    # try case-insensitive/contains
    matches = [ix for ix in tsi.index if "transcriptomic" in str(ix).lower() and "subtype" in str(ix).lower()]
    if matches:
        row_key = matches[0]
    else:
        # fallback: any row whose values look like CMS labels
        for ix in tsi.index:
            vals = set(str(v) for v in tsi.loc[ix].values)
            if any(v.upper().startswith("CMS") for v in vals):
                row_key = ix
                break

assert row_key is not None, "Could not locate CMS subtype row in the .tsi file."

cms_row = tsi.loc[row_key].to_dict()

# Keep only valid patient → CMS pairs (trim id; accept CMS1..CMS4; drop NA/blank)
valid_pairs = []
for pid, lab in cms_row.items():
    if not isinstance(pid, str) or not isinstance(lab, str):
        continue
    pid_s = pid.strip()
    lab_s = lab.strip().upper()
    if not pid_s or lab_s in {"NA", "NAN", ""}:
        continue
    if not lab_s.startswith("CMS"):  # be strict; ignore non-CMS entries
        continue
    valid_pairs.append((pid_s, lab_s))

assert len(valid_pairs) > 0, "No valid CMS-labeled patients found."

kept_patients, labels_str = zip(*valid_pairs)
kept_patients = list(kept_patients)
labels_str    = list(labels_str)

# Stable class order (CMS1..CMS4 if present)
classes = sorted(set(labels_str), key=lambda x: (len(x), x))  # CMS1,CMS2,CMS3,CMS4 in order
vocab   = {c: i for i, c in enumerate(classes)}               # e.g. {'CMS1':0,'CMS2':1,...}
labels_all = torch.tensor([vocab[s] for s in labels_str], dtype=torch.long)

# Handy map used later to align labels to proteomics patients
cms_pid2lab = {p: vocab[s] for p, s in zip(kept_patients, labels_str)}

# Prints
from collections import Counter
print("CMS classes (label->index):", vocab)
print("Total CMS-labeled patients:", len(kept_patients))
print("Class counts:", Counter(labels_all.numpy().tolist()))

CMS classes (label->index): {'CMS1': 0, 'CMS2': 1, 'CMS3': 2, 'CMS4': 3}
Total CMS-labeled patients: 85
Class counts: Counter({1: 33, 3: 22, 2: 16, 0: 14})


In [45]:
# ----- Cell 8 (normalize patient IDs + anchor CMS to proteomics) -----

# 8.1) Normalize patient IDs across all matrices (strip whitespace only)
for name in ["prot_tumor", "prot_norm", "phos_tumor", "phos_norm", "mut_bin", "rna_tumor", "cnv_log2"]:
    df = locals()[name]
    df.columns = [str(c).strip() for c in df.columns]
    locals()[name] = df  # rebind explicitly

# 8.2) Anchor = proteomics tumor columns
prot_patients = set(prot_tumor.columns)

# 8.3) Keep CMS patients present in proteomics tumor
# (cms_pid2lab and kept_patients come from Cell 7)
cms_in_prot = [p for p in kept_patients if p in prot_patients]
labels_in_prot = torch.tensor([cms_pid2lab[p] for p in cms_in_prot], dtype=torch.long)

print("CMS in proteomics:", len(cms_in_prot), "of", len(kept_patients))

# Optional quick diagnostics (no filtering performed here)
sets = {
    "prot": set(prot_tumor.columns),
    "phos": set(phos_tumor.columns),
    "rna":  set(rna_tumor.columns),
    "cnv":  set(cnv_log2.columns),
    "mut":  set(mut_bin.columns),
}
print("Patients per modality:",
      {k: len(v) for k, v in sets.items()})
print("Overlap w/ proteomics:",
      {k: len(sets[k] & sets["prot"]) for k in ["phos", "rna", "cnv", "mut"]})
print("CMS in (prot ∩ rna ∩ cnv):",
      len(set(cms_in_prot) & sets["rna"] & sets["cnv"]))

CMS in proteomics: 76 of 85
Patients per modality: {'prot': 97, 'phos': 97, 'rna': 106, 'cnv': 105, 'mut': 96}
Overlap w/ proteomics: {'phos': 97, 'rna': 96, 'cnv': 95, 'mut': 96}
CMS in (prot ∩ rna ∩ cnv): 75


In [48]:
# Tunable missingness threshold (across tumor+normal); 0.20–0.40 common in proteomics
PROT_MISS_MAX = 0.20

# Aggregate phospho sites -> gene level (median)
phos_gene_tumor_map = phospho_index_to_gene(phos_tumor.index)
phos_gene_norm_map  = phospho_index_to_gene(phos_norm.index)

phos_tumor_gene = phos_tumor.copy()
phos_tumor_gene["__gene__"] = phos_gene_tumor_map.values
phos_tumor_gene = (
    phos_tumor_gene
    .dropna(subset=["__gene__"])
    .groupby("__gene__")
    .median(numeric_only=True)
)

phos_norm_gene = phos_norm.copy()
phos_norm_gene["__gene__"] = phos_gene_norm_map.values
phos_norm_gene = (
    phos_norm_gene
    .dropna(subset=["__gene__"])
    .groupby("__gene__")
    .median(numeric_only=True)
)

# Align tumor vs normal within each modality (intersection within modality)
prot_genes   = prot_tumor.index.intersection(prot_norm.index)
prot_tumor2  = prot_tumor.loc[prot_genes].copy()
prot_norm2   = prot_norm.loc[prot_genes].copy()

phos_genes   = phos_tumor_gene.index.intersection(phos_norm_gene.index)
phos_tumor2  = phos_tumor_gene.loc[phos_genes].copy()
phos_norm2   = phos_norm_gene.loc[phos_genes].copy()

# Filter by missingness across tumor+normal (modality-wise)
def filter_by_missingness(df_tum: pd.DataFrame, df_norm: pd.DataFrame, prot_miss_max=0.20):
    df_all = pd.concat([df_tum, df_norm], axis=1)
    keep_prot = df_all.isna().mean(axis=1) <= prot_miss_max
    return df_tum.loc[keep_prot], df_norm.loc[keep_prot]

prot_tumor_f, prot_norm_f = filter_by_missingness(prot_tumor2, prot_norm2, prot_miss_max=PROT_MISS_MAX)
phos_tumor_f, phos_norm_f = filter_by_missingness(phos_tumor2, phos_norm2, prot_miss_max=PROT_MISS_MAX)

# Keep CMS-labeled proteomics tumor patients (proteomics is the patient anchor)
prot_tumor_f = prot_tumor_f.loc[:, prot_tumor_f.columns.intersection(cms_in_prot)]

# Proteomics missingness mask BEFORE imputation (tumor-only)
# (We keep this aligned to prot_tumor_f columns; final mask computed after union below)
# prot_mask_pre_raw = prot_tumor_f.isna().astype(np.float32)  # optional sanity, not used later

# Sanitize and KNN-impute each modality (adaptive k)
def sanitize_for_impute(df: pd.DataFrame, clip_abs: float | None = None) -> pd.DataFrame:
    df2 = df.apply(pd.to_numeric, errors="coerce").replace([np.inf, -np.inf], np.nan)
    if clip_abs is not None:
        df2 = df2.clip(lower=-clip_abs, upper=clip_abs)
    return df2

def knn_impute(df: pd.DataFrame, max_k: int = 5) -> pd.DataFrame:
    n_samples = df.shape[1]
    k = min(max_k, max(1, n_samples - 1))
    if n_samples < 2:
        row_med = df.median(axis=1, skipna=True)
        return df.apply(lambda col: col.fillna(row_med), axis=0)
    imp = KNNImputer(n_neighbors=k, weights='distance')
    vals = imp.fit_transform(df.T.values)  # [samples, features]
    return pd.DataFrame(vals.T, index=df.index, columns=df.columns)

# Choose patient lists for imputation
tumor_patients_for_prot = prot_tumor_f.columns.tolist()
tumor_patients_for_phos = phos_tumor_f.columns.intersection(cms_in_prot).tolist()
if len(tumor_patients_for_phos) == 0:
    tumor_patients_for_phos = phos_tumor_f.columns.tolist()

# Sanitize
prot_tumor_f_san = sanitize_for_impute(prot_tumor_f.loc[:, tumor_patients_for_prot], clip_abs=1e6)
prot_norm_f_san  = sanitize_for_impute(prot_norm_f,                                  clip_abs=1e6)
phos_tumor_f_san = sanitize_for_impute(phos_tumor_f.loc[:, tumor_patients_for_phos], clip_abs=1e6)
phos_norm_f_san  = sanitize_for_impute(phos_norm_f,                                  clip_abs=1e6)

# Impute
prot_tumor_imp = knn_impute(prot_tumor_f_san, max_k=5)
prot_norm_imp  = knn_impute(prot_norm_f_san,  max_k=5)
phos_tumor_imp = knn_impute(phos_tumor_f_san, max_k=5)
phos_norm_imp  = knn_impute(phos_norm_f_san,  max_k=5)

# Baseline-normalize with NORMALS ONLY (no leakage)
def baseline_z(tum: pd.DataFrame, norm: pd.DataFrame, clip: float = 5.0):
    mu = norm.mean(axis=1)
    sd = norm.std(axis=1, ddof=0).replace(0, np.nan)
    z  = (tum.sub(mu, axis=0)).div(sd, axis=0)
    z  = z.replace([np.inf, -np.inf], np.nan).fillna(0.0).clip(-clip, clip)
    return z, mu, sd

prot_z, prot_mu, prot_sd = baseline_z(prot_tumor_imp, prot_norm_imp, clip=5.0)
phos_z, phos_mu, phos_sd = baseline_z(phos_tumor_imp, phos_norm_imp, clip=5.0)

# Patients → align columns to CMS proteomics set
patient_ids = list(prot_z.columns.intersection(cms_in_prot))
prot_z = prot_z.loc[:, patient_ids]
phos_z = phos_z.reindex(columns=patient_ids, fill_value=np.nan)

# GENES → ALIGN BY UNION (not intersection)
union_proteins = prot_z.index.union(phos_z.index)

# Reindex to union; keep NaN in phospho for presence mask, fill proteo with zeros where absent
prot_z_u = prot_z.reindex(union_proteins, fill_value=0.0)
phos_z_u = phos_z.reindex(union_proteins)  # keep NaN now

# Masks (pre-impute, tumor-only), aligned to union
prot_mask_pre = prot_tumor_f.reindex(union_proteins).loc[:, patient_ids].isna().astype(np.float32)
phos_present  = phos_tumor_f.reindex(union_proteins).loc[:, patient_ids].notna().astype(np.float32)  # optional

# Fill phospho NaNs with 0 for features (presence mask retains availability info)
phos_z_u = phos_z_u.fillna(0.0)

prot_z        = prot_z_u
phos_z        = phos_z_u
prot_mask_pre = prot_mask_pre
protein_ids   = list(union_proteins)
patient_ids   = list(patient_ids)

print("Proteins kept (UNION):", len(protein_ids), "| Patients kept:", len(patient_ids))

Proteins kept (UNION): 7102 | Patients kept: 76


In [49]:
# ----- Cell X: align mutations to protein_id space & build per-patient lists -----

# 1) Collapse duplicate gene rows (still binary)
if mut_bin.index.has_duplicates:
    mut_bin = mut_bin.groupby(level=0).max(numeric_only=True)

# 2) Keep only mutation columns for patients we're actually using (order = patient_ids)
mut_cols = [p for p in patient_ids if p in mut_bin.columns]
mut_sub  = mut_bin.reindex(columns=mut_cols)

# 3) Align rows to protein_ids (already Ensembl base IDs/symbols from your pipeline)
#    Missing rows -> 0 (no mutation recorded for that protein)
mut_aligned_df = mut_sub.reindex(index=protein_ids).fillna(0)

# 4) Ensure strictly binary int8 (DataFrame -> ndarray)
mut_aligned = (mut_aligned_df.to_numpy() > 0).astype(np.int8)  # shape [N_prot, len(mut_cols)]

# 5) (Optional) Map protein id -> row index (handy if you need lookups later)
prot_idx_map = {g: i for i, g in enumerate(protein_ids)}

# 6) Build mut_lists in EXACT patient_ids order (empty for patients with no mut column)
mut_lists = []
present_cols = {c: j for j, c in enumerate(mut_cols)}  # patient_id -> column idx in mut_aligned
for p in patient_ids:
    j = present_cols.get(p, None)
    if j is None:
        mut_lists.append(np.array([], dtype=np.int64))
    else:
        prot_indices = np.nonzero(mut_aligned[:, j])[0].astype(np.int64)
        mut_lists.append(prot_indices)

# 7) Diagnostics
n_with_any = sum(arr.size > 0 for arr in mut_lists)
avg_muts   = float(np.mean([arr.size for arr in mut_lists])) if len(mut_lists) else 0.0
print(f"[mut_lists] patients={len(mut_lists)} | mut_table_cols_found={len(mut_cols)} "
      f"| with_any={n_with_any} | avg_mut_proteins/patient={avg_muts:.1f}")

# 8) Sanity checks
assert len(mut_lists) == len(patient_ids), "mut_lists must align 1:1 with patient_ids"
if len(mut_cols) == 0:
    print("WARNING: No mutation columns matched current patient_ids. "
          "Downstream mutation edges will all be empty.")

[mut_lists] patients=76 | mut_table_cols_found=76 | with_any=76 | avg_mut_proteins/patient=205.4


In [51]:
# ---------- proteomics-anchored patients; compact gene space; dedup-safe ----------

# Inputs expected from earlier cells:
# - prot_z, phos_z, prot_mask_pre  (proteo/phospho union; columns already tumor patients)
# - patient_ids                    (CMS-aligned proteomics tumor patients)
# - rna_tumor, cnv_log2            (raw RNA/CNV tables with gene rows, patient columns)

# --- Helpers to collapse duplicates on rows/columns (median across dups) ---
def collapse_duplicate_rows(df: pd.DataFrame, how="median"):
    if not df.index.is_unique:
        if how == "median":
            df = df.groupby(level=0).median(numeric_only=True)
        elif how == "mean":
            df = df.groupby(level=0).mean(numeric_only=True)
        else:
            raise ValueError("how must be 'median' or 'mean'")
    return df

def collapse_duplicate_cols(df: pd.DataFrame, how="median"):
    if not df.columns.is_unique:
        if how == "median":
            df = df.T.groupby(level=0).median(numeric_only=True).T
        elif how == "mean":
            df = df.T.groupby(level=0).mean(numeric_only=True).T
        else:
            raise ValueError("how must be 'median' or 'mean'")
    return df

def strip_index_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.index   = df.index.map(lambda x: str(x).strip())
    df.columns = df.columns.map(lambda x: str(x).strip())
    return df

# --- Proteomics-anchored patients (preserve order, de-dup) ---
patient_ids = list(dict.fromkeys(patient_ids))  # de-dup while preserving order

# Sync proteo/phospho/mask to these patients
prot_z = prot_z.loc[:, [p for p in patient_ids if p in prot_z.columns]]
phos_z = phos_z.reindex(columns=prot_z.columns, fill_value=0.0)  # align to proteomics anchor
prot_mask_pre = prot_mask_pre.reindex(columns=prot_z.columns)

prot_patients = prot_z.columns
assert prot_patients.is_unique, "Proteomics patient IDs must be unique."

# --- 9.2) Prepare RNA/CNV for these patients (allow missing) with robust de-dup ---
rna_tumor = strip_index_columns(rna_tumor)
cnv_log2  = strip_index_columns(cnv_log2)

# Keep only the columns we can match (but we will reindex to prot_patients later)
rna_cols = [p for p in prot_patients if p in rna_tumor.columns]
cnv_cols = [p for p in prot_patients if p in cnv_log2.columns]

rna_tumor_sub = rna_tumor.loc[:, rna_cols] if len(rna_cols) else pd.DataFrame(index=rna_tumor.index)
cnv_log2_sub = cnv_log2.loc[:,  cnv_cols] if len(cnv_cols) else pd.DataFrame(index=cnv_log2.index)

# Collapse duplicates on BOTH axes (avoid reindex errors)
rna_tumor_sub = collapse_duplicate_rows(collapse_duplicate_cols(rna_tumor_sub), how="median")
cnv_log2_sub = collapse_duplicate_rows(collapse_duplicate_cols(cnv_log2_sub),  how="median")

# Force numeric (non-numeric -> NaN) to keep downstream ops stable
rna_tumor_sub = rna_tumor_sub.apply(pd.to_numeric, errors="coerce")
cnv_log2_sub = cnv_log2_sub.apply(pd.to_numeric, errors="coerce")

# --- Compact gene space = protein ids ∩ RNA genes ∩ CNV genes ---
protein_index = pd.Index(prot_z.index).map(lambda x: str(x).strip())
rna_index = rna_tumor_sub.index
cnv_index = cnv_log2_sub.index

gene_space = protein_index.intersection(rna_index).intersection(cnv_index)

# If intersection is unexpectedly tiny, you can relax here (optional):
if len(gene_space) == 0:
    # fallback: intersect protein with whichever modality has more overlap
    inter_pr_rna = protein_index.intersection(rna_index)
    inter_pr_cnv = protein_index.intersection(cnv_index)
    gene_space = inter_pr_rna if len(inter_pr_rna) >= len(inter_pr_cnv) else inter_pr_cnv
    print(f"WARNING: 3-way intersection empty; falling back to 2-way size={len(gene_space)}")

# --- Reindex RNA/CNV to compact genes and proteomics patient order (dedup-safe) ---
# Columns we reindex to are unique by assert above
rna_full = rna_tumor_sub.reindex(index=gene_space, columns=prot_patients)
cnv_full = cnv_log2_sub.reindex(index=gene_space, columns=prot_patients)

# Availability masks BEFORE any z-scaling
rna_avl = rna_full.notna().astype(np.float32)
cnv_avl = cnv_full.notna().astype(np.float32)

# Final sanity prints
print("Patients kept (proteomics-anchored):", len(prot_patients))
print("Protein features:", prot_z.shape, "| Phospho:", phos_z.shape)
print("COMPACT gene space:", len(gene_space))
print("RNA raw (compacted):", rna_full.shape, "| CNV raw (compacted):", cnv_full.shape)

# Expose for next cells
patient_ids = list(prot_patients)
gene_ids    = list(gene_space)


Patients kept (proteomics-anchored): 76
Protein features: (7102, 76) | Phospho: (7102, 76)
COMPACT gene space: 7093
RNA raw (compacted): (7093, 76) | CNV raw (compacted): (7093, 76)


In [53]:
# Alignment checks (must be True)
assert phos_present.index.equals(phos_z.index),  "Row (gene) order mismatch"
assert phos_present.columns.equals(phos_z.columns), "Column (patient) order mismatch"
assert len(protein_ids) == phos_present.shape[0] == phos_z.shape[0], "Gene dimension mismatch"

# 1) How many genes have ANY phospho measured across patients?
genes_with_any_phos = int((phos_present.sum(axis=1) > 0).sum())
genes_with_no_phos  = int((phos_present.sum(axis=1) == 0).sum())
print("Genes with ≥1 phospho measurement:", genes_with_any_phos)
print("Genes with no phospho measurement:",  genes_with_no_phos)
print("Total genes (union):", len(protein_ids))

# 2) Per-patient phospho coverage (fraction of genes with any measured phospho in that patient)
per_patient_phos_cov = phos_present.sum(axis=0) / len(protein_ids)
print("Phospho coverage per patient (mean±sd):",
      float(per_patient_phos_cov.mean()), "+/-", float(per_patient_phos_cov.std()))

# 3) Strong sanity: features must be ~zero wherever mask==0
Z = phos_z.to_numpy()                          # shape [N_genes, N_patients]
M = phos_present.to_numpy(dtype=bool)          # same shape
eps = 1e-8
imputed_slots_used = int(((~M) & (np.abs(Z) > eps)).sum())
print("Phospho entries that were originally missing but now have imputed |z|>0:", imputed_slots_used)

# 4) Informational: how often present-but-numerically-zero (not an error)
present_but_zeroish = int((M & (np.abs(Z) <= eps)).sum())
print("Present phospho entries that are ~0 (valid, just FYI):", present_but_zeroish)

# 5) Example patient summary (use mask, not nonzero)
p0 = 0
mask_present_p0 = int(M[:, p0].sum())
print(f"Patient {p0}: mask-present phospho features =", mask_present_p0)

# (Optional) Show patients with lowest coverage
cov_sorted = per_patient_phos_cov.sort_values()
print("Lowest-coverage patients:", list(cov_sorted.index[:5]), "->", list(cov_sorted.values[:5]))

Genes with ≥1 phospho measurement: 3144
Genes with no phospho measurement: 3958
Total genes (union): 7102
Phospho coverage per patient (mean±sd): 0.42903780937194824 +/- 0.009626339189708233
Phospho entries that were originally missing but now have imputed |z|>0: 7370
Present phospho entries that are ~0 (valid, just FYI): 0
Patient 0: mask-present phospho features = 3091
Lowest-coverage patients: ['05CO026', '11CO045', '11CO030', '21CO006', '11CO031'] -> [0.39819768, 0.39819768, 0.3984793, 0.3984793, 0.40805408]


In [54]:
# HARD MASK phospho: zero out imputed values where mask==0
phos_z = phos_z.where(phos_present.astype(bool), 0.0)

# Re-run the consistency check:
Z = phos_z.to_numpy()
M = phos_present.to_numpy(dtype=bool)
eps = 1e-8
bad_nonzero_when_missing = int(((~M) & (np.abs(Z) > eps)).sum())
print("Non-zero phospho values where mask==0 (should be 0):", bad_nonzero_when_missing)  # expect 0

Non-zero phospho values where mask==0 (should be 0): 0


In [55]:
# ---------- CELL 10: StratifiedKFold split (train/val/test) ----------
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from collections import Counter
import numpy as np
import torch

# 10.1) Align CMS labels to patient_ids (from Cell 8/9 proteomics anchor)
cms_pid2lab = {p: l for p, l in zip(cms_in_prot, labels_in_prot.tolist())}
labels_aligned = torch.tensor([cms_pid2lab[p] for p in patient_ids], dtype=torch.long)
y = labels_aligned.numpy()

# 10.2) Drop classes with <2 samples (Stratified splits require ≥2 per class)
counts = np.bincount(y, minlength=int(y.max() + 1))
rare_classes = np.where(counts < 2)[0].tolist()
if rare_classes:
    keep_mask = ~np.isin(y, rare_classes)
    patient_ids = [p for p, m in zip(patient_ids, keep_mask) if m]
    labels_aligned = labels_aligned[keep_mask]
    y = labels_aligned.numpy()
    prot_z        = prot_z.loc[:, patient_ids]
    phos_z        = phos_z.loc[:, patient_ids]
    prot_mask_pre = prot_mask_pre.loc[:, patient_ids]
    rna_full      = rna_full.loc[:, patient_ids]
    cnv_full      = cnv_full.loc[:, patient_ids]
    rna_avl       = rna_avl.loc[:, patient_ids]
    cnv_avl       = cnv_avl.loc[:, patient_ids]
    print("Dropped rare CMS classes:", rare_classes)

# 10.3) StratifiedKFold → pick one fold for VAL, one for TEST; rest = TRAIN
# Choose folds so each is ~N/n_splits. For ~75 pts, n_splits=5 ⇒ ~15 per fold (nice for your target 12–16).
counts = np.bincount(y, minlength=int(y.max() + 1))
min_per_class = int(counts[counts > 0].min())
n_splits = int(min(5, max(2, min_per_class)))   # cap at 5 by default; never below 2

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
folds = list(skf.split(np.zeros_like(y), y))

# Nominally use fold 0 for VAL and fold 1 for TEST (disjoint)
VAL_FOLD  = 0
TEST_FOLD = 1 if n_splits >= 3 else 0  # if only 2 folds, fallback handled below

if n_splits >= 3:
    val_idx  = folds[VAL_FOLD][1]
    test_idx = folds[TEST_FOLD][1]
    all_idx = np.arange(len(y))
    train_mask = np.ones_like(all_idx, dtype=bool)
    train_mask[val_idx] = False
    train_mask[test_idx] = False
    train_idx = all_idx[train_mask]
else:
    # n_splits == 2 → create TEST from one fold's test, then carve VAL from the remaining via SSS
    test_idx = folds[1][1]              # ~50% as test
    rest_idx = folds[1][0]              # complement used for train+val
    # carve a small stratified val (e.g., 20% of rest)
    sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=43)
    tr_sub, va_sub = next(sss_val.split(np.zeros_like(y[rest_idx]), y[rest_idx]))
    train_idx = rest_idx[tr_sub]
    val_idx   = rest_idx[va_sub]

# Final reporting
print(f"n_splits={n_splits} | sizes → train={len(train_idx)}  val={len(val_idx)}  test={len(test_idx)}")
print("Train class counts:", Counter(y[train_idx]))
print("Val class counts:  ", Counter(y[val_idx]))
print("Test class counts: ", Counter(y[test_idx]))

n_splits=5 | sizes → train=45  val=16  test=15
Train class counts: Counter({1: 18, 3: 13, 2: 8, 0: 6})
Val class counts:   Counter({1: 6, 3: 4, 0: 3, 2: 3})
Test class counts:  Counter({1: 5, 3: 4, 0: 3, 2: 3})


In [38]:
# Labels aligned to patient_ids
cms_pid2lab = {p:l for p,l in zip(cms_in_prot, labels_in_prot.tolist())}
labels_aligned = torch.tensor([cms_pid2lab[p] for p in patient_ids], dtype=torch.long)
y = labels_aligned.numpy()

# Drop classes with <2 members
counts = np.bincount(y, minlength=int(y.max()+1))
rare = np.where(counts < 2)[0].tolist()
if rare:
    keep_mask = ~np.isin(y, rare)
    patient_ids = [p for p,m in zip(patient_ids, keep_mask) if m]
    labels_aligned = labels_aligned[keep_mask]
    y = labels_aligned.numpy()
    prot_z   = prot_z.loc[:, patient_ids]
    phos_z   = phos_z.loc[:, patient_ids]
    prot_mask_pre = prot_mask_pre.loc[:, patient_ids]
    rna_tumor_sub = rna_tumor_sub.loc[:, patient_ids]
    cnv_log2_sub  = cnv_log2_sub.loc[:, patient_ids]
    print("Dropped rare classes:", rare)

# Robust split: 1 per class in test, small val if possible
rng = np.random.default_rng(42)
classes = np.unique(y)
cls_to_idx = {c: np.where(y==c)[0].tolist() for c in classes}
test_idx = [int(rng.choice(idx)) for idx in cls_to_idx.values()]
test_mask = np.zeros_like(y, dtype=bool); test_mask[test_idx] = True
rest_idx = np.where(~test_mask)[0]
y_rest = y[rest_idx]

# Try a small stratified val, otherwise pick one per class if possible
val_idx = []
if len(rest_idx) >= 4 and all([(y_rest==c).sum()>=2 for c in classes]):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=43)
    tr_sub, va_sub = next(sss.split(np.zeros_like(y_rest), y_rest))
    train_idx = rest_idx[tr_sub]; val_idx = rest_idx[va_sub]
else:
    # 1 per class for val if possible
    for c in classes:
        pool = [i for i in rest_idx if y[i]==c]
        if len(pool)>=2:
            val_idx.append(int(rng.choice(pool)))
    val_idx = np.array(sorted(set(val_idx)), dtype=int)
    train_mask = (~test_mask).copy()
    train_mask[val_idx] = False
    train_idx = np.where(train_mask)[0]

print(f"Split sizes: train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")
print("Train class counts:", Counter(y[train_idx]))
print("Val class counts:  ", Counter(y[val_idx]))
print("Test class counts: ", Counter(y[test_idx]))

Split sizes: train=57 val=15 test=4
Train class counts: Counter({1: 22, 3: 16, 2: 10, 0: 9})
Val class counts:   Counter({1: 6, 3: 4, 2: 3, 0: 2})
Test class counts:  Counter({0: 1, 1: 1, 2: 1, 3: 1})


In [56]:
# ---------- CELL 11: Train-only normalization for RNA/CNV + tensorization ----------

# 11.1) Train-only z-scales for RNA & CNV (fit on *train* patients only, ignore NaN)
def z_by_train_only(df_full: pd.DataFrame, train_cols: list[str], clip=5.0):
    if len(train_cols) == 0:
        # no training data for this modality; return zeros
        Z = pd.DataFrame(0.0, index=df_full.index, columns=df_full.columns)
        mu = pd.Series(0.0, index=df_full.index)
        sd = pd.Series(1.0, index=df_full.index)
        return Z, mu, sd

    train_df = df_full.loc[:, train_cols]
    mu = train_df.mean(axis=1, skipna=True)
    sd = train_df.std(axis=1, ddof=0, skipna=True).replace(0, np.nan)

    Z = (df_full.sub(mu, axis=0)).div(sd, axis=0)
    # guard against inf, then clip; leave NaN to be filled later
    Z = Z.replace([np.inf, -np.inf], np.nan).clip(-clip, clip)
    return Z, mu, sd

# Which train patients actually have RNA/CNV?
train_cols_rna = [patient_ids[i] for i in train_idx if patient_ids[i] in rna_full.columns]
train_cols_cnv = [patient_ids[i] for i in train_idx if patient_ids[i] in cnv_full.columns]

rna_z_df, rna_mu, rna_sd = z_by_train_only(rna_full, train_cols_rna, clip=5.0)
cnv_z_df, cnv_mu, cnv_sd = z_by_train_only(cnv_full, train_cols_cnv, clip=5.0)

# 11.2) Fill NaNs with 0.0 for features; keep availability masks separately
rna_z = rna_z_df.fillna(0.0)
cnv_z = cnv_z_df.fillna(0.0)

# (Optional) If you decided to *hard-mask* phospho to zero where originally missing:
HARD_MASK_PHOS = False
if HARD_MASK_PHOS and 'phos_present' in globals():
    phos_z = phos_z.where(phos_present.astype(bool), 0.0)

# 11.3) Tensorize ALL channels (alignments assumed from prior cells)
# Protein channels (proteomics union space × patient_ids)
X_prot      = torch.tensor(prot_z.T.values,        dtype=torch.float32)   # [P, N_prot]
X_phos      = torch.tensor(phos_z.T.values,        dtype=torch.float32)   # [P, N_prot]
X_mask_prot = torch.tensor(prot_mask_pre.T.values, dtype=torch.float32)   # [P, N_prot]

# If you kept a phospho availability mask, expose it too (lets the GNN know measured vs imputed)
X_phos_avl = None
if 'phos_present' in globals():
    X_phos_avl = torch.tensor(phos_present.T.values, dtype=torch.float32) # [P, N_prot]

# Gene channels (value + availability) in the compact gene space
X_rna      = torch.tensor(rna_z.T.values,    dtype=torch.float32)         # [P, N_gene]
X_cnv      = torch.tensor(cnv_z.T.values,    dtype=torch.float32)         # [P, N_gene]
X_rna_avl  = torch.tensor(rna_avl.T.values,  dtype=torch.float32)         # [P, N_gene]
X_cnv_avl  = torch.tensor(cnv_avl.T.values,  dtype=torch.float32)         # [P, N_gene]

# Final IDs
protein_ids = list(prot_z.index)                   # union protein list from Cell 8
gene_ids    = list(rna_z.index.union(cnv_z.index)) # equals the compact gene_space

# Quick finite checks (will raise if something slipped through)
for name, X in [
    ("X_prot", X_prot), ("X_phos", X_phos), ("X_mask_prot", X_mask_prot),
    ("X_rna", X_rna), ("X_cnv", X_cnv), ("X_rna_avl", X_rna_avl), ("X_cnv_avl", X_cnv_avl)
]:
    if not torch.isfinite(X).all():
        raise ValueError(f"{name} contains non-finite values")

if X_phos_avl is not None and not torch.isfinite(X_phos_avl).all():
    raise ValueError("X_phos_avl contains non-finite values")

print("Protein channels:", X_prot.shape, X_phos.shape, X_mask_prot.shape, 
      "| phos_avl" if X_phos_avl is not None else "| phos_avl (not provided)")
print("Gene channels:",    X_rna.shape,  X_cnv.shape,  X_rna_avl.shape, X_cnv_avl.shape)

Protein channels: torch.Size([76, 7102]) torch.Size([76, 7102]) torch.Size([76, 7102]) | phos_avl
Gene channels: torch.Size([76, 7093]) torch.Size([76, 7093]) torch.Size([76, 7093]) torch.Size([76, 7093])


In [57]:
# ---------- CELL X: Build gene <-> protein cross-edges (name-matched) ----------
def is_ensembl(g):
    return isinstance(g, str) and g.upper().startswith("ENSG")

# 1) Build index maps
prot_idx = {g: i for i, g in enumerate(protein_ids)}  # protein (union) IDs
gene_idx = {g: i for i, g in enumerate(gene_ids)}     # compact gene space (RNA∩CNV∩proteo)

# 2) Exact-name matches → edges gene->protein
src, dst = [], []
overlap = 0
for g in gene_ids:
    if g in prot_idx:         # exact string match
        src.append(gene_idx[g])
        dst.append(prot_idx[g])
        overlap += 1

codes_edge_index     = torch.tensor([src, dst], dtype=torch.long)
codes_rev_edge_index = torch.tensor([dst, src], dtype=torch.long)

print(f"[codes edges] {codes_edge_index.shape[1]} (gene→protein); overlap genes={overlap}/{len(gene_ids)} "
      f"({overlap/len(gene_ids):.1%})")

# 3) Quick diagnostics: how many of your protein_ids are Ensembl vs symbols?
n_prot_ens  = sum(is_ensembl(p) for p in protein_ids)
n_gene_ens  = sum(is_ensembl(g) for g in gene_ids)
print(f"protein_ids: {len(protein_ids)} total | Ensembl-like={n_prot_ens} | non-Ensembl={len(protein_ids)-n_prot_ens}")
print(f"gene_ids   : {len(gene_ids)} total   | Ensembl-like={n_gene_ens} | non-Ensembl={len(gene_ids)-n_gene_ens}")

# 4) Optional: warn if overlap is low (often caused by symbol-vs-Ensembl mix)
if overlap < 0.6 * len(gene_ids):
    print("WARNING: Low gene↔protein overlap. Likely ID convention mismatch (symbols vs Ensembl).")
    print("Tip: ensure phospho aggregation produced Ensembl IDs when available (your function prefers ENSG).")
    print("If many protein rows are symbols only, consider restricting protein_ids to Ensembl-like to improve alignment.")

# --- OPTIONAL TIGHTENING (commented) ---
# If you decide to restrict to Ensembl-only proteins to maximize overlap, do it BEFORE tensorization:
# ONLY_ENSG_PROTEINS = False
# if ONLY_ENSG_PROTEINS:
#     keep_mask = [is_ensembl(p) for p in protein_ids]
#     prot_keep = [p for p, k in zip(protein_ids, keep_mask) if k]
#     # Reindex proteomics matrices (and any masks) to prot_keep, then rebuild protein_ids and codes_edge_index:
#     prot_z        = prot_z.loc[prot_keep]
#     phos_z        = phos_z.loc[prot_keep]
#     prot_mask_pre = prot_mask_pre.loc[prot_keep]
#     protein_ids   = prot_keep
#     # Rebuild the maps and edges after this restriction.

[codes edges] 7093 (gene→protein); overlap genes=7093/7093 (100.0%)
protein_ids: 7102 total | Ensembl-like=7102 | non-Ensembl=0
gene_ids   : 7093 total   | Ensembl-like=7093 | non-Ensembl=0


In [72]:
from sklearn.metrics.pairwise import cosine_similarity
from torch_geometric.utils import to_undirected  # is_undirected not needed

def build_ppi_knn_from_train_robust(X_train_prot: torch.Tensor,
                                    X_train_phos: torch.Tensor | None = None,
                                    k: int = 15,
                                    k_step: int = 5,
                                    k_max: int = 40,
                                    tiny_jitter: float = 1e-8):
    Vp = X_train_prot.cpu().numpy().T  # [N_prot, P_train]
    if X_train_phos is not None:
        Vh = X_train_phos.cpu().numpy().T
        V  = np.concatenate([Vp, Vh], axis=1)
    else:
        V = Vp

    V = np.nan_to_num(V, nan=0.0, posinf=0.0, neginf=0.0)
    row_norm = np.linalg.norm(V, axis=1)
    zero_rows = (row_norm == 0)
    if zero_rows.any():
        V[zero_rows, 0] = tiny_jitter

    N = V.shape[0]
    cur_k = min(k, max(1, N - 1))

    def make_edges(cur_k: int) -> torch.Tensor:
        S = cosine_similarity(V)               # [N, N]
        np.fill_diagonal(S, -np.inf)
        S = np.nan_to_num(S, nan=-1.0)
        kk = min(cur_k, N - 1)
        idx = np.argpartition(-S, kth=kk, axis=1)[:, :kk]

        pairs = set()
        for i in range(N):
            for j in idx[i]:
                if i == j: 
                    continue
                a, b = (i, j) if i < j else (j, i)
                pairs.add((a, b))

        if not pairs:
            return torch.empty((2, 0), dtype=torch.long)

        E = np.array(sorted(list(pairs)), dtype=np.int64)  # [M, 2]
        src = np.concatenate([E[:, 0], E[:, 1]], axis=0)
        dst = np.concatenate([E[:, 1], E[:, 0]], axis=0)
        edge_index = torch.tensor(np.stack([src, dst], axis=0), dtype=torch.long)

        # Always force canonical undirected (idempotent) and drop self-loops
        edge_index = to_undirected(edge_index, num_nodes=N)
        mask = edge_index[0] != edge_index[1]
        return edge_index[:, mask]

    while True:
        edge_index = make_edges(cur_k)
        deg = torch.bincount(edge_index[0], minlength=N)
        isolates = int((deg == 0).sum().item())
        print(f"[kNN] k={cur_k} -> edges={edge_index.size(1)} (directed), isolates={isolates}")
        if isolates == 0 or cur_k >= k_max:
            break
        cur_k = min(k_max, cur_k + k_step)

    return edge_index


# ---- Build PPI from TRAIN patients (uses your existing splits) ----
X_train_prot = X_prot[train_idx]   # [P_train, N_prot]
X_train_phos = X_phos[train_idx]   # [P_train, N_prot]
ppi_edge_index = build_ppi_knn_from_train_robust(X_train_prot, X_train_phos, k=15, k_step=5, k_max=40)

# Diagnostics
N = X_prot.shape[1]
deg = torch.bincount(ppi_edge_index[0], minlength=N)
print(f"Nodes={N} | Edges={ppi_edge_index.size(1)} (directed)")
print(f"Degree: mean={deg.float().mean():.2f}, min={int(deg.min())}, max={int(deg.max())}")
iso = (deg == 0).sum().item()
print("Isolated proteins:", iso)

# Connectedness (requires networkx)
import networkx as nx
E = list(zip(ppi_edge_index[0].tolist(), ppi_edge_index[1].tolist()))
G = nx.Graph(); G.add_edges_from(E)
n_comp = nx.number_connected_components(G)
giant = len(max(nx.connected_components(G), key=len)) / G.number_of_nodes()
print(f"Components={n_comp}, Giant component fraction={giant:.3f}")

[kNN] k=15 -> edges=190398 (directed), isolates=0
Nodes=7102 | Edges=190398 (directed)
Degree: mean=26.81, min=15, max=191
Isolated proteins: 0
Components=1, Giant component fraction=1.000


In [73]:
# Canonicalize each edge as (min, max), drop duplicates
E = torch.sort(ppi_edge_index, dim=0)[0].t()      # [E, 2]
E = torch.unique(E, dim=0).t()                    # [2, E_unique]
ppi_edge_index = torch.stack([E[0], E[1]], dim=0)

In [74]:
X_train_prot = X_prot[train_idx]
X_train_phos = X_phos[train_idx]
ppi_edge_index = build_ppi_knn_from_train_robust(X_train_prot, X_train_phos, k=15, k_step=5, k_max=40)

N = X_prot.shape[1]
deg = torch.bincount(ppi_edge_index[0], minlength=N)
print(f"Nodes={N} | Edges={ppi_edge_index.size(1)} (directed)")
print(f"Degree: mean={deg.float().mean():.2f}, min={int(deg.min())}, max={int(deg.max())}")
print("Isolated proteins:", int((deg==0).sum()))

vals, counts = torch.unique(deg, return_counts=True)
print("Degree histogram (degree:count):", dict(zip(vals.tolist(), counts.tolist())))

[kNN] k=15 -> edges=190398 (directed), isolates=0
Nodes=7102 | Edges=190398 (directed)
Degree: mean=26.81, min=15, max=191
Isolated proteins: 0
Degree histogram (degree:count): {15: 646, 16: 673, 17: 601, 18: 490, 19: 468, 20: 437, 21: 339, 22: 305, 23: 267, 24: 250, 25: 234, 26: 189, 27: 150, 28: 163, 29: 147, 30: 109, 31: 103, 32: 101, 33: 90, 34: 98, 35: 88, 36: 57, 37: 62, 38: 63, 39: 52, 40: 44, 41: 46, 42: 51, 43: 44, 44: 35, 45: 32, 46: 36, 47: 22, 48: 23, 49: 22, 50: 35, 51: 24, 52: 27, 53: 17, 54: 22, 55: 16, 56: 20, 57: 23, 58: 16, 59: 20, 60: 21, 61: 17, 62: 12, 63: 9, 64: 14, 65: 11, 66: 14, 67: 15, 68: 6, 69: 7, 70: 12, 71: 8, 72: 4, 73: 9, 74: 9, 75: 7, 76: 10, 77: 4, 78: 4, 79: 8, 80: 9, 81: 4, 82: 5, 83: 9, 84: 6, 85: 3, 86: 6, 87: 4, 88: 4, 89: 3, 90: 5, 91: 6, 92: 4, 93: 3, 94: 2, 96: 2, 97: 6, 98: 1, 99: 1, 100: 2, 101: 3, 102: 1, 103: 1, 104: 2, 105: 2, 106: 6, 107: 3, 108: 1, 109: 1, 110: 3, 111: 4, 113: 1, 115: 1, 117: 3, 118: 1, 120: 1, 122: 1, 124: 1, 128: 3, 13

In [75]:
from sklearn.metrics.pairwise import cosine_similarity

# pick the same patient subset & feature construction used to build the PPI
Vp = X_train_prot.cpu().numpy().T                    # [N_prot, P_train]
V  = Vp
if 'X_train_phos' in globals() and X_train_phos is not None:
    Vh = X_train_phos.cpu().numpy().T               # [N_prot, P_train]
    V  = np.concatenate([Vp, Vh], axis=1)           # multi-omics as in builder

# cosine sim like the builder
S = cosine_similarity(V)                             # [N, N]
np.fill_diagonal(S, -np.inf)

def topk_neighbors(i, k):
    kk = min(k, S.shape[1]-1)
    idx = np.argpartition(-S[i], kth=kk)[:kk]
    return set(idx.tolist())

def sym_topk_neighbors(i, k):
    """Neighbors after symmetrizing: j in topk(i) OR i in topk(j)."""
    tki = topk_neighbors(i, k)
    sym = set(tki)
    for j in range(S.shape[0]):
        if j == i: 
            continue
        # is i among j's top-k?
        kk = min(k, S.shape[1]-1)
        idx_j = np.argpartition(-S[j], kth=kk)[:kk]
        if i in idx_j:
            sym.add(j)
    return sym

# neighbors from the built graph
def graph_neighbors(i, edge_index):
    ei = edge_index.numpy()
    return set(ei[1, ei[0]==i].tolist())

prot = 0                       # change as needed
k_used = 15                    # the k you passed to the builder

g_nbrs   = graph_neighbors(prot, ppi_edge_index)
sym_nbrs = sym_topk_neighbors(prot, k_used)

overlap  = g_nbrs & sym_nbrs
print(f"Node {prot}: graph_deg={len(g_nbrs)} sym_topk_deg={len(sym_nbrs)} overlap={len(overlap)}")
print("Jaccard(graph, sym_topk) =",
      len(overlap) / max(1, len(g_nbrs | sym_nbrs)))

# (Optional) show the top-10 most similar indices and their sims
top10 = np.argsort(-S[prot])[:10]
print("Top-10 by cosine:", top10, "scores:", np.round(S[prot, top10], 3))

# (Optional) map to protein IDs for readability
if 'protein_ids' in globals():
    print("Top-10 protein IDs:", [protein_ids[i] for i in top10])
    print("Graph neighbors (IDs):", [protein_ids[i] for i in sorted(g_nbrs)])

Node 0: graph_deg=38 sym_topk_deg=38 overlap=38
Jaccard(graph, sym_topk) = 1.0
Top-10 by cosine: [2250 3571 4738 3589 1301 5349 1603 1666 6539 6972] scores: [0.795 0.788 0.785 0.783 0.776 0.774 0.771 0.77  0.764 0.76 ]
Top-10 protein IDs: ['ENSG00000115541', 'ENSG00000137288', 'ENSG00000160124', 'ENSG00000137563', 'ENSG00000101346', 'ENSG00000167600', 'ENSG00000105388', 'ENSG00000106028', 'ENSG00000197142', 'ENSG00000242110']
Graph neighbors (IDs): ['ENSG00000047230', 'ENSG00000070019', 'ENSG00000099800', 'ENSG00000101019', 'ENSG00000101346', 'ENSG00000101421', 'ENSG00000105388', 'ENSG00000106028', 'ENSG00000115541', 'ENSG00000118939', 'ENSG00000126432', 'ENSG00000130234', 'ENSG00000132541', 'ENSG00000136270', 'ENSG00000137288', 'ENSG00000137563', 'ENSG00000140092', 'ENSG00000145824', 'ENSG00000147044', 'ENSG00000147202', 'ENSG00000154639', 'ENSG00000160124', 'ENSG00000162910', 'ENSG00000164182', 'ENSG00000164576', 'ENSG00000164924', 'ENSG00000165215', 'ENSG00000165280', 'ENSG000001655

In [76]:
from torch_geometric.data import HeteroData

class MultiOmicsPatientDataset(torch.utils.data.Dataset):
    def __init__(self, patient_ids, y,
                 X_prot, X_phos, X_mask_prot,         # [P, N_prot]
                 X_rna,  X_cnv,  X_rna_avl, X_cnv_avl, # [P, N_gene]
                 mut_lists, protein_ids, gene_ids,
                 ppi_edge_index, codes_edge_index, codes_rev_edge_index,
                 use_masks: bool = True):
        super().__init__()
        self.pids = list(patient_ids)
        self.y = y

        # ---- basic shape guards
        Pp, Np = X_prot.shape
        assert X_phos.shape == (Pp, Np), "X_phos must match X_prot shape"
        assert X_mask_prot.shape == (Pp, Np), "X_mask_prot must match X_prot shape"

        Pg, Ng = X_rna.shape
        assert Pg == Pp, "X_rna must have same #patients as X_prot"
        assert X_cnv.shape      == (Pg, Ng), "X_cnv must match X_rna shape"
        assert X_rna_avl.shape  == (Pg, Ng), "X_rna_avl must match X_rna shape"
        assert X_cnv_avl.shape  == (Pg, Ng), "X_cnv_avl must match X_rna shape"
        assert len(mut_lists)   == Pp,       "mut_lists must align with patients"

        # ---- store (make contiguous for speed)
        self.X_prot = X_prot.contiguous()
        self.X_phos = X_phos.contiguous()
        self.X_mask_prot = (X_mask_prot if use_masks else torch.zeros_like(X_prot)).contiguous()

        self.X_rna = X_rna.contiguous()
        self.X_cnv = X_cnv.contiguous()
        self.X_rna_avl = (X_rna_avl if use_masks else torch.zeros_like(X_rna)).contiguous()
        self.X_cnv_avl = (X_cnv_avl if use_masks else torch.zeros_like(X_cnv)).contiguous()

        self.mut = mut_lists
        self.prot_ids = list(protein_ids)
        self.gene_ids = list(gene_ids)

        # ---- static edges
        self.ppi       = ppi_edge_index
        self.codes     = codes_edge_index
        self.codes_rev = codes_rev_edge_index

        # edge dtype/range checks
        for name, ei, n0, n1 in [
            ("ppi",       self.ppi,       Np, Np),
            ("codes",     self.codes,     Ng, Np),
            ("rev_codes", self.codes_rev, Np, Ng),
        ]:
            assert ei.dtype == torch.long and ei.dim() == 2 and ei.size(0) == 2, f"{name} must be [2,E] long"
            if ei.numel() > 0:
                max0 = int(ei[0].max())
                max1 = int(ei[1].max())
                assert max0 < n0 and max1 < n1, f"{name} edge index out of range"

        self.nP = Np  # #protein nodes
        self.nG = Ng  # #gene nodes

        # ---- static template graph (copied per __getitem__)
        self.template = HeteroData()
        self.template['protein'].x = torch.zeros(self.nP, 3)  # prot_z, phos_z, prot_missing_mask
        self.template['gene'].x    = torch.zeros(self.nG, 4)  # rna_z, cnv_z, rna_avl, cnv_avl
        self.template['patient'].x = torch.zeros(1, 1)

        self.template[('protein','ppi','protein')].edge_index   = self.ppi
        self.template[('gene','codes','protein')].edge_index    = self.codes
        self.template[('protein','rev_codes','gene')].edge_index= self.codes_rev

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, i):
        g = self.template.clone()

        # protein channels
        g['protein'].x[:, 0] = self.X_prot[i]
        g['protein'].x[:, 1] = self.X_phos[i]
        g['protein'].x[:, 2] = self.X_mask_prot[i]

        # gene channels
        g['gene'].x[:, 0] = self.X_rna[i]
        g['gene'].x[:, 1] = self.X_cnv[i]
        g['gene'].x[:, 2] = self.X_rna_avl[i]
        g['gene'].x[:, 3] = self.X_cnv_avl[i]

        # patient↔protein mutation edges (sparse, per-patient)
        mi = self.mut[i]
        if isinstance(mi, np.ndarray):
            mi = mi.tolist()
        if len(mi) > 0:
            src = torch.zeros(len(mi), dtype=torch.long)     # single patient node index 0
            dst = torch.tensor(mi, dtype=torch.long)         # protein indices
            g[('patient','mutated','protein')].edge_index   = torch.stack([src, dst], dim=0)
            g[('protein','rev_mutated','patient')].edge_index = torch.stack([dst, src], dim=0)
        else:
            g[('patient','mutated','protein')].edge_index     = torch.empty((2,0), dtype=torch.long)
            g[('protein','rev_mutated','patient')].edge_index = torch.empty((2,0), dtype=torch.long)

        g['patient'].y = torch.tensor([int(self.y[i].item())], dtype=torch.long)
        return g

In [77]:
g0 = train_ds[0]
print("protein feat dims:", g0['protein'].x.shape)
print("gene feat dims:",    g0['gene'].x.shape)
print("ppi edges:",         g0[('protein','ppi','protein')].edge_index.shape[1])
print("codes edges:",       g0[('gene','codes','protein')].edge_index.shape[1])
print("mut edges:",         g0[('patient','mutated','protein')].edge_index.shape[1])
print("label:",             g0['patient'].y.item())

protein feat dims: torch.Size([6259, 3])
gene feat dims: torch.Size([6250, 4])
ppi edges: 159842
codes edges: 6250
mut edges: 19
label: 1


In [78]:
class MultiOmicsKGNN(nn.Module):
    def __init__(self, in_protein=3, in_gene=4, in_patient=1, hidden=128, n_classes=4, dropout=0.2):
        super().__init__()
        self.lin_prot = nn.Linear(in_protein, hidden)
        self.lin_gene = nn.Linear(in_gene, hidden)
        self.lin_pat  = nn.Linear(in_patient, hidden)

        self.conv1 = HeteroConv({
            ('protein','ppi','protein'):           SAGEConv((-1, -1), hidden),
            ('gene','codes','protein'):            SAGEConv((-1, -1), hidden),
            ('protein','rev_codes','gene'):        SAGEConv((-1, -1), hidden),
            ('patient','mutated','protein'):       SAGEConv((-1, -1), hidden),
            ('protein','rev_mutated','patient'):   SAGEConv((-1, -1), hidden),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('protein','ppi','protein'):           SAGEConv((-1, -1), hidden),
            ('gene','codes','protein'):            SAGEConv((-1, -1), hidden),
            ('protein','rev_codes','gene'):        SAGEConv((-1, -1), hidden),
            ('patient','mutated','protein'):       SAGEConv((-1, -1), hidden),
            ('protein','rev_mutated','patient'):   SAGEConv((-1, -1), hidden),
        }, aggr='sum')

        self.dropout = nn.Dropout(dropout)
        self.lin_fuse = nn.Linear(3 * hidden, hidden)  # concat patient + protein + gene
        self.cls = nn.Linear(hidden, n_classes)

    def forward(self, g: HeteroData):
        # Project raw features to hidden
        x = {
            'protein': F.relu(self.lin_prot(g['protein'].x)),
            'gene':    F.relu(self.lin_gene(g['gene'].x)),
            'patient': F.relu(self.lin_pat(g['patient'].x)),
        }

        # Two hetero SAGE layers
        x = self.conv1(x, g.edge_index_dict); x = {k: F.relu(v) for k, v in x.items()}
        x = self.conv2(x, g.edge_index_dict); x = {k: F.relu(v) for k, v in x.items()}

        # Batch-aware pooling
        def bvec(name):
            return g[name].batch if 'batch' in g[name] else torch.zeros(
                x[name].size(0), dtype=torch.long, device=x[name].device)

        p_batch = bvec('protein')
        g_batch = bvec('gene')
        t_batch = bvec('patient')

        z_prot = global_mean_pool(x['protein'], p_batch)   # [B, H]
        z_gene = global_mean_pool(x['gene'],    g_batch)   # [B, H]
        z_pat  = global_mean_pool(x['patient'], t_batch)   # [B, H]

        z = torch.cat([z_pat, z_prot, z_gene], dim=-1)     # [B, 3H]
        z = self.dropout(F.relu(self.lin_fuse(z)))
        logits = self.cls(z)                                # [B, C]
        return logits


In [79]:
## ---------- Build mut_lists from gene-level binary (robust to Ensembl versions) ----------

def ensg_base(s: str) -> str:
    s = str(s).strip()
    m = re.match(r'^(ENSG[0-9]+)', s)   # keep only ENSG base (drop .version)
    return m.group(1) if m else s       # if not ENSG, return as-is (e.g., TP53)

# 1) Load
mut_url = URLS["mut_gene_bin"]
mut_df = pd.read_csv(mut_url, sep="\t", header=0)

# 2) Identify gene column (first col) and normalize to ENSG base
gene_col = mut_df.columns[0]
mut_df = mut_df.rename(columns={gene_col: "gene"})
mut_df["gene_base"] = mut_df["gene"].map(ensg_base)
mut_df = mut_df.drop(columns=["gene"]).set_index("gene_base")

# 2b) Coerce mutation values to numeric and collapse duplicate patient columns (if any)
mut_df = mut_df.apply(pd.to_numeric, errors="coerce")      # non-numeric -> NaN
mut_df = mut_df.fillna(0)                                  # treat NaN as 0 (no mutation)
if mut_df.columns.has_duplicates:
    mut_df = mut_df.T.groupby(level=0).max(numeric_only=True).T  # binary OR across dup cols

# 3) Collapse duplicate Ensembl rows (binary OR across dup rows)
if mut_df.index.has_duplicates:
    mut_df = mut_df.groupby(level=0).max(numeric_only=True)

# 4) Normalize current protein node IDs to the same base
protein_ids_base = [ensg_base(g) for g in protein_ids]
prot_base_to_idx = {}
for i, b in enumerate(protein_ids_base):
    # keep first occurrence so union order -> node index is stable
    if b not in prot_base_to_idx:
        prot_base_to_idx[b] = i

# 5) Keep only mutation columns for our current patient_ids (preserve order)
mut_cols = [p for p in patient_ids if p in mut_df.columns]
mut_df = mut_df.reindex(columns=mut_cols)

# 6) Align mutation rows to *base* protein IDs; missing -> 0
mut_df = mut_df.reindex(index=list(prot_base_to_idx.keys())).fillna(0)

# 7) Ensure strictly binary int8
mut_df = (mut_df > 0).astype(np.int8)

# 8) Build mut_lists in EXACT patient_ids order (empty if patient not present in file)
present_cols = set(mut_df.columns)
col_pos = {c:i for i, c in enumerate(mut_cols)}  # avoid O(N) .index() calls in loop

mut_lists = []
for p in patient_ids:
    if p in present_cols:
        j = col_pos[p]
        prot_indices = np.nonzero(mut_df.iloc[:, j].to_numpy(dtype=bool))[0].astype(np.int64)
        # map row order (base ID) -> protein node index
        base_hits = mut_df.index[prot_indices].tolist()
        idx_hits = np.array([prot_base_to_idx[b] for b in base_hits if b in prot_base_to_idx], dtype=np.int64)
        mut_lists.append(idx_hits)
    else:
        mut_lists.append(np.array([], dtype=np.int64))

# 9) Diagnostics + sanity checks
n_with_any = sum(arr.size > 0 for arr in mut_lists)
avg_muts   = float(np.mean([arr.size for arr in mut_lists])) if mut_lists else 0.0

print(
    f"[mut_lists] patients={len(mut_lists)} | with_any={n_with_any} | "
    f"avg_muted_proteins_per_patient={avg_muts:.1f}"
)
print(
    "Coverage:",
    f"mutation columns matched = {len(mut_cols)}/{len(patient_ids)} patients | "
    f"row overlap (mut→protein base ids with any 1s) = "
    f"{int((mut_df.sum(axis=1)>0).sum())}/{len(prot_base_to_idx)}"
)

assert len(mut_lists) == len(patient_ids), "mut_lists must align 1:1 with patient_ids"


[mut_lists] patients=76 | with_any=76 | avg_muted_proteins_per_patient=205.4
Coverage: mutation columns matched = 76/76 patients | row overlap (mut→protein base ids with any 1s) = 5407/7102


In [80]:
# ---------- Loaders ----------
from torch_geometric.loader import DataLoader

def take_rows(X, idx): 
    return X[idx]

# Split tensors by your precomputed indices
X_prot_tr, X_prot_va, X_prot_te = take_rows(X_prot, train_idx), take_rows(X_prot, val_idx), take_rows(X_prot, test_idx)
X_phos_tr, X_phos_va, X_phos_te = take_rows(X_phos, train_idx), take_rows(X_phos, val_idx), take_rows(X_phos, test_idx)
X_mprt_tr, X_mprt_va, X_mprt_te = take_rows(X_mask_prot, train_idx), take_rows(X_mask_prot, val_idx), take_rows(X_mask_prot, test_idx)

X_rna_tr,  X_rna_va,  X_rna_te  = take_rows(X_rna, train_idx),  take_rows(X_rna, val_idx),  take_rows(X_rna, test_idx)
X_cnv_tr,  X_cnv_va,  X_cnv_te  = take_rows(X_cnv, train_idx),  take_rows(X_cnv, val_idx),  take_rows(X_cnv, test_idx)
X_ravl_tr, X_ravl_va, X_ravl_te = take_rows(X_rna_avl, train_idx), take_rows(X_rna_avl, val_idx), take_rows(X_rna_avl, test_idx)
X_cavl_tr, X_cavl_va, X_cavl_te = take_rows(X_cnv_avl, train_idx), take_rows(X_cnv_avl, val_idx), take_rows(X_cnv_avl, test_idx)

y_tr, y_va, y_te = labels_aligned[train_idx], labels_aligned[val_idx], labels_aligned[test_idx]
pids_tr = [patient_ids[i] for i in train_idx]
pids_va = [patient_ids[i] for i in val_idx]
pids_te = [patient_ids[i] for i in test_idx]

train_ds = MultiOmicsPatientDataset(
    pids_tr, y_tr,
    X_prot_tr, X_phos_tr, X_mprt_tr,
    X_rna_tr,  X_cnv_tr,  X_ravl_tr, X_cavl_tr,
    [mut_lists[i] for i in train_idx],
    protein_ids, gene_ids,
    ppi_edge_index, codes_edge_index, codes_rev_edge_index
)
val_ds = MultiOmicsPatientDataset(
    pids_va, y_va,
    X_prot_va, X_phos_va, X_mprt_va,
    X_rna_va,  X_cnv_va,  X_ravl_va, X_cavl_va,
    [mut_lists[i] for i in val_idx],
    protein_ids, gene_ids,
    ppi_edge_index, codes_edge_index, codes_rev_edge_index
)
test_ds = MultiOmicsPatientDataset(
    pids_te, y_te,
    X_prot_te, X_phos_te, X_mprt_te,
    X_rna_te,  X_cnv_te,  X_ravl_te, X_cavl_te,
    [mut_lists[i] for i in test_idx],
    protein_ids, gene_ids,
    ppi_edge_index, codes_edge_index, codes_rev_edge_index
)

# Smaller batch size is safer; set num_workers=0 in notebooks
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  num_workers=0, pin_memory=torch.cuda.is_available())
val_loader   = DataLoader(val_ds,   batch_size=8, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())
test_loader  = DataLoader(test_ds,  batch_size=8, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())

# Quick sanity
g0 = train_ds[0]
print("protein feat dims:", g0['protein'].x.shape)  # (N_prot, 3)
print("gene feat dims:",    g0['gene'].x.shape)     # (N_gene, 4)
print("ppi edges:",         g0[('protein','ppi','protein')].edge_index.shape[1])
print("codes edges:",       g0[('gene','codes','protein')].edge_index.shape[1])


protein feat dims: torch.Size([7102, 3])
gene feat dims: torch.Size([7093, 4])
ppi edges: 190398
codes edges: 7093


In [85]:
from torch.utils.data import WeightedRandomSampler

# Build balanced sampler from training labels (y_tr is a 1D tensor of class ids)
cls_counts = np.bincount(y_tr.cpu().numpy(), minlength=int(y_tr.max().item())+1)
cls_weights = {c: 1.0 / cnt for c, cnt in enumerate(cls_counts) if cnt > 0}
sample_weights = torch.tensor([cls_weights[int(c.item())] for c in y_tr], dtype=torch.double)

# Optional: reproducible sampling
g = torch.Generator()
g.manual_seed(42)

train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),   # one epoch ~ same length as dataset
    replacement=True,
    generator=g
)

# DataLoaders (note: no shuffle when sampler is set)
train_loader = DataLoader(
    train_ds, batch_size=8, sampler=train_sampler,
    num_workers=0, pin_memory=torch.cuda.is_available(), drop_last=False
)
val_loader = DataLoader(
    val_ds, batch_size=8, shuffle=False,
    num_workers=0, pin_memory=torch.cuda.is_available()
)
test_loader = DataLoader(
    test_ds, batch_size=8, shuffle=False,
    num_workers=0, pin_memory=torch.cuda.is_available()
)

In [86]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_classes = int(labels_aligned.max().item() + 1)
model = MultiOmicsKGNN(in_protein=3, in_gene=4, in_patient=1, hidden=128, n_classes=n_classes).to(device)

# Try a forward pass
g_test = next(iter(train_loader)).to(device)
with torch.no_grad():
    out = model(g_test)
print("Logits shape:", tuple(out.shape))  # should be (batch_size, n_classes)

Logits shape: (8, 4)


In [87]:
# ---------- Model ----------
#import torch.nn.functional as F
#from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool

class MultiOmicsKGNN(nn.Module):
    def __init__(self, in_protein=3, in_gene=4, in_patient=1, hidden=128, n_classes=4, p_drop=0.2):
        super().__init__()
        self.lin_prot = nn.Linear(in_protein, hidden)
        self.lin_gene = nn.Linear(in_gene, hidden)
        self.lin_pat  = nn.Linear(in_patient, hidden)

        self.conv1 = HeteroConv({
            ('protein','ppi','protein'):           SAGEConv((-1,-1), hidden),
            ('gene','codes','protein'):            SAGEConv((-1,-1), hidden),
            ('protein','rev_codes','gene'):        SAGEConv((-1,-1), hidden),
            ('patient','mutated','protein'):       SAGEConv((-1,-1), hidden),
            ('protein','rev_mutated','patient'):   SAGEConv((-1,-1), hidden),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('protein','ppi','protein'):           SAGEConv((-1,-1), hidden),
            ('gene','codes','protein'):            SAGEConv((-1,-1), hidden),
            ('protein','rev_codes','gene'):        SAGEConv((-1,-1), hidden),
            ('patient','mutated','protein'):       SAGEConv((-1,-1), hidden),
            ('protein','rev_mutated','patient'):   SAGEConv((-1,-1), hidden),
        }, aggr='sum')

        self.dropout = nn.Dropout(p_drop)
        self.lin_fuse = nn.Linear(3*hidden, hidden)
        self.cls = nn.Linear(hidden, n_classes)

    def forward(self, g: HeteroData):
        x = {
            'protein': F.relu(self.lin_prot(g['protein'].x)),
            'gene':    F.relu(self.lin_gene(g['gene'].x)),
            'patient': F.relu(self.lin_pat(g['patient'].x)),
        }
        x = self.conv1(x, g.edge_index_dict); x = {k: self.dropout(F.relu(v)) for k,v in x.items()}
        x = self.conv2(x, g.edge_index_dict); x = {k: self.dropout(F.relu(v)) for k,v in x.items()}

        def bvec(tname):
            return g[tname].batch if 'batch' in g[tname] else torch.zeros(
                x[tname].size(0), dtype=torch.long, device=x[tname].device)

        z_prot = global_mean_pool(x['protein'], bvec('protein'))
        z_gene = global_mean_pool(x['gene'],    bvec('gene'))
        z_pat  = global_mean_pool(x['patient'], bvec('patient'))

        z = torch.cat([z_pat, z_prot, z_gene], dim=-1)
        z = self.dropout(F.relu(self.lin_fuse(z)))
        return self.cls(z)

In [None]:
# ---------- Training with per-epoch classification report ----------
from collections import Counter
from sklearn.metrics import classification_report, balanced_accuracy_score, f1_score, confusion_matrix
import torch.nn as nn

# Class names (align to label indices). If you have CMS names, you can set them here.
n_classes = int(labels_aligned.max().item() + 1)
try:
    class_names = list(classes)  # from Cell 7 (e.g., ["CMS1","CMS2","CMS3","CMS4"])
    if len(class_names) != n_classes:
        class_names = [f"c{i}" for i in range(n_classes)]
except NameError:
    class_names = [f"c{i}" for i in range(n_classes)]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate the model here
model = MultiOmicsKGNN(
    in_protein=3,   # prot_z, phos_z, prot_missing_mask
    in_gene=4,      # rna_z, cnv_z, rna_avl, cnv_avl
    in_patient=1,
    hidden=128,
    n_classes=n_classes
).to(device)

# Weighted loss from TRAIN ONLY (reuse y_tr from your split)
cnt = Counter(y_tr.cpu().numpy().tolist())
weights = torch.tensor([1.0 / max(cnt.get(i, 1), 1) for i in range(n_classes)],
                       dtype=torch.float32, device=device)
criterion = nn.CrossEntropyLoss(weight=weights)

opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    all_y, all_p = [], []
    tot_loss, n_batches = 0.0, 0
    for g in loader:
        g = g.to(device)
        logits = model(g)
        y = g['patient'].y.view(-1).to(device)
        loss = criterion(logits, y)
        tot_loss += float(loss.item()); n_batches += 1
        pred = logits.argmax(dim=-1)
        all_y.extend(y.detach().cpu().numpy().tolist())
        all_p.extend(pred.detach().cpu().numpy().tolist())
    if n_batches == 0:
        return 0.0, 0.0, np.array([]), np.array([])
    avg_loss = tot_loss / n_batches
    all_y, all_p = np.array(all_y), np.array(all_p)
    acc = (all_y == all_p).mean() if all_y.size else 0.0
    return avg_loss, acc, all_y, all_p

def train_epoch(loader):
    model.train()
    tot_loss, correct, total, n_batches = 0.0, 0, 0, 0
    for g in loader:
        g = g.to(device)
        logits = model(g)
        y = g['patient'].y.view(-1).to(device)
        loss = criterion(logits, y)
        opt.zero_grad(); loss.backward(); opt.step()
        tot_loss += float(loss.item()); n_batches += 1
        pred = logits.argmax(dim=-1)
        correct += int((pred == y).sum().item())
        total += y.size(0)
    avg_loss = tot_loss / max(n_batches, 1)
    acc = correct / max(total, 1)
    return avg_loss, acc

best_val_macro_f1, best_state = -1.0, None
EPOCHS = 30
REPORT_EVERY = 5  # print report every N epochs

for epoch in range(1, EPOCHS + 1):
    tr_loss, tr_acc = train_epoch(train_loader)
    va_loss, va_acc, y_true, y_pred = evaluate(val_loader)

    # Metrics robust to small val sets
    if y_true.size > 0:
        macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
        bal_acc  = balanced_accuracy_score(y_true, y_pred) if len(np.unique(y_true)) > 1 else va_acc
    else:
        macro_f1, bal_acc = 0.0, 0.0

    # Track best by macro-F1
    if macro_f1 > best_val_macro_f1:
        best_val_macro_f1 = macro_f1
        best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    print(f"Epoch {epoch:02d} | "
          f"train {tr_loss:.3f}/{tr_acc:.3f} | "
          f"val {va_loss:.3f}/{va_acc:.3f} | "
          f"macroF1 {macro_f1:.3f} balAcc {bal_acc:.3f}")

    if (epoch % REPORT_EVERY == 0) and (y_true.size > 0):
        print(classification_report(
            y_true, y_pred,
            labels=list(range(n_classes)),
            target_names=class_names,
            digits=3,
            zero_division=0
        ))
        cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes)))
        print("Confusion matrix (val):\n", cm)

# Load best and evaluate on test
if best_state is not None:
    model.load_state_dict(best_state)

te_loss, te_acc, y_true_te, y_pred_te = evaluate(test_loader)
print(f"TEST  | loss {te_loss:.3f} acc {te_acc:.3f}")
if y_true_te.size > 0:
    print(classification_report(
        y_true_te, y_pred_te,
        labels=list(range(n_classes)),
        target_names=class_names,
        digits=3,
        zero_division=0
    ))
    cm = confusion_matrix(y_true_te, y_pred_te, labels=list(range(n_classes)))
    print("Confusion matrix (test):\n", cm)

Epoch 01 | train 1.338/0.200 | val 1.471/0.188 | macroF1 0.079 balAcc 0.250
Epoch 02 | train 1.231/0.311 | val 1.592/0.188 | macroF1 0.079 balAcc 0.250
Epoch 03 | train 1.301/0.267 | val 1.454/0.188 | macroF1 0.079 balAcc 0.250
Epoch 04 | train 1.368/0.200 | val 1.381/0.188 | macroF1 0.079 balAcc 0.250
Epoch 05 | train 1.312/0.267 | val 1.393/0.188 | macroF1 0.079 balAcc 0.250
              precision    recall  f1-score   support

        CMS1      0.188     1.000     0.316         3
        CMS2      0.000     0.000     0.000         6
        CMS3      0.000     0.000     0.000         3
        CMS4      0.000     0.000     0.000         4

    accuracy                          0.188        16
   macro avg      0.047     0.250     0.079        16
weighted avg      0.035     0.188     0.059        16

Confusion matrix (val):
 [[3 0 0 0]
 [6 0 0 0]
 [3 0 0 0]
 [4 0 0 0]]
Epoch 06 | train 1.361/0.156 | val 1.405/0.375 | macroF1 0.275 balAcc 0.500
Epoch 07 | train 1.349/0.267 | val 1.41

In [32]:
from sklearn.metrics import classification_report, confusion_matrix

y_true, y_pred = [], []
model.eval()
for g in test_loader:
    g = g.to(device)
    with torch.no_grad():
        logits = model(g)
    y_true.extend(g['patient'].y.view(-1).cpu().numpy().tolist())
    y_pred.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

print(classification_report(y_true, y_pred, digits=3))
print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))

              precision    recall  f1-score   support

           0      1.000     1.000     1.000         1
           1      0.500     1.000     0.667         1
           2      1.000     1.000     1.000         1
           3      0.000     0.000     0.000         1

    accuracy                          0.750         4
   macro avg      0.625     0.750     0.667         4
weighted avg      0.625     0.750     0.667         4

Confusion matrix:
 [[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 1 0 0]]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
