In [3]:
!pip install pandas pyarrow torch --quiet

In [4]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from io import StringIO
import requests

from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter

from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool

In [5]:
# CPTAC COAD URLs (from your earlier messages)
URLS = {
    "tumor_cases": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_Tumor_CaseList.txt",
    "normal_cases": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_Normal_CaseList.txt",
    "proteomics_tumor": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_proteomics_gene_abundance_log2_reference_intensity_normalized_Tumor.txt",
    "proteomics_normal": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_proteomics_gene_abundance_log2_reference_intensity_normalized_Normal.txt",
    "maf_gene_binary": "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_somatic_mutation_gene_level_binary.txt",
    # clinical .tsi with CMS / proteomic subtypes
    "linkedomics_tsi": "https://linkedomics.org/cptac-colon/Human__CPTAC_COAD__MS__Clinical__Clinical__03_01_2017__CPTAC__Clinical__BCM.tsi",
}

In [6]:
def fetch_tsv(url, index_col=None):
    r = requests.get(url); r.raise_for_status()
    return pd.read_csv(StringIO(r.text), sep="\t", header=0, index_col=index_col)

def strip_version(x):
    # ENSG00000000003.15 -> ENSG00000000003
    if isinstance(x, str) and x.startswith("ENSG") and "." in x:
        return x.split(".", 1)[0]
    return x

In [7]:
prot_tumor = fetch_tsv(URLS["proteomics_tumor"], index_col=0)
prot_norm  = fetch_tsv(URLS["proteomics_normal"], index_col=0)

prot_tumor.index = prot_tumor.index.map(strip_version)
prot_norm.index  = prot_norm.index.map(strip_version)

# Keep intersection of proteins for consistent processing
common_prots = prot_tumor.index.intersection(prot_norm.index)
prot_tumor = prot_tumor.loc[common_prots].copy()
prot_norm  = prot_norm.loc[common_prots].copy()

tumor_ids  = prot_tumor.columns.tolist()
normal_ids = prot_norm.columns.tolist()

print(prot_tumor.shape, prot_norm.shape, len(common_prots))

(9151, 97) (9151, 100) 9151


In [9]:
# --- 4) Load clinical .tsi and build CMS-only labels (drop NA / NaN) ---
tsi = fetch_tsv(URLS["linkedomics_tsi"])
assert 'attrib_name' in tsi.columns, "Unexpected .tsi format"
tsi = tsi.set_index('attrib_name')

# Extract CMS row
cms_labels = tsi.loc['Transcriptomic_subtype'].to_dict()

# Keep only tumor patients that have valid CMS subtype (string, not NA)
valid_pairs = [(pid, lab) for pid, lab in cms_labels.items() if isinstance(lab, str) and lab != "NA"]

# Separate patient IDs and label strings
kept_patients, labels_str = zip(*valid_pairs)
kept_patients = list(kept_patients)
labels_str = list(labels_str)

# Build vocabulary
classes = sorted(set(labels_str))   # e.g. ['CMS1','CMS2','CMS3','CMS4']
vocab = {c:i for i,c in enumerate(classes)}

# Tensor of integer labels
labels_y = torch.tensor([vocab[lab] for lab in labels_str], dtype=torch.long)

print(f"Kept patients with CMS labels: {len(kept_patients)}")
print("CMS classes ->", vocab)

Kept patients with CMS labels: 85
CMS classes -> {'CMS1': 0, 'CMS2': 1, 'CMS3': 2, 'CMS4': 3}


In [13]:
# --- 5) Clean + baseline-normalize tumors against normals (and define patient_ids/protein_ids) ---

import pandas as pd
import numpy as np
import torch

# 5.0) Ensure inputs from previous cells exist:
assert 'prot_tumor' in globals() and 'prot_norm' in globals(), "Load proteomics first (Cell 3)."
assert 'kept_patients' in globals() and 'labels_y' in globals(), "Build CMS labels first (Cell 4)."

# 5.1) Align CMS-labeled patients to proteomics tumor columns (avoid KeyError)
prot_tumor.columns = [c.strip() for c in prot_tumor.columns]
prot_norm.columns  = [c.strip() for c in prot_norm.columns]
kp_clean = [p.strip() for p in kept_patients]

pid_to_lab = {pid: lab for pid, lab in zip(kp_clean, labels_y.tolist())}
kept_patients = [pid for pid in kp_clean if pid in prot_tumor.columns]
labels_y      = torch.tensor([pid_to_lab[pid] for pid in kept_patients], dtype=torch.long)

print(f"After alignment: {len(kept_patients)} tumor patients with CMS labels present in proteomics.")

# 5.2) Subset tumor to labeled patients; keep all normals
df_tumor = prot_tumor.loc[:, kept_patients].copy()
df_norm_ = prot_norm.copy()

# 5.3) Work on common proteins only
common_prots = df_tumor.index.intersection(df_norm_.index)
df_tumor = df_tumor.loc[common_prots].copy()
df_norm_ = df_norm_.loc[common_prots].copy()

# 5.4) Concatenate for consistent filtering across tumor+normal
df_all = pd.concat([df_tumor, df_norm_], axis=1)

# Missingness-based filtering (adjust thresholds as needed)
PROT_MISS_MAX = 0.40   # drop proteins with >40% missing across ALL samples
SAMP_MISS_MAX = 0.20   # drop samples with >20% missing proteins

prot_miss = df_all.isna().mean(axis=1)
samp_miss = df_all.isna().mean(axis=0)

keep_prot = prot_miss <= PROT_MISS_MAX
keep_samp = samp_miss <= SAMP_MISS_MAX
df_all = df_all.loc[keep_prot, keep_samp]

# Re-split after filtering
tumor_ids_clean  = [c for c in df_all.columns if c in set(kept_patients)]
normal_ids_clean = [c for c in df_all.columns if c in set(prot_norm.columns)]
df_tumor = df_all.loc[:, tumor_ids_clean]
df_norm  = df_all.loc[:, normal_ids_clean]

# 5.5) Missingness mask BEFORE imputation (for tumor only)
miss_mask_tumor = df_tumor.isna().astype(np.float32)

# 5.6) Impute remaining NAs with per-protein median (robust after filtering)
row_median_all = df_all.median(axis=1, skipna=True)
df_tumor_imp = df_tumor.apply(lambda col: col.fillna(row_median_all), axis=0)
df_norm_imp  = df_norm.apply(lambda col: col.fillna(row_median_all), axis=0)

# 5.7) Baseline stats from NORMALS ONLY (no leakage)
mean_norm = df_norm_imp.mean(axis=1)
std_norm  = df_norm_imp.std(axis=1, ddof=0).replace(0, np.nan)

# 5.8) Baseline-normalized tumor z (clip extremes)
Z_CLIP = 5.0
z_baseline = (df_tumor_imp.sub(mean_norm, axis=0)).div(std_norm, axis=0).fillna(0.0).clip(-Z_CLIP, Z_CLIP)

# 5.9) Define IDs and build tensors
protein_ids = z_baseline.index.tolist()         # <--- proteins after filtering
patient_ids = z_baseline.columns.tolist()       # <--- tumors after filtering

X_prot = torch.tensor(z_baseline.T.values, dtype=torch.float32)  # [P, N]
X_mask = torch.tensor(miss_mask_tumor.loc[protein_ids, patient_ids].T.values,
                      dtype=torch.float32)                        # [P, N]

print("Shapes -> X_prot:", X_prot.shape, "| X_mask:", X_mask.shape)
print("#proteins:", len(protein_ids), "| #patients:", len(patient_ids))

After alignment: 76 tumor patients with CMS labels present in proteomics.
Shapes -> X_prot: torch.Size([7, 6572]) | X_mask: torch.Size([7, 6572])
#proteins: 6572 | #patients: 7


In [15]:
from sklearn.model_selection import StratifiedKFold

# 6.1) Align labels to patient_ids (from Cell 5)
if 'patient_ids' not in globals():
    raise RuntimeError("patient_ids not defined — re-run Cell 5 first.")

pid_to_lab = {pid: lab for pid, lab in zip(kept_patients, labels_y.tolist())}
labels_aligned = torch.tensor([pid_to_lab[pid] for pid in patient_ids if pid in pid_to_lab],
                              dtype=torch.long)

# ensure X_prot/X_mask are in the same order as patient_ids already (Cell 5 did that)
assert X_prot.shape[0] == len(patient_ids), "X_prot rows must match patient_ids length."
assert X_mask.shape[0] == len(patient_ids), "X_mask rows must match patient_ids length."

# 6.2) Drop classes with <2 samples (Stratified splitting requires >=2 per class)
y_np = labels_aligned.numpy()
counts = np.bincount(y_np, minlength=int(y_np.max()+1))
rare_classes = np.where(counts < 2)[0].tolist()

if len(rare_classes) > 0:
    mask = ~np.isin(y_np, rare_classes)
    kept_before = len(y_np)
    # filter all aligned arrays
    patient_ids = [p for p, m in zip(patient_ids, mask) if m]
    X_prot      = X_prot[mask]
    X_mask      = X_mask[mask]
    labels_aligned = labels_aligned[mask]
    y_np = labels_aligned.numpy()
    print(f"Dropped {kept_before - len(y_np)} samples from classes with <2 members:", rare_classes)

# 6.3) Remap labels to 0..C'-1 after dropping
unique = sorted(np.unique(y_np))
remap = {old:i for i, old in enumerate(unique)}
labels_aligned = torch.tensor([remap[int(v)] for v in y_np], dtype=torch.long)
y_np = labels_aligned.numpy()
n_classes = len(unique)
print("Class counts after filtering:", np.bincount(y_np))

# 6.4) Use StratifiedKFold to create test and val folds robustly
min_count = np.bincount(y_np).min()
if min_count < 2:
    raise RuntimeError("Still have a class with <2 samples after filtering.")

# choose number of folds based on the rarest class
n_splits = min(5, int(min_count))  # at least 2
if n_splits < 2:
    n_splits = 2

skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
folds = list(skf.split(np.zeros_like(y_np), y_np))

# Take the first fold's test indices; rest is train+val
rest_idx, test_idx = folds[0][0], folds[0][1]

# Now split rest into train/val using another StratifiedKFold
y_rest = y_np[rest_idx]
min_count_rest = np.bincount(y_rest).min()
n_splits_val = min(5, int(min_count_rest))
if n_splits_val < 2:
    n_splits_val = 2
skf2 = StratifiedKFold(n_splits=n_splits_val, shuffle=True, random_state=43)
rest_folds = list(skf2.split(np.zeros_like(y_rest), y_rest))
train_sub, val_sub = rest_folds[0][0], rest_folds[0][1]
train_idx = rest_idx[train_sub]
val_idx   = rest_idx[val_sub]

print(f"Split sizes -> train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")

# Build the patient ID lists (useful for logging/debug)
train_patients = [patient_ids[i] for i in train_idx]
val_patients   = [patient_ids[i] for i in val_idx]
test_patients  = [patient_ids[i] for i in test_idx]

# 6.5) Slice tensors and labels for each split
def take_rows(X, idx): return X[idx]

X_train, X_val, X_test = take_rows(X_prot, train_idx), take_rows(X_prot, val_idx), take_rows(X_prot, test_idx)
M_train, M_val, M_test = take_rows(X_mask, train_idx), take_rows(X_mask, val_idx), take_rows(X_mask, test_idx)
y_train, y_val, y_test = labels_aligned[train_idx], labels_aligned[val_idx], labels_aligned[test_idx]

print("Shapes ->",
      "X_train", X_train.shape, "X_val", X_val.shape, "X_test", X_test.shape,
      "| y_train", y_train.shape, "y_val", y_val.shape, "y_test", y_test.shape)

Dropped 2 samples from classes with <2 members: [0, 3]
Class counts after filtering: [3 2]


ValueError: n_splits=2 cannot be greater than the number of members in each class.

In [83]:
# Save key artifacts
pd.Series(patient_ids_tumor).to_csv("COAD_patient_ids_tumor.tsv", sep="\t", index=False, header=False)
pd.Series(protein_ids).to_csv("COAD_protein_ids.tsv", sep="\t", index=False, header=False)
torch.save(X_prot_tumor_z, "COAD_X_proteomics_tumor_z.pt")
if mut_edge_index is not None:
    torch.save(mut_edge_index, "COAD_mut_edge_index.pt")

print("Saved: COAD_X_proteomics_tumor_z.pt, IDs, norm stats, and mutation edges (if available).")

Saved: COAD_X_proteomics_tumor_z.pt, IDs, norm stats, and mutation edges (if available).


In [84]:
# ---- paths from your earlier saving step
X_path = Path("COAD_X_proteomics_tumor_z.pt")
patients_path = Path("COAD_patient_ids_tumor.tsv")
proteins_path = Path("COAD_protein_ids.tsv")

assert X_path.exists() and patients_path.exists() and proteins_path.exists(), "Run the earlier ingest cell first."

# Proteomics (z-scored) — shape [n_patients, n_proteins]
X_prot = torch.load(X_path)  # FloatTensor
patient_ids = pd.read_csv(patients_path, sep="\t", header=None)[0].astype(str).tolist()
protein_ids = pd.read_csv(proteins_path,  sep="\t", header=None)[0].astype(str).tolist()

n_patients, n_proteins = X_prot.shape
print(f"Loaded proteomics: {n_patients} patients × {n_proteins} proteins")


Loaded proteomics: 97 patients × 9151 proteins


In [92]:
# Align mutation table
mut = pd.read_csv(URLS["maf_gene_binary"], sep="\t", header=0, index_col=0)

# Normalize identifiers to match our protein_ids (strip Ensembl version if present)
def strip_version(x):
    if isinstance(x, str) and x.startswith("ENSG") and "." in x:
        return x.split(".")[0]
    return x

mut.index = mut.index.map(strip_version)

# Keep only our cohort & protein list
mut = mut.reindex(index=protein_ids)               # rows aligned to our proteins
mut = mut[[c for c in mut.columns if c in set(patient_ids)]]  # cols aligned to our patients
mut = mut.fillna(0).astype(int)

# Build per-patient mutated protein index lists
patient_idx_map = {pid:i for i, pid in enumerate(patient_ids)}
mutated_indices_by_patient = [
    np.where(mut.iloc[:, j].values == 1)[0].astype(np.int64)
    for j in range(mut.shape[1])
]
#print(mutated_indices_by_patient)
print("Mutation table aligned:", mut.shape)

Mutation table aligned: (9151, 96)


In [86]:
tsi_url = 
tsi = pd.read_csv(tsi_url, sep="\t")  # wide 'attrib_name' table
assert 'attrib_name' in tsi.columns, "Unexpected .tsi format"
tsi = tsi.set_index('attrib_name')

# two rows we care about
cms_row  = tsi.loc['Transcriptomic_subtype']      # values: CMS1..CMS4 or 'NA'
prot_row = tsi.loc['Proteomic_subtype']           # values: A..E or 'NA'

# turn into dicts: {sample_id -> label_str}
cms_labels  = cms_row.to_dict()
prot_labels = prot_row.to_dict()
print(len(cms_labels))

110


In [103]:
# Preferred: CMS (Transcriptomic_subtype). If missing for a sample, we can fall back to proteomic subtype.
use_fallback_to_proteomic = False

label_strs = []
for pid in patient_ids:
    lab = cms_labels.get(pid, 'NA')
    if lab == 'NA' and use_fallback_to_proteomic:
        lab = prot_labels.get(pid, 'NA')
    label_strs.append(lab)

In [104]:
# --- 3) Filter to labeled patients & build integer labels ---
keep_mask = np.array([lab != 'NA' and isinstance(lab, str) for lab in label_strs], dtype=bool)
kept_patients = [p for p, k in zip(patient_ids, keep_mask) if k]
print(f"Labeled patients kept: {keep_mask.sum()} / {len(patient_ids)}")

Labeled patients kept: 76 / 97


In [105]:
# choose vocabulary over actually present labels
present_labels = sorted({lab for lab, k in zip(label_strs, keep_mask) if k})
label_vocab = {lab:i for i, lab in enumerate(present_labels)}
print("Label vocabulary:", label_vocab)

Label vocabulary: {'CMS1': 0, 'CMS2': 1, 'CMS3': 2, 'CMS4': 3}


In [106]:
# indices of kept patients
kept_idx = np.where(keep_mask)[0]
# tensors aligned to kept patients
X_prot_kept = X_prot[kept_idx]  # [P_kept, N]
labels_y    = torch.tensor([label_vocab[label_strs[i]] for i in kept_idx], dtype=torch.long)

In [107]:
def strip_version(x):
    if isinstance(x, str) and x.startswith("ENSG") and "." in x:
        return x.split(".", 1)[0]
    return x

# 1) make row IDs (genes/proteins) comparable and align them to your protein node order
mut.index = mut.index.map(strip_version)
mut = mut.reindex(index=protein_ids)  # rows now in same order as your protein nodes

# 2) build patient->indices dict, only for the kept patients (labels available)
pid_to_mutidx = {}
for pid in kept_patients:
    if pid in mut.columns:
        col = mut[pid].fillna(0).astype(int).values  # length = len(protein_ids)
        pid_to_mutidx[pid] = np.where(col == 1)[0].astype(np.int64)
    else:
        pid_to_mutidx[pid] = np.array([], dtype=np.int64)

# 3) assemble the final list in exact kept order (matches X_prot_kept / labels_y)
mut_idx_by_patient_kept = [pid_to_mutidx[pid] for pid in kept_patients]

print("Sanity:", len(mut_idx_by_patient_kept), "lists; first 5 sizes ->",
      [len(a) for a in mut_idx_by_patient_kept[:5]])

Sanity: 76 lists; first 5 sizes -> [63, 477, 40, 799, 31]


In [118]:
from torch_geometric.nn import HeteroConv, SAGEConv, global_mean_pool

class KGNNWithProteinReadout(nn.Module):
    def __init__(self, in_protein=1, in_patient=1, hidden=128, n_classes=4):
        super().__init__()
        self.lin_prot_in  = nn.Linear(in_protein, hidden)
        self.lin_pat_in   = nn.Linear(in_patient, hidden)

        self.conv1 = HeteroConv({
            ('patient','mutated','protein'):     SAGEConv((-1, -1), hidden),
            ('protein','rev_mutated','patient'): SAGEConv((-1, -1), hidden),
        }, aggr='sum')
        self.conv2 = HeteroConv({
            ('patient','mutated','protein'):     SAGEConv((-1, -1), hidden),
            ('protein','rev_mutated','patient'): SAGEConv((-1, -1), hidden),
        }, aggr='sum')

        self.lin_fuse   = nn.Linear(2*hidden, hidden)
        self.classifier = nn.Linear(hidden, n_classes)

    def forward(self, g):
        x = {
            'protein': F.relu(self.lin_prot_in(g['protein'].x)),
            'patient': F.relu(self.lin_pat_in(g['patient'].x)),
        }
        x = self.conv1(x, g.edge_index_dict); x = {k: F.relu(v) for k,v in x.items()}
        x = self.conv2(x, g.edge_index_dict); x = {k: F.relu(v) for k,v in x.items()}

        # --- pool PER GRAPH using batch vectors ---
        prot_batch = g['protein'].batch if 'batch' in g['protein'] else torch.zeros(
            x['protein'].size(0), dtype=torch.long, device=x['protein'].device
        )
        pat_batch  = g['patient'].batch if 'batch' in g['patient'] else torch.zeros(
            x['patient'].size(0), dtype=torch.long, device=x['patient'].device
        )

        z_prot = global_mean_pool(x['protein'], prot_batch)   # [B, H]
        z_pat  = global_mean_pool(x['patient'], pat_batch)    # [B, H] (safe even if 1 node/graph)

        z = torch.cat([z_pat, z_prot], dim=-1)                # [B, 2H]
        z = F.relu(self.lin_fuse(z))
        logits = self.classifier(z)                           # [B, C]
        return logits

In [None]:
from collections import Counter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_classes = int(labels_y.max().item() + 1)

model = KGNNWithProteinReadout(in_protein=1, in_patient=1, hidden=128, n_classes=n_classes).to(device)

# optional: class-weighted loss for imbalance
cnt = Counter(labels_y.numpy().tolist())
weights = torch.tensor([1.0 / cnt[i] for i in range(n_classes)], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=weights)

opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total, correct, total_loss = 0, 0, 0.0
    for g in loader:
        g = g.to(device)
        logits = model(g)                # [B_graphs?, n_classes] (here 1 per g)
        y = g['patient'].y.view(-1).to(device)
        loss = criterion(logits, y)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        total_loss += float(loss.item())
        pred = logits.argmax(dim=-1)
        correct += int((pred == y).sum().item())
        total += y.size(0)
    acc = correct / max(total, 1)
    return total_loss / max(len(loader), 1), acc

best_val, best_state = 0.0, None
for epoch in range(1, 31):  # 30 epochs to start
    tr_loss, tr_acc = run_epoch(train_loader, True)
    va_loss, va_acc = run_epoch(val_loader,   False)
    if va_acc > best_val:
        best_val = va_acc
        best_state = {k: v.cpu() for k, v in model.state_dict().items()}
    print(f"Epoch {epoch:02d} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")

# Load best weights and evaluate on test
model.load_state_dict(best_state)
te_loss, te_acc = run_epoch(test_loader, False)
print(f"TEST acc: {te_acc:.3f}")

Epoch 01 | train 1.398/0.151 | val 1.388/0.273


In [None]:
k = 0  # pick an index in 0..len(dataset)-1 within the kept subset
g = dataset[k].to(device)
model.eval()
with torch.no_grad():
    logits = model(g)
probs = logits.softmax(dim=-1).cpu().numpy()[0]
pred_class = int(probs.argmax())
print("Pred class:", pred_class, "probs:", probs)

In [95]:
# --- 4) Update the Dataset to take externally provided labels & patient subset ---
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv

class PatientGraphDatasetLabeled(torch.utils.data.Dataset):
    def __init__(self, X_prot, mutated_idx_by_patient, labels):
        self.X = X_prot                              # [P, N]
        self.mut_idx = mutated_idx_by_patient        # list of np arrays
        self.labels = labels                         # torch.long [P]
        self.n_patients, self.n_proteins = X_prot.shape

        self.template = HeteroData()
        self.template['protein'].x = torch.zeros(self.n_proteins, 1)   # abundance channel
        self.template['patient'].x = torch.zeros(1, 1)                 # (optional covars slot)

    def __len__(self):
        return self.n_patients

    def __getitem__(self, i):
        g = self.template.clone()
        g['protein'].x[:, 0] = self.X[i]

        mut_idx = self.mut_idx[i]
        if mut_idx.size > 0:
            src = torch.zeros(len(mut_idx), dtype=torch.long)              # patient node index 0
            dst = torch.tensor(mut_idx, dtype=torch.long)                  # protein indices
            g[('patient','mutated','protein')].edge_index = torch.stack([src, dst], dim=0)
            # add reverse edges so messages can flow back to patient
            r_src = dst.clone()
            r_dst = torch.zeros_like(src)
            g[('protein','rev_mutated','patient')].edge_index = torch.stack([r_src, r_dst], dim=0)
        else:
            g[('patient','mutated','protein')].edge_index = torch.empty((2,0), dtype=torch.long)
            g[('protein','rev_mutated','patient')].edge_index = torch.empty((2,0), dtype=torch.long)

        # supervision on patient node
        g['patient'].y = self.labels[i:i+1]
        return g

dataset = PatientGraphDatasetLabeled(X_prot_kept, mut_idx_by_patient_kept, labels_y)

# --- 5) Stratified split (by label) ---
from sklearn.model_selection import StratifiedShuffleSplit
y_np = labels_y.numpy()
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
train_idx, temp_idx = next(sss1.split(np.zeros_like(y_np), y_np))
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_idx, test_idx = next(sss2.split(np.zeros_like(y_np[temp_idx]), y_np[temp_idx]))
val_idx  = temp_idx[val_idx]
test_idx = temp_idx[test_idx]

train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx.tolist()), batch_size=16, shuffle=True)
val_loader   = DataLoader(torch.utils.data.Subset(dataset, val_idx.tolist()),   batch_size=16)
test_loader  = DataLoader(torch.utils.data.Subset(dataset, test_idx.tolist()),  batch_size=16)
print(f"Splits -> train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}")

Splits -> train=53 val=11 test=12


In [51]:
# ---- load phenotype for labels (if available)
pheno_url = "https://cptac-pancancer-data.s3.us-west-2.amazonaws.com/data_freeze_v1.2_reorganized/COAD/COAD_phenotype.txt"
ph = pd.read_csv(pheno_url, sep="\t", header=0)
print(ph)

# Heuristic: find a plausible subtype column (adjust to your truth column)
candidate_cols = [c for c in ph.columns if "CMS" in c.upper() or "SUBTYPE" in c.upper()]
label_col = candidate_cols[0] if candidate_cols else None

         idx  CIBERSORT_B_cell_naive  CIBERSORT_B_cell_memory  \
0    01CO001                0.170660                 0.025228   
1    01CO005                0.105408                 0.000000   
2    01CO006                0.400195                 0.485658   
3    01CO008                0.008900                 0.023195   
4    01CO013                0.122969                 0.000000   
..       ...                     ...                      ...   
105  21CO007                0.000000                 0.043618   
106  22CO004                0.000000                 0.099148   
107  22CO006                0.076124                 0.129644   
108  24CO005                0.050307                 0.073352   
109  27CO004                0.116184                 0.059798   

     CIBERSORT_B_cell_plasma  CIBERSORT_T_cell_CD8+  \
0                   0.149748               0.154231   
1                   0.047860               0.169397   
2                   0.000000               0.867519   