In [1]:
# Paths already used in your notebook
SC_PARQUET   = "../../data/filtered_datasets/sc_overlap_genes.parquet"   # cells × genes, index=SANGER_MODEL_ID
BULK_LONG    = "../../data/gdsc_bulk_overlap_genes.parquet"              # long bulk with LN_IC50 + SIDG genes as columns
MODEL_DIR    = "models"

# Reuse helpers from before, or keep these minimal ones:
import numpy as np, pandas as pd, os, re, joblib
pd.set_option("display.width", 160)
pd.set_option("display.max_columns", 30)

def load_sc_df(path):
    df = pd.read_parquet(path)
    if df.index.name != "SANGER_MODEL_ID":
        df.index.name = "SANGER_MODEL_ID"
    # add synthetic cell_id
    if "cell_id" not in df.columns:
        seq = df.groupby(level=0).cumcount().astype(str)
        df.insert(0, "cell_id", df.index.astype(str) + "__" + seq)
    return df

def load_bulk_wide_with_genes(path):
    df = pd.read_parquet(path) if path.endswith(".parquet") else pd.read_csv(path)
    # pull out LN_IC50 long view (unused here) and a features matrix (samples × genes)
    # Assume columns: SANGER_MODEL_ID, DRUG_ID, LN_IC50, SIDG...
    # Keep the expression columns (SIDG*) by pattern:
    feature_cols = [c for c in df.columns if c.startswith("SIDG")]
    expr = df[["SANGER_MODEL_ID"] + feature_cols].drop_duplicates("SANGER_MODEL_ID").set_index("SANGER_MODEL_ID")
    expr.index = expr.index.astype(str)
    return expr  # voom-like bulk features used to train


In [2]:
# Build pseudo-bulk by averaging cells per line (use mean of log1p CP10k; median also works)
sc = load_sc_df(SC_PARQUET)
gene_cols = [c for c in sc.columns if c.startswith("SIDG")]
sc_g = sc[gene_cols].astype(float)

sc_pseudobulk = sc_g.groupby(sc.index).mean()  # index=SANGER_MODEL_ID, cols=genes
print("pseudo-bulk shape:", sc_pseudobulk.shape)


pseudo-bulk shape: (28, 2044)


In [3]:
bulk_voom = load_bulk_wide_with_genes(BULK_LONG)
# intersect lines and genes
common_lines = sc_pseudobulk.index.intersection(bulk_voom.index)
common_genes = sc_pseudobulk.columns.intersection(bulk_voom.columns)
Xpb = sc_pseudobulk.loc[common_lines, common_genes]
Ybk = bulk_voom.loc[common_lines, common_genes]

# closed-form per-gene regression (OLS): a = cov(X,Y)/var(X), b = mean(Y) - a*mean(X)
mu_x = Xpb.mean(axis=0)
mu_y = Ybk.mean(axis=0)
var_x = Xpb.var(axis=0, ddof=1)
cov_xy = ((Xpb - mu_x) * (Ybk - mu_y)).mean(axis=0) * (len(common_lines)/(len(common_lines)-1))  # unbiased

a = (cov_xy / var_x).replace([np.inf, -np.inf], np.nan).fillna(0.0)
b = (mu_y - a * mu_x).fillna(mu_y)  # if a==0 fallback to mean alignment

calib_gene = pd.DataFrame({"slope": a, "intercept": b})
print("calibration per-gene (head):")
display(calib_gene.head())


calibration per-gene (head):


Unnamed: 0,slope,intercept
SIDG00004,267.503082,-2.873278
SIDG00036,2.591734,7.461721
SIDG00100,1.501939,6.148303
SIDG00101,0.785775,6.585975
SIDG00147,4.922744,4.582056


In [5]:
# Cell B3 (fixed): apply gene-wise map to every cell and predict (helpers included)

import os, re, joblib
import numpy as np
import pandas as pd

# helpers (self-contained)
def available_models(model_dir, allowed_drugs=None):
    paths = {}
    if not (model_dir and os.path.isdir(model_dir)):
        return paths
    allow = None if allowed_drugs is None else set(map(str, allowed_drugs))
    for fname in os.listdir(model_dir):
        if not fname.endswith(".joblib"):
            continue
        m = re.search(r"(\d+)\.joblib$", fname)  # capture trailing digits
        if not m:
            continue
        d = m.group(1)
        if allow is None or d in allow:
            paths[d] = os.path.join(model_dir, fname)
    return paths

def load_bundle(path):
    b = joblib.load(path)
    for k in ("model", "scaler", "gene_cols"):
        if k not in b:
            raise KeyError(f"Bundle {os.path.basename(path)} missing key: {k}")
    b["gene_cols"] = [str(g) for g in b["gene_cols"]]
    return b

# --- apply the per-gene affine mapping learned in B2 to ALL cells ---
# expects: SC_PARQUET, calib_gene (with 'slope','intercept'), common_genes from B2

# reload SC to be safe (uses same synthetic cell_id logic as before)
sc = pd.read_parquet(SC_PARQUET)
if sc.index.name != "SANGER_MODEL_ID":
    sc.index.name = "SANGER_MODEL_ID"
if "cell_id" not in sc.columns:
    seq = sc.groupby(level=0).cumcount().astype(str)
    sc.insert(0, "cell_id", sc.index.astype(str) + "__" + seq)

# keep only genes we have calibration for
genes_use = list(set(common_genes).intersection(calib_gene.index))
sc_g = sc[genes_use].astype(float)

A = calib_gene.loc[genes_use, "slope"]
B = calib_gene.loc[genes_use, "intercept"]

# x' = a_g * x + b_g
sc_bulkified = sc_g.mul(A, axis=1).add(B, axis=1)

# --- predict with existing bundles on bulkified features ---
model_paths = available_models(MODEL_DIR, allowed_drugs=None)
if not model_paths:
    raise ValueError(f"No model bundles found in {MODEL_DIR}")

pred_rows = []
skipped = []

for d, p in model_paths.items():
    b = load_bundle(p)
    want = [g for g in b["gene_cols"] if g in sc_bulkified.columns]
    miss = [g for g in b["gene_cols"] if g not in sc_bulkified.columns]
    if miss:
        skipped.append((d, len(miss)))
        continue
    X = sc_bulkified[b["gene_cols"]].values
    Xs = b["scaler"].transform(X)
    y = b["model"].predict(Xs)
    pred_rows.append(pd.DataFrame({
        "cell_id": sc["cell_id"].values,
        "SANGER_MODEL_ID": sc.index.astype(str).values,
        "drug_id": d,
        "y_pred": y
    }))

preds_bulkified = (pd.concat(pred_rows, ignore_index=True)
                   if pred_rows else pd.DataFrame(columns=["cell_id","SANGER_MODEL_ID","drug_id","y_pred"]))

print(f"bulkified preds shape: {preds_bulkified.shape}")
if skipped:
    print("[info] skipped bundles due to missing genes:", skipped[:10], 
          ("... (truncated)" if len(skipped) > 10 else ""))


bulkified preds shape: (308270, 4)


In [7]:
# Safe agreement comparison for whichever prediction tables exist

import numpy as np
import pandas as pd

# Build bulk LN_IC50 wide from your long table (robust to parquet/csv)
def build_bulk_LNwide(path):
    df = pd.read_parquet(path) if path.endswith(".parquet") else pd.read_csv(path)
    # expect columns like: SANGER_MODEL_ID, DRUG_ID, LN_IC50
    # fall back to lower/variant names if needed
    cols = {c.lower(): c for c in df.columns}
    line_col = cols.get("sanger_model_id") or cols.get("cell_line") or "SANGER_MODEL_ID"
    drug_col = cols.get("drug_id") or cols.get("drug") or "DRUG_ID"
    ln_col   = cols.get("ln_ic50") or cols.get("lnic50") or "LN_IC50"
    if not {line_col, drug_col, ln_col} <= set(df.columns):
        raise KeyError(f"Bulk long is missing required columns. Found: {list(df.columns)}")
    w = df.pivot_table(index=line_col, columns=drug_col, values=ln_col, aggfunc="mean")
    w.index = w.index.astype(str); w.columns = w.columns.astype(str)
    return w

def percentile_of_value(arr, value):
    arr = np.asarray(arr, float)
    if arr.size == 0 or np.isnan(value): return np.nan
    lt = np.sum(arr < value); eq = np.sum(arr == value)
    return 100.0 * (lt + 0.5*eq) / arr.size

def agreement_rate(preds_df, lnwide, agree_band=(25,75)):
    if preds_df is None or len(preds_df)==0: return np.nan
    ok = 0; total = 0
    for (line, drug), sub in preds_df.groupby(["SANGER_MODEL_ID","drug_id"]):
        if (line in lnwide.index) and (drug in lnwide.columns):
            pct = percentile_of_value(sub["y_pred"].values, float(lnwide.loc[line, drug]))
            if not np.isnan(pct):
                total += 1
                ok += (agree_band[0] <= pct <= agree_band[1])
    return (ok / total) if total else np.nan

LNwide = build_bulk_LNwide(BULK_LONG)

# Probe which prediction DataFrames exist in the namespace
candidates = []
candidates.append(("original preds",    globals().get("preds", None)))
candidates.append(("moment matched",    globals().get("preds_mm", None)))
candidates.append(("bulkified (gene)",  globals().get("preds_bulkified", None)))

print("=== Agreement rate (bulk within per-cell [25,75]) ===")
for name, df in candidates:
    try:
        rate = agreement_rate(df, LNwide)
    except Exception as e:
        rate = f"error: {e}"
    print(f"{name:>22}: {rate}")


=== Agreement rate (bulk within per-cell [25,75]) ===
        original preds: nan
        moment matched: nan
      bulkified (gene): 0.12043795620437957


In [8]:
# Cell 1: choose the best baseline predictions automatically

import numpy as np, pandas as pd, os, matplotlib.pyplot as plt

# helper to build bulk LN_IC50 wide
def build_bulk_LNwide(path):
    df = pd.read_parquet(path) if path.endswith(".parquet") else pd.read_csv(path)
    cols = {c.lower(): c for c in df.columns}
    line_col = cols.get("sanger_model_id") or cols.get("cell_line") or "SANGER_MODEL_ID"
    drug_col = cols.get("drug_id") or cols.get("drug") or "DRUG_ID"
    ln_col   = cols.get("ln_ic50") or cols.get("lnic50") or "LN_IC50"
    w = df.pivot_table(index=line_col, columns=drug_col, values=ln_col, aggfunc="mean")
    w.index = w.index.astype(str); w.columns = w.columns.astype(str)
    return w

def percentile_of_value(arr, value):
    arr = np.asarray(arr, float)
    if arr.size == 0 or np.isnan(value): return np.nan
    lt = np.sum(arr < value); eq = np.sum(arr == value)
    return 100.0 * (lt + 0.5*eq) / arr.size

def agreement_rate(preds_df, lnwide, band=(25,75)):
    if preds_df is None or len(preds_df)==0: return np.nan
    ok = 0; total = 0
    for (line, drug), sub in preds_df.groupby(["SANGER_MODEL_ID","drug_id"]):
        if (line in lnwide.index) and (drug in lnwide.columns):
            pct = percentile_of_value(sub["y_pred"].values, float(lnwide.loc[line, drug]))
            if not np.isnan(pct):
                total += 1
                ok += (band[0] <= pct <= band[1])
    return (ok / total) if total else np.nan

LNwide = build_bulk_LNwide(BULK_LONG)

# candidates from your session (some may not exist)
candidates = {
    "original":       globals().get("preds", None),
    "moment_matched": globals().get("preds_mm", None),
    "bulkified_gene": globals().get("preds_bulkified", None),
}

rates = {name: agreement_rate(df, LNwide) for name, df in candidates.items() if df is not None and len(df)}
best_name, best_df = max(rates.items(), key=lambda kv: (kv[1] if kv[1]==kv[1] else -1)) if rates else (None, None)

print("Baseline agreement rates (bulk within [25,75]%):")
for k,v in rates.items(): print(f"  {k:>15}: {v:.3f}")
print(f"\nChosen baseline: {best_name}")
BASE_PREDS = candidates.get(best_name, None)
assert BASE_PREDS is not None and len(BASE_PREDS), "No predictions available to proceed."


Baseline agreement rates (bulk within [25,75]%):
   bulkified_gene: 0.120

Chosen baseline: bulkified_gene


In [11]:
# Cell 2 (fixed): compare calibration modes and pick the best — robust to column names

# Build global linear calibration on line-level means
Xg = (BASE_PREDS
      .assign(SANGER_MODEL_ID=lambda d: d["SANGER_MODEL_ID"].astype(str),
              drug_id=lambda d: d["drug_id"].astype(str))
      .groupby(["SANGER_MODEL_ID","drug_id"])["y_pred"].mean())

# Stack LNwide and rename its two key columns whatever their names are
bulk_pairs = LNwide.stack()

bp = bulk_pairs.rename("bulk").reset_index()
# Figure out the current names
line_col = LNwide.index.name or "SANGER_MODEL_ID"
drug_col = LNwide.columns.name or "DRUG_ID"
# Rename to our canonical names
bp = bp.rename(columns={line_col:"SANGER_MODEL_ID", drug_col:"drug_id"})
bp["SANGER_MODEL_ID"] = bp["SANGER_MODEL_ID"].astype(str)
bp["drug_id"] = bp["drug_id"].astype(str)

merged_g = (
    Xg.rename("sc_mean").reset_index()
      .merge(bp, on=["SANGER_MODEL_ID","drug_id"], how="inner")
)

# Global fit bulk ≈ a_g * sc_mean + b_g
if len(merged_g) >= 2 and np.var(merged_g["sc_mean"].values, ddof=1) > 0:
    a_g = np.cov(merged_g["sc_mean"], merged_g["bulk"], ddof=1)[0,1] / np.var(merged_g["sc_mean"], ddof=1)
else:
    a_g = 1.0
b_g = merged_g["bulk"].mean() - a_g * merged_g["sc_mean"].mean()

# Per-drug calibration (fit only if >=3 lines with variance)
rows = []
for d, sub in merged_g.groupby("drug_id"):
    x = sub["sc_mean"].values
    y = sub["bulk"].values
    if len(sub) >= 3 and np.var(x, ddof=1) > 0:
        a = np.cov(x, y, ddof=1)[0,1] / np.var(x, ddof=1)
        b = y.mean() - a * x.mean()
        rows.append({"drug_id": str(d), "n_lines": int(len(sub)), "slope": float(a), "intercept": float(b)})
calib_df = pd.DataFrame(rows).set_index("drug_id") if rows else pd.DataFrame(columns=["slope","intercept","n_lines"])

# Generate calibrated columns
preds_cal = BASE_PREDS.copy()
preds_cal["SANGER_MODEL_ID"] = preds_cal["SANGER_MODEL_ID"].astype(str)
preds_cal["drug_id"] = preds_cal["drug_id"].astype(str)
preds_cal["y_none"]   = preds_cal["y_pred"].astype(float)
preds_cal["y_global"] = a_g * preds_cal["y_pred"].astype(float) + b_g

preds_cal["y_perdrug"] = preds_cal["y_pred"].astype(float)
preds_cal["y_shrunk"]  = preds_cal["y_pred"].astype(float)
LAMBDA = 5.0  # shrinkage toward global

if len(calib_df):
    for d, row in calib_df.iterrows():
        m = preds_cal["drug_id"].eq(d)
        a_d, b_d, n = float(row["slope"]), float(row["intercept"]), int(row["n_lines"])
        preds_cal.loc[m, "y_perdrug"] = a_d * preds_cal.loc[m, "y_pred"].astype(float) + b_d
        w = n / (n + LAMBDA)
        a_s = w*a_d + (1-w)*a_g
        b_s = w*b_d + (1-w)*b_g
        preds_cal.loc[m, "y_shrunk"] = a_s * preds_cal.loc[m, "y_pred"].astype(float) + b_s

def percentile_of_value(arr, value):
    arr = np.asarray(arr, float)
    if arr.size == 0 or np.isnan(value): return np.nan
    lt = np.sum(arr < value); eq = np.sum(arr == value)
    return 100.0 * (lt + 0.5*eq) / arr.size

def agreement_rate_col(col):
    ok = 0; total = 0
    for (line, drug), sub in preds_cal.groupby(["SANGER_MODEL_ID","drug_id"]):
        if (line in LNwide.index) and (drug in LNwide.columns):
            pct = percentile_of_value(sub[col].values, float(LNwide.loc[line, drug]))
            if not np.isnan(pct):
                total += 1
                ok += (25 <= pct <= 75)
    return (ok/total) if total else np.nan

modes = ["y_none","y_global","y_perdrug","y_shrunk"]
mode_rates = {m: agreement_rate_col(m) for m in modes if m in preds_cal.columns}
best_mode = max(mode_rates.items(), key=lambda kv: (kv[1] if kv[1]==kv[1] else -1))[0]

print("Calibration mode agreement rates:")
for m,r in mode_rates.items(): print(f"  {m:>9}: {r:.3f}")
print(f"\nChosen calibration mode: {best_mode}")


Calibration mode agreement rates:
     y_none: 0.120
   y_global: 0.117
  y_perdrug: 0.190
   y_shrunk: 0.175

Chosen calibration mode: y_perdrug


In [12]:
# Cell 3: finalize outputs based on the chosen mode

OUT_DIR = "out"
os.makedirs(OUT_DIR, exist_ok=True)
PLOTS_DIR = os.path.join(OUT_DIR, "plots_final")
os.makedirs(PLOTS_DIR, exist_ok=True)

# 3a) write calibration table you can feed back into File 2
calib_out = None
if best_mode == "y_global":
    calib_out = pd.DataFrame({"drug_id": sorted(preds_cal["drug_id"].unique()),
                              "slope": a_g, "intercept": b_g})
elif best_mode == "y_perdrug":
    calib_out = calib_df.reset_index()[["drug_id","slope","intercept"]]
elif best_mode == "y_shrunk":
    rows = []
    if len(calib_df):
        for d, row in calib_df.iterrows():
            n = int(row["n_lines"]); w = n/(n+5.0)
            a_s = w*float(row["slope"]) + (1-w)*a_g
            b_s = w*float(row["intercept"]) + (1-w)*b_g
            rows.append({"drug_id": str(d), "slope": a_s, "intercept": b_s, "n_lines": n})
        # For drugs without per-drug fit, fall back to global
        missing = set(preds_cal["drug_id"].unique()) - set(calib_df.index)
        for d in sorted(missing):
            rows.append({"drug_id": str(d), "slope": a_g, "intercept": b_g, "n_lines": 0})
        calib_out = pd.DataFrame(rows)
    else:
        calib_out = pd.DataFrame({"drug_id": sorted(preds_cal["drug_id"].unique()),
                                  "slope": a_g, "intercept": b_g, "n_lines": 0})
else:
    # no calibration: identity
    calib_out = pd.DataFrame({"drug_id": sorted(preds_cal["drug_id"].unique()),
                              "slope": 1.0, "intercept": 0.0})

calib_path = os.path.join(OUT_DIR, "calibration_coeffs.csv")
calib_out.to_csv(calib_path, index=False)
print(f"Saved calibration table → {calib_path}")

# 3b) write calibrated per-cell predictions
col_map = {"y_none":"per_cell_predictions_uncal.parquet",
           "y_global":"per_cell_predictions_global.parquet",
           "y_perdrug":"per_cell_predictions_perdrug.parquet",
           "y_shrunk":"per_cell_predictions_shrunk.parquet"}
preds_save = preds_cal[["cell_id","SANGER_MODEL_ID","drug_id", best_mode]].rename(columns={best_mode:"y_pred"})
preds_path = os.path.join(OUT_DIR, col_map.get(best_mode, "per_cell_predictions_final.parquet"))
preds_save.to_parquet(preds_path, index=False)
print(f"Saved calibrated per-cell predictions → {preds_path}  ({len(preds_save)} rows)")

# 3c) evaluation table + quick plots
rows = []
for (line, drug), sub in preds_save.groupby(["SANGER_MODEL_ID","drug_id"]):
    if (line not in LNwide.index) or (drug not in LNwide.columns): continue
    y = sub["y_pred"].values
    bulk_val = float(LNwide.loc[line, drug])
    lt = percentile_of_value(y, bulk_val)
    rows.append({
        "SANGER_MODEL_ID": line,
        "drug_id": drug,
        "n_cells": int(len(y)),
        "pred_mean": float(np.mean(y)),
        "pred_q10": float(np.quantile(y, 0.10)),
        "pred_q90": float(np.quantile(y, 0.90)),
        "bulk_LN_IC50": bulk_val,
        "bulk_percentile_in_pred": float(lt),
        "agreement_label": "agree" if 25<=lt<=75 else ("discordant" if lt<10 or lt>90 else "borderline"),
    })
eval_final = pd.DataFrame(rows).sort_values(["SANGER_MODEL_ID","drug_id"]).reset_index(drop=True)
eval_path = os.path.join(OUT_DIR, "eval_metrics_final.csv")
eval_final.to_csv(eval_path, index=False)
print(f"Saved evaluation metrics → {eval_path}")
print("Final agreement rate:", (eval_final["agreement_label"]=="agree").mean())

# few plots
def plot_dist(line, drug, df_mode):
    sub = df_mode[(df_mode["SANGER_MODEL_ID"]==line) & (df_mode["drug_id"]==drug)]
    if sub.empty: return
    y = sub["y_pred"].values
    bulk_val = float(LNwide.loc[line, drug])
    plt.figure(figsize=(6,4))
    plt.hist(y, bins=40, alpha=0.7, density=True)
    plt.axvline(bulk_val, linestyle="--", linewidth=2, label=f"Bulk LN_IC50 = {bulk_val:.2f}")
    plt.title(f"{line} — drug {drug} ({best_mode})")
    plt.xlabel("Per-cell predicted LN_IC50")
    plt.ylabel("Density")
    plt.legend(loc="best")
    plt.tight_layout()
    fname = os.path.join(PLOTS_DIR, f"dist_{best_mode}_{line}_{drug}.png")
    plt.savefig(fname, dpi=150); plt.close()

made = 0
for (line, drug), _ in preds_save.groupby(["SANGER_MODEL_ID","drug_id"]):
    plot_dist(line, drug, preds_save)
    made += 1
    if made >= 12: break
print(f"Saved {made} calibrated plots → {PLOTS_DIR}")

print("\nNext: set CALIBRATION_CSV = 'out/calibration_coeffs.csv' in your File 2 notebook and re-run predictions for future runs.")


Saved calibration table → out/calibration_coeffs.csv
Saved calibrated per-cell predictions → out/per_cell_predictions_perdrug.parquet  (308270 rows)
Saved evaluation metrics → out/eval_metrics_final.csv
Final agreement rate: 0.18571428571428572
Saved 12 calibrated plots → out/plots_final

Next: set CALIBRATION_CSV = 'out/calibration_coeffs.csv' in your File 2 notebook and re-run predictions for future runs.
