In [4]:
# Cell 1: imports & config
from __future__ import annotations

import os, re
import numpy as np
import pandas as pd
import joblib

# your paths (tweak if needed)
PRED_PARQUET = "out/per_cell_predictions.parquet"            # from File 2
BULK_LONG    = "../../data/gdsc_bulk_overlap_genes.parquet"  # long table with drug, cell_line, IC50 or LN_IC50
SC_PARQUET   = "../../data/filtered_datasets/sc_overlap_genes.parquet"  # cells × genes (aligned)

np.set_printoptions(suppress=True, precision=4)
pd.set_option("display.width", 160)
pd.set_option("display.max_columns", 20)


In [5]:
# Cell 2: helpers

def load_bulk_long_to_wide(bulk_path: str) -> pd.DataFrame:
    df = pd.read_parquet(bulk_path) if bulk_path.endswith(".parquet") else pd.read_csv(bulk_path)
    lower = {c.lower(): c for c in df.columns}
    drug_col = lower.get("drug") or lower.get("drug_id")
    line_col = lower.get("cell_line") or lower.get("sanger_model_id") or lower.get("line")
    ln_col   = lower.get("ln_ic50") or lower.get("lnic50")
    ic50_col = lower.get("ic50")
    if not drug_col or not line_col or (not ln_col and not ic50_col):
        raise KeyError(f"Bulk must have drug, cell_line, and IC50 or LN_IC50. Columns: {list(df.columns)}")
    if ln_col is None:
        df["LN_IC50"] = np.log(df[ic50_col].astype(float))
        ln_col = "LN_IC50"
    df["_drug"] = df[drug_col].astype(str)
    df["_line"] = df[line_col].astype(str)
    wide = df.pivot_table(index="_line", columns="_drug", values=ln_col, aggfunc="mean")
    wide.index.name = "SANGER_MODEL_ID"
    wide.columns = wide.columns.astype(str)
    return wide

def check_bulk_units(bulk_path: str) -> None:
    df = pd.read_parquet(bulk_path) if bulk_path.endswith(".parquet") else pd.read_csv(bulk_path)
    lower = {c.lower(): c for c in df.columns}
    ln_col   = lower.get("ln_ic50") or lower.get("lnic50")
    ic50_col = lower.get("ic50")
    print("=== Bulk unit check ===")
    print("Columns:", list(df.columns))
    if ic50_col:
        print("\nIC50 summary:")
        display(df[ic50_col].describe())
    if ln_col:
        print("\nLN_IC50 summary:")
        display(df[ln_col].describe())
    if ic50_col and ln_col:
        comp = (np.log(df[ic50_col].astype(float)) - df[ln_col].astype(float)).dropna()
        print(f"\nmedian |ln(IC50) - LN_IC50| = {comp.abs().median():.4g} (n={comp.size})")
        print(f"mean   ln(IC50) - LN_IC50   = {comp.mean():.4g}")

def available_models(model_dir: str, allowed_drugs=None):
    paths = {}
    if not (model_dir and os.path.isdir(model_dir)):
        return paths
    allow = None if allowed_drugs is None else set(map(str, allowed_drugs))
    for fname in os.listdir(model_dir):
        if not fname.endswith(".joblib"):
            continue
        m = re.search(r"(\d+)\.joblib$", fname)  # grabs trailing digits
        if not m:
            continue
        drug_id = m.group(1)
        if (allow is None) or (drug_id in allow):
            paths[drug_id] = os.path.join(model_dir, fname)
    return paths

def simple_linfit(x, y):
    x = np.asarray(x, float); y = np.asarray(y, float)
    if x.size < 2 or np.var(x, ddof=1) == 0:
        return np.nan, np.nan
    a = np.cov(x, y, ddof=1)[0,1] / np.var(x, ddof=1)
    b = y.mean() - a * x.mean()
    return float(a), float(b)


In [6]:
# Cell 3: bulk unit diagnostics
check_bulk_units(BULK_LONG)


=== Bulk unit check ===
Columns: ['SANGER_MODEL_ID', 'DRUG_ID', 'LN_IC50', 'SIDG00004', 'SIDG00036', 'SIDG00100', 'SIDG00101', 'SIDG00147', 'SIDG00152', 'SIDG00178', 'SIDG00181', 'SIDG00202', 'SIDG00218', 'SIDG00229', 'SIDG00244', 'SIDG00245', 'SIDG00262', 'SIDG00279', 'SIDG00313', 'SIDG00317', 'SIDG00323', 'SIDG00324', 'SIDG00372', 'SIDG00373', 'SIDG00377', 'SIDG00396', 'SIDG00429', 'SIDG00455', 'SIDG00474', 'SIDG00481', 'SIDG00498', 'SIDG00513', 'SIDG00520', 'SIDG00567', 'SIDG00574', 'SIDG00577', 'SIDG00629', 'SIDG00637', 'SIDG00640', 'SIDG00654', 'SIDG00655', 'SIDG00667', 'SIDG00668', 'SIDG00671', 'SIDG00703', 'SIDG00706', 'SIDG00737', 'SIDG00741', 'SIDG00747', 'SIDG00756', 'SIDG00765', 'SIDG00769', 'SIDG00770', 'SIDG00771', 'SIDG00794', 'SIDG00795', 'SIDG00797', 'SIDG00805', 'SIDG00811', 'SIDG00882', 'SIDG00911', 'SIDG00960', 'SIDG01000', 'SIDG01003', 'SIDG01012', 'SIDG01014', 'SIDG01015', 'SIDG01021', 'SIDG01113', 'SIDG01129', 'SIDG01132', 'SIDG01138', 'SIDG01139', 'SIDG01146', 'S

count    8238.000000
mean        3.487623
std         2.180995
min        -5.015997
25%         1.938438
50%         3.830331
75%         5.254912
max         9.159444
Name: LN_IC50, dtype: float64

In [7]:
# Cell 4: global bias diagnostics

# per-cell preds
preds = pd.read_parquet(PRED_PARQUET) if PRED_PARQUET.endswith(".parquet") else pd.read_csv(PRED_PARQUET)
preds["SANGER_MODEL_ID"] = preds["SANGER_MODEL_ID"].astype(str)
preds["drug_id"] = preds["drug_id"].astype(str)

# per-(line,drug) SC mean
sc_means = preds.groupby(["SANGER_MODEL_ID","drug_id"])["y_pred"].mean().rename("sc_mean").reset_index()

# bulk wide
ln_wide = load_bulk_long_to_wide(BULK_LONG)
bulk_long = ln_wide.stack().rename("bulk").reset_index().rename(columns={"_line":"SANGER_MODEL_ID","_drug":"drug_id"})

merged = sc_means.merge(bulk_long, on=["SANGER_MODEL_ID","drug_id"], how="inner")
print("=== Global SC mean vs bulk ===")
print("Merged pairs:", len(merged))
if len(merged) >= 2:
    corr = merged["sc_mean"].corr(merged["bulk"])
    a, b = simple_linfit(merged["sc_mean"].values, merged["bulk"].values)
    print(f"corr(sc_mean, bulk) = {corr:.3f}")
    print(f"global slope (bulk ≈ a*sc_mean + b) = {a:.3f}")
    print(f"global intercept = {b:.3f}")
    display(merged.sample(min(10, len(merged)), random_state=0))
else:
    print("Not enough pairs to compute correlation.")


=== Global SC mean vs bulk ===
Merged pairs: 26
corr(sc_mean, bulk) = 0.816
global slope (bulk ≈ a*sc_mean + b) = 1.258
global intercept = -1.191


Unnamed: 0,SANGER_MODEL_ID,drug_id,sc_mean,bulk
2,SIDM00148,1931,5.517494,6.289278
20,SIDM00928,1526,1.853074,0.975542
14,SIDM00885,1096,2.781226,2.72312
17,SIDM00893,2508,2.001284,3.037925
5,SIDM00630,2038,4.930231,3.2624
11,SIDM00872,1845,1.542293,-2.378115
22,SIDM01037,1096,2.968024,5.241832
13,SIDM00879,427,5.217323,5.450255
18,SIDM00920,2508,1.911496,0.819275
19,SIDM00920,427,5.265457,5.133296


In [8]:
# Cell 5: per-drug calibration (bulk ≈ a*sc_mean + b)

rows = []
for d, sub in merged.groupby("drug_id"):
    if len(sub) >= 3:
        a, b = simple_linfit(sub["sc_mean"].values, sub["bulk"].values)
        corr = sub["sc_mean"].corr(sub["bulk"])
        rows.append({"drug_id": d, "n_lines": int(len(sub)), "slope": a, "intercept": b, "corr": corr})

calib_df = pd.DataFrame(rows).sort_values("corr", ascending=False)
print("=== Per-drug calibration estimates ===")
if calib_df.empty:
    print("No drug had ≥3 lines to fit calibration.")
else:
    print("Top 10 by correlation:")
    display(calib_df.head(10))
    print("Bottom 10 by correlation:")
    display(calib_df.tail(10))


=== Per-drug calibration estimates ===
Top 10 by correlation:


Unnamed: 0,drug_id,n_lines,slope,intercept,corr
2,1931,3,63.932185,-346.61601,0.986918
0,1096,3,6.478394,-14.913033,0.788429
1,1845,3,6.523471,-11.622634,0.504447
3,2508,3,6.379974,-10.118284,0.355459
4,427,6,2.740826,-9.409872,0.138194


Bottom 10 by correlation:


Unnamed: 0,drug_id,n_lines,slope,intercept,corr
2,1931,3,63.932185,-346.61601,0.986918
0,1096,3,6.478394,-14.913033,0.788429
1,1845,3,6.523471,-11.622634,0.504447
3,2508,3,6.379974,-10.118284,0.355459
4,427,6,2.740826,-9.409872,0.138194


In [9]:
# Cell 6: missing genes per model (defensive check)

# load sc just to get gene columns; tolerate 'cell_id' if present
sc = pd.read_parquet(SC_PARQUET) if SC_PARQUET.endswith(".parquet") else pd.read_csv(SC_PARQUET, index_col=0)
sc_cols = [c for c in sc.columns if c != "cell_id"]
sc_set = set(map(str, sc_cols))

# discover bundles for drugs present in preds
needed = sorted(preds["drug_id"].unique())
model_paths = available_models("models", allowed_drugs=needed)

miss_rows = []
for d, p in model_paths.items():
    try:
        b = joblib.load(p)
        want = set(map(str, b["gene_cols"]))
        miss = want - sc_set
        miss_rows.append({
            "drug_id": d,
            "n_model_features": len(want),
            "n_missing": len(miss),
            "frac_missing": (len(miss) / len(want)) if len(want) else np.nan
        })
    except Exception as e:
        miss_rows.append({"drug_id": d, "n_model_features": np.nan, "n_missing": np.nan, "frac_missing": np.nan})
        print(f"[warn] skip {d}: {e}")

miss_df = pd.DataFrame(miss_rows).sort_values("frac_missing", ascending=False)
print("=== Missing genes per model ===")
display(miss_df.head(10))
if not miss_df.empty:
    print("Max frac_missing:", miss_df["frac_missing"].max())


=== Missing genes per model ===


Unnamed: 0,drug_id,n_model_features,n_missing,frac_missing
0,2508,2044,0,0.0
1,2540,2044,0,0.0
2,1931,2044,0,0.0
3,2515,2044,0,0.0
4,1096,2044,0,0.0
5,1845,2044,0,0.0
6,1089,2044,0,0.0
7,427,2044,0,0.0
8,1526,2044,0,0.0
9,2038,2044,0,0.0


Max frac_missing: 0.0


In [10]:
# Cell 7 (optional): preview agreement improvement after calibration

def percentile_of_value(arr: np.ndarray, value: float) -> float:
    arr = np.asarray(arr, float)
    if arr.size == 0 or np.isnan(value):
        return np.nan
    lt = np.sum(arr < value)
    eq = np.sum(arr == value)
    return 100.0 * (lt + 0.5*eq) / arr.size

def label_agreement(pct, agree=(25,75), disc=(10,90)):
    if np.isnan(pct): return "NA"
    if agree[0] <= pct <= agree[1]: return "agree"
    if pct < disc[0] or pct > disc[1]: return "discordant"
    return "borderline"

# before
rows0 = []
for (line, drug), sub in preds.groupby(["SANGER_MODEL_ID","drug_id"]):
    if (line in ln_wide.index) and (drug in ln_wide.columns):
        pct = percentile_of_value(sub["y_pred"].values, ln_wide.loc[line, drug])
        rows0.append(label_agreement(pct))
before_agree = (np.array(rows0) == "agree").mean() if rows0 else np.nan

# after (only for drugs with fitted calibration)
preds_cal = preds.copy()
if not calib_df.empty:
    preds_cal["y_pred_cal"] = preds_cal["y_pred"].astype(float)
    for d, row in calib_df.set_index("drug_id").iterrows():
        m = preds_cal["drug_id"].eq(d)
        preds_cal.loc[m, "y_pred_cal"] = row["slope"] * preds_cal.loc[m, "y_pred_cal"] + row["intercept"]

    rows1 = []
    for (line, drug), sub in preds_cal.groupby(["SANGER_MODEL_ID","drug_id"]):
        if (line in ln_wide.index) and (drug in ln_wide.columns):
            val = ln_wide.loc[line, drug]
            y = sub["y_pred_cal"].values if "y_pred_cal" in sub else sub["y_pred"].values
            pct = percentile_of_value(y, val)
            rows1.append(label_agreement(pct))
    after_agree = (np.array(rows1) == "agree").mean() if rows1 else np.nan
else:
    after_agree = np.nan

print("=== Agreement preview (bulk within per-cell [25,75]) ===")
print(f"before calibration: {before_agree:.3f}")
print(f"after  calibration: {after_agree:.3f}   (only calibrated drugs included)")


=== Agreement preview (bulk within per-cell [25,75]) ===
before calibration: 0.000
after  calibration: 0.154   (only calibrated drugs included)
