In [6]:
# Cell 1: config
from __future__ import annotations
import os
import numpy as np
import pandas as pd
import joblib
from typing import Dict, Sequence, Optional, List

# --- your files ---
SC_PARQUET = "../../data/filtered_datasets/sc_overlap_genes.parquet"   # cells × genes (+ SANGER_MODEL_ID available)
BULK_LONG  = "../../data/gdsc_bulk_overlap_genes.parquet"              # long: drug, cell_line, IC50 (or LN_IC50)

# trained model bundles (optional filter)
MODEL_DIR  = "models"     # each file: {drug_id}.joblib with {'model','scaler','gene_cols'}
REQUIRE_MODEL_BUNDLE = True   # set False if you don't want to require a bundle to shortlist

# output
OUT_CSV    = "out/shortlist.csv"
os.makedirs(os.path.dirname(OUT_CSV), exist_ok=True)

# triage knobs
BEST_DRUGS           = [1845, 2540, 2038, 2508, 1096, 1931, 2515, 1089, 427, 1526]  # [] = use all drugs in bulk
MIN_CELLS            = 200     # min cells per line to consider
MIN_LINES_PER_DRUG   = 12      # min lines with non-null LN_IC50 per drug to compute quartiles
TOP_K                = 0       # 0 = keep all; >0 keep top-K by contrast

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)


In [16]:
# Cell 2 (fixed): helpers that read SANGER_MODEL_ID from the index

from typing import Dict, Sequence, Optional, List

def load_sc_df(sc_parquet: str) -> pd.DataFrame:
    """Load single-cell matrix. Index is SANGER_MODEL_ID (one row per cell)."""
    df = pd.read_parquet(sc_parquet)
    # ensure the index is named SANGER_MODEL_ID (it already is in your printout)
    if df.index.name is None:
        df.index.name = "SANGER_MODEL_ID"
    return df

def extract_line_map_from_sc_index(sc: pd.DataFrame) -> pd.Series:
    """
    Build a Series mapping each row (cell) to its SANGER_MODEL_ID
    assuming the *index name* is 'SANGER_MODEL_ID' and values are the line IDs.
    """
    if sc.index.name != "SANGER_MODEL_ID":
        raise KeyError(
            f"Expected sc.index.name == 'SANGER_MODEL_ID', found {sc.index.name!r}. "
            "Rename the index or adjust this helper."
        )
    # one entry per row/cell, value = the line ID from the index
    s = pd.Series(sc.index.astype(str).values, index=sc.index)
    s.index.name = "cell_row_index"   # keeps a handle to each row
    s.name = "SANGER_MODEL_ID"
    return s

def load_bulk_long(bulk_path: str) -> pd.DataFrame:
    """
    Load long-format bulk table with columns like: drug, cell_line, IC50 (or LN_IC50).
    Return wide LN_IC50 matrix: index = SANGER_MODEL_ID, cols = drug_id (as str).
    """
    df = pd.read_parquet(bulk_path) if bulk_path.endswith(".parquet") else pd.read_csv(bulk_path)
    lower = {c.lower(): c for c in df.columns}
    drug_col = lower.get("drug") or lower.get("drug_id")
    line_col = lower.get("cell_line") or lower.get("sanger_model_id") or lower.get("line")
    ln_col   = lower.get("ln_ic50") or lower.get("lnic50")
    ic50_col = lower.get("ic50")
    if not drug_col or not line_col or (not ln_col and not ic50_col):
        raise KeyError(
            "Bulk long table must have columns for drug, cell_line, and IC50 or LN_IC50.\n"
            f"Found columns: {list(df.columns)}"
        )
    if ln_col is None:
        df["LN_IC50"] = np.log(df[ic50_col].astype(float))
        ln_col = "LN_IC50"
    df["_drug"] = df[drug_col].astype(str)
    df["_line"] = df[line_col].astype(str)
    wide = df.pivot_table(index="_line", columns="_drug", values=ln_col, aggfunc="mean")
    wide.index.name = "SANGER_MODEL_ID"
    wide.columns = wide.columns.astype(str)
    return wide

# Cell A: robust model discovery (supports elasticnet_drug{ID}.joblib)

import re

def available_models(model_dir: str, allowed_drugs: Optional[Sequence[str]] = None) -> Dict[str, str]:
    """
    Discover trained bundles in model_dir and return {drug_id_str: path}.
    Supports filenames like:
      - 427.joblib
      - drug427.joblib
      - elasticnet_427.joblib
      - elasticnet_drug427.joblib
      - anything..._drug{digits}.joblib
    """
    paths: Dict[str, str] = {}
    if not (model_dir and os.path.isdir(model_dir)):
        return paths

    allow = None if allowed_drugs is None else set(map(str, allowed_drugs))

    for fname in os.listdir(model_dir):
        if not fname.endswith(".joblib"):
            continue
        # try to extract the trailing digits as the drug id
        # examples matched:
        #   elasticnet_drug1845.joblib  -> 1845
        #   drug427.joblib              -> 427
        #   1526.joblib                 -> 1526
        m = re.search(r'(\d+)\.joblib$', fname)
        if not m:
            continue
        drug_id = m.group(1)  # string
        if (allow is None) or (drug_id in allow):
            paths[drug_id] = os.path.join(model_dir, fname)
    return paths


def drug_quartiles(ln_ic50_wide: pd.DataFrame) -> pd.DataFrame:
    q = ln_ic50_wide.quantile([0.25, 0.75], axis=0, interpolation="linear")
    q.index = ["q25", "q75"]
    return q


In [17]:
# Cell 3 (fixed): load sc, derive counts from index, load bulk, harmonize

# single-cell (rows = cells; index value = SANGER_MODEL_ID)
sc = load_sc_df(SC_PARQUET)

# per-cell mapping to line IDs (from the index)
cell_to_line = extract_line_map_from_sc_index(sc)

# counts per line and filter by MIN_CELLS
counts = cell_to_line.value_counts().rename_axis("SANGER_MODEL_ID").rename("n_cells").astype(int)
eligible_lines = counts[counts >= MIN_CELLS].index.astype(str)
if len(eligible_lines) == 0:
    raise ValueError("No lines pass the min-cells threshold.")

# bulk → wide LN_IC50, filtered to eligible lines
ln_wide = load_bulk_long(BULK_LONG)
ln_wide = ln_wide.loc[ln_wide.index.intersection(eligible_lines)]
if ln_wide.shape[0] == 0:
    raise ValueError("No overlap between bulk lines and eligible lines.")


In [18]:
# Cell 4: drug filtering (best list + optional presence of model bundles)

# optional BEST_DRUGS filter
if BEST_DRUGS:
    ln_wide = ln_wide.loc[:, ln_wide.columns.intersection([str(d) for d in BEST_DRUGS])]
    if ln_wide.shape[1] == 0:
        raise ValueError("No overlap between LN_IC50 drugs and BEST_DRUGS after filtering.")

# optionally require trained bundles
if REQUIRE_MODEL_BUNDLE:
    model_paths = available_models(MODEL_DIR, allowed_drugs=ln_wide.columns)
    if not model_paths:
        raise ValueError("REQUIRE_MODEL_BUNDLE=True but no {drug}.joblib found for these drugs.")
    keep = [d for d in ln_wide.columns if d in model_paths]
    ln_wide = ln_wide[keep]
    if ln_wide.shape[1] == 0:
        raise ValueError("After requiring model bundles, no drugs remain.")


In [19]:
# Cell 5: guard on per-drug sample size

valid_drugs = [c for c in ln_wide.columns if ln_wide[c].notna().sum() >= MIN_LINES_PER_DRUG]
ln_wide = ln_wide[valid_drugs]
if ln_wide.shape[1] == 0:
    raise ValueError(
        "No drug has sufficient lines to compute quartiles. "
        "Lower MIN_LINES_PER_DRUG or expand BEST_DRUGS."
    )

q = drug_quartiles(ln_wide)  # rows: q25,q75
q.head()


_drug,1089,1096,1526,1845,1931,2038,2508,2515,2540,427
q25,3.867196,2.848809,1.660317,1.250459,5.005812,4.003188,0.983825,5.003191,3.56767,4.247605
q75,5.607646,4.469358,3.593059,3.130015,6.225575,4.996259,2.638681,5.908472,4.825648,5.133296


In [20]:
# Cell 6: choose (low_drug, high_drug) maximizing contrast per line

rows: List[dict] = []

for line, n_cells in counts[counts.index.isin(ln_wide.index)].items():
    vals = ln_wide.loc[line].dropna()
    if vals.empty:
        continue

    low_mask = vals <= q.loc["q25", vals.index]
    high_mask = vals >= q.loc["q75", vals.index]
    lows = vals[low_mask]
    highs = vals[high_mask]
    if lows.empty or highs.empty:
        continue

    best = None
    for ld, lv in lows.items():
        for hd, hv in highs.items():
            contrast = float(hv - lv)
            if (best is None) or (contrast > best[0]):
                best = (contrast, ld, lv, hd, hv)

    if best is None:
        continue

    contrast, ld, lv, hd, hv = best
    rows.append({
        "SANGER_MODEL_ID": str(line),
        "n_cells": int(n_cells),
        "low_drug": str(ld),
        "low_value": float(lv),
        "high_drug": str(hd),
        "high_value": float(hv),
        "contrast": float(contrast),
    })

shortlist_df = pd.DataFrame(rows)
if shortlist_df.empty:
    raise ValueError("No (low, high) pairs found. Relax thresholds or expand drug set.")

shortlist_df = shortlist_df.sort_values(["contrast", "n_cells"], ascending=[False, False]).reset_index(drop=True)
if TOP_K and TOP_K > 0:
    shortlist_df = shortlist_df.head(TOP_K)

shortlist_df.head(10)


Unnamed: 0,SANGER_MODEL_ID,n_cells,low_drug,low_value,high_drug,high_value,contrast
0,SIDM00872,1623,1845,-2.378115,1089,5.773363,8.151478
1,SIDM01037,568,1845,-0.728895,1096,5.241832,5.970727
2,SIDM00866,1280,1526,0.733084,1931,6.634596,5.901512
3,SIDM01056,1316,1845,0.322898,2515,6.075759,5.752861
4,SIDM00928,995,1526,0.975542,427,5.855807,4.880265
5,SIDM00920,823,2508,0.819275,427,5.133296,4.314021
6,SIDM00097,825,2540,3.234354,427,5.946308,2.711954
7,SIDM00675,629,427,4.02733,1089,6.580458,2.553128
8,SIDM00148,839,427,3.924929,1931,6.289278,2.364349
9,SIDM00630,877,2038,3.2624,1096,4.617345,1.354945


In [21]:
# Cell 7: save
shortlist_df.to_csv(OUT_CSV, index=False)
print(f"Wrote shortlist with {len(shortlist_df)} lines → {OUT_CSV}")


Wrote shortlist with 13 lines → out/shortlist.csv
