In [1]:
# Cell 1: config
from __future__ import annotations
import os
import numpy as np
import pandas as pd
import joblib
import re
from typing import Dict, Sequence, Optional, List

# paths
SC_PARQUET     = "../../data/filtered_datasets/sc_overlap_genes.parquet"
SHORTLIST_CSV  = "out/shortlist.csv"   # <- update if you saved it elsewhere
MODEL_DIR      = "models"

# outputs
OUT_PRED_PARQUET = "out/per_cell_predictions.parquet"
OUT_SUMMARY_CSV  = "out/summaries.csv"
os.makedirs("out", exist_ok=True)

# optional calibration table (columns: drug_id,slope,intercept) — set to None if not used
CALIBRATION_CSV = None  # e.g., "data/calibration_coeffs.csv"

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)

# optional calibration table (columns: drug_id,slope,intercept)
CALIBRATION_CSV = "out/calibration_coeffs.csv"



In [2]:
# Cell 2: helpers

def load_sc_df(sc_parquet: str) -> pd.DataFrame:
    """
    Load single-cell matrix. Rows = cells, index = SANGER_MODEL_ID (repeated per cell).
    Returns df with numeric gene columns; adds a synthetic 'cell_id' column.
    """
    df = pd.read_parquet(sc_parquet)
    if df.index.name != "SANGER_MODEL_ID":
        # if needed, rename; but your printout shows it's already SANGER_MODEL_ID
        df.index.name = "SANGER_MODEL_ID"

    # generate a unique cell_id since the index repeats
    seq = df.groupby(level=0).cumcount().astype(str)
    df.insert(0, "cell_id", df.index.astype(str) + "__" + seq)

    # ensure gene columns are numeric
    for c in df.columns:
        if c == "cell_id":
            continue
        df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def available_models(model_dir: str, allowed_drugs: Optional[Sequence[str]] = None) -> Dict[str, str]:
    """
    Find bundles like elasticnet_drug1845.joblib, drug427.joblib, 1526.joblib, etc.
    Returns {drug_id_str: path}
    """
    paths: Dict[str, str] = {}
    if not (model_dir and os.path.isdir(model_dir)):
        return paths
    allow = None if allowed_drugs is None else set(map(str, allowed_drugs))
    for fname in os.listdir(model_dir):
        if not fname.endswith(".joblib"):
            continue
        m = re.search(r'(\d+)\.joblib$', fname)
        if not m:
            continue
        drug_id = m.group(1)
        if (allow is None) or (drug_id in allow):
            paths[drug_id] = os.path.join(model_dir, fname)
    return paths

def load_bundle(path: str) -> dict:
    bundle = joblib.load(path)
    for k in ("model", "scaler", "gene_cols"):
        if k not in bundle:
            raise KeyError(f"Bundle at {path} missing key: {k}")
    bundle["gene_cols"] = [str(g) for g in bundle["gene_cols"]]
    return bundle

def align_features(sc_block: pd.DataFrame, gene_cols: Sequence[str], fill_missing: float = 0.0) -> pd.DataFrame:
    """
    Return sc_block with columns exactly in gene_cols order.
    Missing genes are added with constant fill; extras are dropped.
    """
    gene_cols = [str(g) for g in gene_cols]
    present = [g for g in gene_cols if g in sc_block.columns]
    missing = [g for g in gene_cols if g not in sc_block.columns]
    X = sc_block.reindex(columns=present).copy()
    if missing:
        X = pd.concat([X, pd.DataFrame(fill_missing, index=sc_block.index, columns=missing)], axis=1)
    X = X.loc[:, gene_cols]
    return X

def summarize_vector(y: np.ndarray) -> dict:
    y = np.asarray(y, dtype=float)
    if y.size == 0:
        return dict(n=0, mean=np.nan, median=np.nan, sd=np.nan, q10=np.nan, q90=np.nan)
    return {
        "n": int(y.size),
        "mean": float(np.mean(y)),
        "median": float(np.median(y)),
        "sd": float(np.std(y, ddof=1)) if y.size > 1 else np.nan,
        "q10": float(np.quantile(y, 0.10)),
        "q90": float(np.quantile(y, 0.90)),
    }


In [3]:
# Cell 3: load sc + shortlist + (optional) calibration

sc = load_sc_df(SC_PARQUET)  # has columns: cell_id, genes..., index=SANGER_MODEL_ID
shortlist = pd.read_csv(SHORTLIST_CSV, dtype={
    "SANGER_MODEL_ID": str,
    "low_drug": str,
    "high_drug": str
})

calib = None
if CALIBRATION_CSV and os.path.exists(CALIBRATION_CSV):
    calib = pd.read_csv(CALIBRATION_CSV, dtype={"drug_id": str}).set_index("drug_id")
    # sanity
    for col in ("slope", "intercept"):
        if col not in calib.columns:
            raise KeyError(f"Calibration CSV missing column: {col}")


In [4]:
# %% 
# Cell 4: prediction loop (with optional per-drug calibration)

# discover available models for the drugs in the shortlist
needed_drugs = set(shortlist["low_drug"].astype(str)).union(set(shortlist["high_drug"].astype(str)))
model_paths = available_models(MODEL_DIR, allowed_drugs=sorted(needed_drugs))
if not model_paths:
    raise ValueError("No matching model bundles found for drugs in shortlist.")

pred_rows: List[pd.DataFrame] = []
summary_rows: List[dict] = []

# speed: precompute line -> row indices (sc.index is SANGER_MODEL_ID)
line_to_idx = {line: np.where(sc.index.values == line)[0] for line in sc.index.unique()}

# tiny helper for calibration lookup
def apply_calibration(drug_id: str, y: np.ndarray) -> np.ndarray:
    if calib is None: 
        return y
    if drug_id in calib.index:
        slope = float(calib.loc[drug_id, "slope"])
        intercept = float(calib.loc[drug_id, "intercept"])
        return slope * y + intercept
    return y  # fallback to uncalibrated if missing

print(f"Running predictions for {len(shortlist)} lines × 2 drugs (low/high)…")
for _, row in shortlist.iterrows():
    line = str(row["SANGER_MODEL_ID"])
    for drug_col in ("low_drug", "high_drug"):
        drug_id = str(row[drug_col])

        # skip if model missing
        if drug_id not in model_paths:
            print(f"[WARN] Missing model for drug {drug_id}; skipping ({line}).")
            continue

        # subset cells for this line
        idx = line_to_idx.get(line, None)
        if idx is None or len(idx) == 0:
            print(f"[WARN] No cells found for line {line}; skipping.")
            continue
        sc_block = sc.iloc[idx, :]

        # gene-only matrix (drop metadata like 'cell_id')
        gene_cols_in_sc = [c for c in sc_block.columns if c != "cell_id"]

        # load model bundle & align features
        bundle = load_bundle(model_paths[drug_id])
        X_df = align_features(sc_block[gene_cols_in_sc], bundle["gene_cols"], fill_missing=0.0)
        X = X_df.values

        # transform & predict
        Xs = bundle["scaler"].transform(X)
        y_pred = bundle["model"].predict(Xs)

        # optional per-drug linear calibration
        y_pred = apply_calibration(drug_id, y_pred)

        # collect per-cell predictions
        pred_rows.append(pd.DataFrame({
            "cell_id": sc_block["cell_id"].values,
            "SANGER_MODEL_ID": line,
            "drug_id": drug_id,
            "y_pred": y_pred.astype(float),
        }))

        # per (line, drug) summary
        stats = summarize_vector(y_pred)
        stats.update({"SANGER_MODEL_ID": line, "drug_id": drug_id})
        summary_rows.append(stats)

# concatenate outputs
preds = (pd.concat(pred_rows, axis=0, ignore_index=True)
         if pred_rows else pd.DataFrame(columns=["cell_id","SANGER_MODEL_ID","drug_id","y_pred"]))
summaries = (pd.DataFrame(summary_rows)
             if summary_rows else pd.DataFrame(columns=["SANGER_MODEL_ID","drug_id","n","mean","median","sd","q10","q90"]))

print(f"Made {len(preds):,} per-cell predictions across {preds[['SANGER_MODEL_ID','drug_id']].drop_duplicates().shape[0]} (line,drug) pairs.")
preds.head(), summaries.head()


Running predictions for 13 lines × 2 drugs (low/high)…
Made 24,350 per-cell predictions across 26 (line,drug) pairs.


(        cell_id SANGER_MODEL_ID drug_id    y_pred
 0  SIDM00872__0       SIDM00872    1845  1.023117
 1  SIDM00872__1       SIDM00872    1845  0.769557
 2  SIDM00872__2       SIDM00872    1845  0.277932
 3  SIDM00872__3       SIDM00872    1845  0.719935
 4  SIDM00872__4       SIDM00872    1845  0.826301,
       n      mean    median        sd       q10       q90 SANGER_MODEL_ID  \
 0  1623  0.882667  0.884405  0.302596  0.497564  1.263792       SIDM00872   
 1  1623  4.158027  4.152576  0.229049  3.885130  4.438769       SIDM00872   
 2   568  1.402018  1.370122  0.363885  0.972862  1.888719       SIDM01037   
 3   568  3.256128  3.247564  0.259673  2.927773  3.594750       SIDM01037   
 4  1280  1.506137  1.501434  0.173338  1.288101  1.724522       SIDM00866   
 
   drug_id  
 0    1845  
 1    1089  
 2    1845  
 3    1096  
 4    1526  )

In [5]:
# Cell 5: save
if len(preds):
    preds.to_parquet(OUT_PRED_PARQUET, index=False)
    print(f"Wrote per-cell predictions → {OUT_PRED_PARQUET} ({len(preds)} rows)")
else:
    print("No predictions generated.")

if len(summaries):
    summaries = summaries[["SANGER_MODEL_ID", "drug_id", "n", "mean", "median", "sd", "q10", "q90"]]
    summaries.to_csv(OUT_SUMMARY_CSV, index=False)
    print(f"Wrote summaries → {OUT_SUMMARY_CSV}")
else:
    print("No summary rows generated.")


Wrote per-cell predictions → out/per_cell_predictions.parquet (24350 rows)
Wrote summaries → out/summaries.csv
