In [1]:
# Cell 1: config
from __future__ import annotations
import os
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Optional

# inputs from earlier steps
PRED_PARQUET   = "out/per_cell_predictions.parquet"   # from File 2
SHORTLIST_CSV  = "out/shortlist.csv"                  # from File 1
BULK_LONG      = "../../data/gdsc_bulk_overlap_genes.parquet"  # long table with drug, cell_line, IC50/LN_IC50

# outputs
OUT_EVAL_CSV   = "out/eval_metrics.csv"
OUT_RANK_CSV   = "out/per_line_rank_checks.csv"       # optional Spearman results
PLOTS_DIR      = "out/plots"
os.makedirs("out", exist_ok=True)
os.makedirs(PLOTS_DIR, exist_ok=True)

# thresholds for agreement labeling
AGREE_BAND      = (25, 75)   # in percentile
DISCORD_BAND    = (10, 90)

# plotting control
MAX_PLOTS = 100   # cap total plots if you want; set None to plot all
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)


In [2]:
# Cell 2: helpers

import warnings
from sklearn.mixture import GaussianMixture
from scipy.stats import spearmanr
import matplotlib.pyplot as plt

def load_bulk_long_to_wide(bulk_path: str) -> pd.DataFrame:
    """
    Load long-format bulk table (must include: drug, cell_line, IC50 or LN_IC50),
    return wide matrix of LN_IC50 with index=SANGER_MODEL_ID, columns=drug_id (str).
    """
    df = pd.read_parquet(bulk_path) if bulk_path.endswith(".parquet") else pd.read_csv(bulk_path)
    lower = {c.lower(): c for c in df.columns}
    drug_col = lower.get("drug") or lower.get("drug_id")
    line_col = lower.get("cell_line") or lower.get("sanger_model_id") or lower.get("line")
    ln_col   = lower.get("ln_ic50") or lower.get("lnic50")
    ic50_col = lower.get("ic50")
    if not drug_col or not line_col or (not ln_col and not ic50_col):
        raise KeyError(
            "Bulk long table must have columns for drug, cell_line, and IC50 or LN_IC50.\n"
            f"Found columns: {list(df.columns)}"
        )
    if ln_col is None:
        df["LN_IC50"] = np.log(df[ic50_col].astype(float))
        ln_col = "LN_IC50"
    df["_drug"] = df[drug_col].astype(str)
    df["_line"] = df[line_col].astype(str)
    wide = df.pivot_table(index="_line", columns="_drug", values=ln_col, aggfunc="mean")
    wide.index.name = "SANGER_MODEL_ID"
    wide.columns = wide.columns.astype(str)
    return wide

def percentile_of_value(arr: np.ndarray, value: float) -> float:
    y = np.asarray(arr, dtype=float)
    y = y[~np.isnan(y)]
    if y.size == 0 or np.isnan(value):
        return np.nan
    lt = np.sum(y < value)
    eq = np.sum(y == value)
    return 100.0 * (lt + 0.5*eq) / y.size



def fit_gmm_1to3(y: np.ndarray, random_state: int = 42) -> Dict[str, object]:
    y = np.asarray(y, dtype=float)
    y = y[~np.isnan(y)]
    if y.size == 0:
        return {"k": np.nan, "bic": np.nan, "means": np.array([]), "hi_frac": np.nan, "delta_means": np.nan}
    y = y.reshape(-1, 1)
    if y.shape[0] == 1:
        return {"k": 1, "bic": np.nan, "means": y.flatten(), "hi_frac": 1.0, "delta_means": 0.0}
    best = None
    for k in (1, 2, 3):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=random_state)
            gmm.fit(y)
            bic = gmm.bic(y)
        if best is None or bic < best["bic"]:
            best = {"k": k, "model": gmm, "bic": bic}
    gmm = best["model"]
    means = gmm.means_.flatten()
    resp = gmm.predict_proba(y)
    hi_idx = int(np.argmax(means))
    hi_frac = float(resp[:, hi_idx].mean())
    delta_means = float((np.max(means) - np.min(means)) if best["k"] > 1 else 0.0)
    return {"k": best["k"], "bic": best["bic"], "means": means, "hi_frac": hi_frac, "delta_means": delta_means}

def label_agreement(pct: float, agree_band=(25,75), discord_band=(10,90)) -> str:
    if np.isnan(pct):
        return "NA"
    if agree_band[0] <= pct <= agree_band[1]:
        return "agree"
    if pct < discord_band[0] or pct > discord_band[1]:
        return "discordant"
    return "borderline"


In [3]:
# Cell 3: load everything

# per-cell predictions
preds = pd.read_parquet(PRED_PARQUET)  # columns: cell_id, SANGER_MODEL_ID, drug_id, y_pred
preds["SANGER_MODEL_ID"] = preds["SANGER_MODEL_ID"].astype(str)
preds["drug_id"] = preds["drug_id"].astype(str)

# shortlist
shortlist = pd.read_csv(SHORTLIST_CSV, dtype={"SANGER_MODEL_ID": str, "low_drug": str, "high_drug": str})

# bulk wide matrix (LN_IC50)
ln_wide = load_bulk_long_to_wide(BULK_LONG)

# reduce ln_wide to the lines & drugs that appear in preds (robust to mismatches)
lines_in_preds = preds["SANGER_MODEL_ID"].unique().astype(str)
drugs_in_preds = preds["drug_id"].unique().astype(str)
ln_wide = ln_wide.loc[ln_wide.index.intersection(lines_in_preds), ln_wide.columns.intersection(drugs_in_preds)]

ln_wide.head()


_drug,1089,1096,1526,1845,1931,2038,2508,2515,2540,427
SANGER_MODEL_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
SIDM00097,4.729441,3.803311,4.253706,2.401597,5.846368,4.10101,1.376546,5.716241,3.234354,5.946308
SIDM00148,5.027403,3.393763,3.679521,3.595526,6.289278,5.356963,1.316832,6.211602,4.008466,3.924929
SIDM00630,4.125169,4.617345,2.671341,2.518441,4.100623,3.2624,1.92268,5.231693,3.755694,4.725384
SIDM00675,6.580458,4.071376,3.593059,3.876409,6.161872,5.047485,2.586296,5.741185,5.334885,4.02733
SIDM00866,5.882219,3.207134,0.733084,3.137863,6.634596,,1.643312,5.041295,4.762933,


In [4]:
# Cell 4: metrics per (line, drug)

rows: List[dict] = []

grouped = preds.groupby(["SANGER_MODEL_ID", "drug_id"])
for (line, drug), df in grouped:
    # Cell 4 — inside the for (line, drug) loop
    y = df["y_pred"].to_numpy(dtype=float)
    y = y[~np.isnan(y)]
    if y.size == 0:
        bulk_val = np.nan
        pct = np.nan
        gmm = {"k": np.nan, "delta_means": np.nan, "hi_frac": np.nan, "means": np.array([])}
    else:
        bulk_val = ln_wide.loc[line, drug] if (line in ln_wide.index and drug in ln_wide.columns) else np.nan
        pct = percentile_of_value(y, bulk_val)
        gmm = fit_gmm_1to3(y, random_state=RANDOM_STATE)

    rows.append({
        "SANGER_MODEL_ID": line,
        "drug_id": drug,
        "n_cells": int(len(y)),
        "pred_mean": float(np.mean(y)) if len(y) else np.nan,
        "pred_median": float(np.median(y)) if len(y) else np.nan,
        "pred_q10": float(np.quantile(y, 0.10)) if len(y) else np.nan,
        "pred_q90": float(np.quantile(y, 0.90)) if len(y) else np.nan,
        "bulk_LN_IC50": float(bulk_val) if not pd.isna(bulk_val) else np.nan,
        "bulk_percentile_in_pred": float(pct) if not pd.isna(pct) else np.nan,
        "agreement_label": label_agreement(pct, AGREE_BAND, DISCORD_BAND),
        "gmm_k": gmm["k"],
        "gmm_delta_means": gmm["delta_means"],
        "gmm_hi_frac": gmm["hi_frac"],
    })

eval_df = pd.DataFrame(rows).sort_values(["SANGER_MODEL_ID", "drug_id"]).reset_index(drop=True)
eval_df.head(10)


Unnamed: 0,SANGER_MODEL_ID,drug_id,n_cells,pred_mean,pred_median,pred_q10,pred_q90,bulk_LN_IC50,bulk_percentile_in_pred,agreement_label,gmm_k,gmm_delta_means,gmm_hi_frac
0,SIDM00097,2540,824,10.587073,10.621323,8.476162,12.720927,3.234354,0.121359,discordant,2,0.33687,0.622304
1,SIDM00097,427,824,10.504237,10.308697,8.932986,12.224777,5.946308,0.0,discordant,2,1.656656,0.290479
2,SIDM00148,1931,818,21.27248,21.177653,18.980331,23.670029,6.289278,0.0,discordant,1,0.0,1.0
3,SIDM00148,427,818,9.950962,9.849729,8.293121,11.72878,3.924929,0.0,discordant,2,1.658749,0.368867
4,SIDM00630,1096,871,19.27333,18.7756,15.491406,23.93388,4.617345,0.0,discordant,2,5.189946,0.313066
5,SIDM00630,2038,871,9.787884,9.679008,7.5189,12.254015,3.2624,0.0,discordant,1,0.0,1.0
6,SIDM00675,1089,629,15.141143,15.134066,12.254394,18.165031,6.580458,0.0,discordant,1,0.0,1.0
7,SIDM00675,427,629,17.118556,17.122559,15.440026,18.72507,4.02733,0.0,discordant,1,0.0,1.0
8,SIDM00866,1526,1279,-8.559445,-8.75219,-10.815838,-5.975373,0.733084,100.0,discordant,2,2.405403,0.393665
9,SIDM00866,1931,1279,22.62421,22.745712,19.474729,25.595569,6.634596,0.0,discordant,1,0.0,1.0


In [5]:
# Cell 5 (fixed): per-line rank check (Spearman across drugs), robust to empty case

from scipy.stats import spearmanr

rank_rows: List[dict] = []

for line, sub in eval_df.groupby("SANGER_MODEL_ID"):
    # need >= 2 drugs with both mean and bulk value
    mask = sub["pred_mean"].notna() & sub["bulk_LN_IC50"].notna()
    sub2 = sub.loc[mask, ["drug_id", "pred_mean", "bulk_LN_IC50"]]
    if len(sub2) >= 2:
        rho, p = spearmanr(sub2["pred_mean"].values, sub2["bulk_LN_IC50"].values)
        rank_rows.append({
            "SANGER_MODEL_ID": line,
            "n_drugs": int(len(sub2)),
            "spearman_rho": float(rho),
            "spearman_p": float(p),
        })

if rank_rows:
    rank_df = pd.DataFrame(rank_rows).sort_values("spearman_rho", ascending=False).reset_index(drop=True)
else:
    # Cell 5 — else branch
    rank_df = pd.DataFrame(columns=["SANGER_MODEL_ID","n_drugs","spearman_rho","spearman_p"])
    print("No lines had ≥2 drugs with both SC mean and bulk LN_IC50; skipping rank checks.")


rank_df.head(10)


Unnamed: 0,SANGER_MODEL_ID,n_drugs,spearman_rho,spearman_p
0,SIDM00148,2,1.0,
1,SIDM00866,2,1.0,
2,SIDM00630,2,1.0,
3,SIDM00885,2,1.0,
4,SIDM00872,2,1.0,
5,SIDM01037,2,1.0,
6,SIDM00928,2,1.0,
7,SIDM00893,2,1.0,
8,SIDM00920,2,1.0,
9,SIDM00675,2,-1.0,


In [6]:
# Cell 6: save results
eval_df.to_csv(OUT_EVAL_CSV, index=False)
print(f"Wrote evaluation metrics → {OUT_EVAL_CSV} ({len(eval_df)} rows)")

if len(rank_df):
    rank_df.to_csv(OUT_RANK_CSV, index=False)
    print(f"Wrote per-line rank checks → {OUT_RANK_CSV}")
else:
    print("No per-line rank checks (need ≥3 drugs with data per line).")


Wrote evaluation metrics → out/eval_metrics.csv (26 rows)
Wrote per-line rank checks → out/per_line_rank_checks.csv


In [7]:
# Cell 7: distribution plots

def plot_distribution_with_bulk(line: str, drug: str, df: pd.DataFrame, bulk_wide: pd.DataFrame, out_dir: str) -> Optional[str]:
    sub = df[(df["SANGER_MODEL_ID"] == line) & (df["drug_id"] == drug)]
    if sub.empty:
        return None
    # Cell 7 — early return if no finite preds
    y = sub["y_pred"].to_numpy(dtype=float)
    y = y[~np.isnan(y)]
    if y.size == 0:
        return None
    bulk_val = bulk_wide.loc[line, drug] if (line in bulk_wide.index and drug in bulk_wide.columns) else np.nan

    # fit GMM for viz
    #gmm = fit_gmm_1to3(y, random_state=RANDOM_STATE)
    #means = np.sort(gmm["means"]) if isinstance(gmm["means"], np.ndarray) else []

    plt.figure(figsize=(6,4))
    plt.hist(y, bins=40, alpha=0.7, density=True)
    if not np.isnan(bulk_val):
        plt.axvline(bulk_val, linestyle="--", linewidth=2, label=f"Bulk LN_IC50 = {bulk_val:.2f}")
    #for m in means:
     #   plt.axvline(m, linestyle=":", linewidth=1)
    plt.title(f"{line} — drug {drug}")
    plt.xlabel("Per-cell predicted LN_IC50")
    plt.ylabel("Density")
    if not np.isnan(bulk_val):
        plt.legend(loc="best")
    fname = os.path.join(out_dir, f"dist_{line}_{drug}.png")
    plt.tight_layout()
    plt.savefig(fname, dpi=150)
    plt.close()
    return fname

made = 0
for (line, drug), _ in preds.groupby(["SANGER_MODEL_ID", "drug_id"]):
    if MAX_PLOTS and made >= MAX_PLOTS:
        break
    path = plot_distribution_with_bulk(line, drug, preds, ln_wide, PLOTS_DIR)
    if path:
        made += 1

print(f"Saved {made} distribution plots to {PLOTS_DIR}")


Saved 26 distribution plots to out/plots


In [8]:
# Cell 8: error-bar plots per line for the drugs listed in the shortlist row

def plot_line_errorbars(line: str, shortlist_row: pd.Series, eval_table: pd.DataFrame, out_dir: str) -> Optional[str]:
    drugs = [str(shortlist_row["low_drug"]), str(shortlist_row["high_drug"])]
    sub = eval_table[(eval_table["SANGER_MODEL_ID"] == line) & (eval_table["drug_id"].isin(drugs))]
    if sub.empty:
        return None
    # order: low then high
    sub = sub.set_index("drug_id").loc[drugs].reset_index()

    x = np.arange(len(sub))
    means = sub["pred_mean"].values
    q10 = sub["pred_q10"].values
    q90 = sub["pred_q90"].values
    bulk = sub["bulk_LN_IC50"].values
    yerr = np.vstack([means - q10, q90 - means])

    plt.figure(figsize=(6,4))
    plt.errorbar(x, means, yerr=yerr, fmt="o", capsize=3, label="SC mean ± q10–q90")
    plt.scatter(x, bulk, marker="x", s=60, label="Bulk LN_IC50")
    plt.xticks(x, [f"drug {d}" for d in sub["drug_id"]])
    plt.title(f"{line}: low/high drugs")
    plt.ylabel("LN_IC50")
    plt.legend(loc="best")
    fname = os.path.join(out_dir, f"errorbar_{line}.png")
    plt.tight_layout()
    plt.savefig(fname, dpi=150)
    plt.close()
    return fname

made = 0
for _, row in shortlist.iterrows():
    line = str(row["SANGER_MODEL_ID"])
    p = plot_line_errorbars(line, row, eval_df, PLOTS_DIR)
    if p:
        made += 1
    if MAX_PLOTS and made >= MAX_PLOTS:
        break

print(f"Saved {made} error-bar plots to {PLOTS_DIR}")


Saved 13 error-bar plots to out/plots
