In [11]:
# Paths
SC_PARQUET       = "../../data/filtered_datasets/breast_cancer_embeddings.parquet"  # SANGER_MODEL_ID + emb_###
SHORTLIST_CSV    = "out/shortlist.csv"
MODEL_DIR        = "embeddings"   # where you saved the bulk models
OUT_PRED_PARQUET = "out/per_cell_predictions.parquet"
OUT_SUMMARY_CSV  = "out/summaries.csv"

import os, re, joblib, numpy as np, pandas as pd
from typing import Dict, Sequence, Optional, List
os.makedirs("out", exist_ok=True)

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)


In [12]:
def load_sc_df(sc_parquet: str) -> pd.DataFrame:
    """
    Load single-cell EMBEDDINGS. Returns df with:
      index = SANGER_MODEL_ID (repeated per cell),
      columns = 'cell_id' (generated if missing) + emb_### features.
    """
    df = pd.read_parquet(sc_parquet)

    # Make sure we have SANGER_MODEL_ID as index
    if "SANGER_MODEL_ID" in df.columns:
        df = df.set_index("SANGER_MODEL_ID")
    if df.index.name != "SANGER_MODEL_ID":
        df.index.name = "SANGER_MODEL_ID"

    # Ensure a per-cell unique ID column
    if "cell_id" not in df.columns:
        seq = df.groupby(level=0).cumcount().astype(str)
        df.insert(0, "cell_id", df.index.astype(str) + "__" + seq)

    # Keep only embedding columns + cell_id
    emb_cols = [c for c in df.columns if re.fullmatch(r"emb_\d+", str(c))]
    keep = ["cell_id"] + emb_cols
    df = df[keep].copy()

    # Cast embeddings to numeric (robust)
    for c in emb_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    return df


def available_models(model_dir: str, allowed_drugs: Optional[Sequence[str]] = None) -> Dict[str, str]:
    """
    Find model bundles by drug id inside MODEL_DIR.
    Supports names like:
      elasticnet_drug427_voomEmb.joblib
      drug427.joblib
      427.joblib
    Returns {drug_id_str: path}
    """
    paths: Dict[str, str] = {}
    if not (model_dir and os.path.isdir(model_dir)):
        return paths
    allow = None if allowed_drugs is None else set(map(str, allowed_drugs))

    for fname in os.listdir(model_dir):
        if not fname.endswith(".joblib"):
            continue
        m = (re.search(r"drug(\d+)", fname) or
             re.search(r"(\d+)(?=\.joblib$)", fname))
        if not m:
            continue
        drug_id = m.group(1)
        if allow is None or drug_id in allow:
            paths[drug_id] = os.path.join(model_dir, fname)
    return paths


def load_bundle(path: str) -> dict:
    """
    Load a trained model bundle from disk and normalize it to:
        {"pipeline": <fitted-pipeline>, "feature_cols": [..]}
    Supports:
      - {"pipeline", "feature_cols"}                 # current
      - {"pipeline", "gene_cols"}                    # your earlier save
      - {"model","scaler","gene_cols"}               # legacy (no imputer)
    """
    import joblib
    from sklearn.pipeline import Pipeline

    b = joblib.load(path)

    # Case 1: already has a pipeline
    if isinstance(b, dict) and "pipeline" in b:
        feats = b.get("feature_cols") or b.get("gene_cols") or b.get("emb_cols")
        if feats is None:
            raise KeyError(f"Bundle {path} has 'pipeline' but no feature names "
                           "(expected 'feature_cols' or 'gene_cols').")
        return {"pipeline": b["pipeline"], "feature_cols": [str(c) for c in feats]}

    # Case 2: legacy dict with model+scaler(+genes)
    if isinstance(b, dict) and {"model", "scaler", "gene_cols"}.issubset(b.keys()):
        pipe = Pipeline([
            ("scale", b["scaler"]),
            ("model", b["model"]),
        ])
        return {"pipeline": pipe, "feature_cols": [str(c) for c in b["gene_cols"]]}

    # Case 3: someone saved the pipeline object directly
    if hasattr(b, "predict") and hasattr(b, "fit"):
        # We still need feature names to align columns
        # try common sidecar keys if present
        feats = getattr(b, "feature_names_in_", None)
        if feats is not None:
            return {"pipeline": b, "feature_cols": [str(c) for c in feats]}
        raise KeyError(f"Bundle {path} is a model/pipeline object without saved feature names.")

    raise KeyError(f"Unrecognized bundle format at {path}.")



def align_features(frame: pd.DataFrame, needed_cols: Sequence[str], fill_missing: float = 0.0) -> pd.DataFrame:
    """
    Align 'frame' to exactly the columns in 'needed_cols' (order preserved).
    Missing cols are added with constant fill; extras dropped.
    """
    needed = [str(c) for c in needed_cols]
    present = [c for c in needed if c in frame.columns]
    missing = [c for c in needed if c not in frame.columns]

    X = frame.reindex(columns=present).copy()
    if missing:
        X = pd.concat([X, pd.DataFrame(fill_missing, index=frame.index, columns=missing)], axis=1)
    return X[needed]


def summarize_vector(y: np.ndarray) -> dict:
    y = np.asarray(y, dtype=float)
    if y.size == 0:
        return dict(n=0, mean=np.nan, median=np.nan, sd=np.nan, q10=np.nan, q90=np.nan)
    return {
        "n": int(y.size),
        "mean": float(np.mean(y)),
        "median": float(np.median(y)),
        "sd": float(np.std(y, ddof=1)) if y.size > 1 else np.nan,
        "q10": float(np.quantile(y, 0.10)),
        "q90": float(np.quantile(y, 0.90)),
    }


In [13]:
# Single-cell embeddings (rows=cells; index=SANGER_MODEL_ID)
sc = load_sc_df(SC_PARQUET)

# Shortlist with columns: SANGER_MODEL_ID, low_drug, high_drug
shortlist = pd.read_csv(
    SHORTLIST_CSV,
    dtype={"SANGER_MODEL_ID": str, "low_drug": str, "high_drug": str}
)

# What drugs do we need models for?
needed_drugs = set(shortlist["low_drug"]).union(set(shortlist["high_drug"]))
model_paths = available_models(MODEL_DIR, allowed_drugs=sorted(needed_drugs))
print(f"Found {len(model_paths)} model(s) for {len(needed_drugs)} requested drug(s).")


Found 10 model(s) for 10 requested drug(s).


In [14]:
pred_rows: List[pd.DataFrame] = []
summary_rows: List[dict] = []

# Precompute fast lookup from line -> cell rows
line_to_idx = {line: np.where(sc.index.values == line)[0] for line in sc.index.unique()}

for _, row in shortlist.iterrows():
    line = str(row["SANGER_MODEL_ID"])
    idx = line_to_idx.get(line, None)
    if idx is None or len(idx) == 0:
        print(f"[WARN] No cells for line {line}; skipping both drugs.")
        continue

    sc_block = sc.iloc[idx, :]
    emb_cols_in_sc = [c for c in sc_block.columns if c != "cell_id"]

    for drug_col in ("low_drug", "high_drug"):
        drug_id = str(row[drug_col])

        # skip if no model
        if drug_id not in model_paths:
            print(f"[WARN] Missing model for drug {drug_id}; skipping.")
            continue

        # load pipeline bundle & align features
        bundle = load_bundle(model_paths[drug_id])
        pipe = bundle["pipeline"]
        needed = bundle["feature_cols"]   # embedding feature names used in training

        X = align_features(sc_block[emb_cols_in_sc], needed, fill_missing=0.0)

        # Predict directly with the pipeline (imputer+scaler+enet already inside)
        y_pred = pipe.predict(X.values)

        # collect per-cell predictions
        pred_rows.append(pd.DataFrame({
            "cell_id": sc_block["cell_id"].values,
            "SANGER_MODEL_ID": line,
            "drug_id": drug_id,
            "y_pred": y_pred,
        }))

        # (line, drug) summary
        stats = summarize_vector(y_pred)
        stats.update({"SANGER_MODEL_ID": line, "drug_id": drug_id})
        summary_rows.append(stats)

# Assemble outputs
preds = pd.concat(pred_rows, axis=0, ignore_index=True) if pred_rows else pd.DataFrame(
    columns=["cell_id", "SANGER_MODEL_ID", "drug_id", "y_pred"]
)
summaries = pd.DataFrame(summary_rows) if summary_rows else pd.DataFrame(
    columns=["SANGER_MODEL_ID", "drug_id", "n", "mean", "median", "sd", "q10", "q90"]
)

preds.head(), summaries.head()


(        cell_id SANGER_MODEL_ID drug_id     y_pred
 0  SIDM00872__0       SIDM00872    1845  14.828639
 1  SIDM00872__1       SIDM00872    1845  10.983622
 2  SIDM00872__2       SIDM00872    1845   8.991213
 3  SIDM00872__3       SIDM00872    1845  17.720463
 4  SIDM00872__4       SIDM00872    1845  18.165287,
       n       mean     median        sd        q10        q90 SANGER_MODEL_ID  \
 0  1612  15.045147  14.885859  3.430244  10.956575  19.193967       SIDM00872   
 1  1612  19.455275  19.552840  1.848168  16.966412  21.716185       SIDM00872   
 2   558  14.471709  14.217988  2.752169  11.447732  17.489128       SIDM01037   
 3   558  16.862919  16.851243  2.874423  13.277205  20.609584       SIDM01037   
 4  1279  -8.559445  -8.752190  1.861862 -10.815838  -5.975373       SIDM00866   
 
   drug_id  
 0    1845  
 1    1089  
 2    1845  
 3    1096  
 4    1526  )

In [15]:
if len(preds):
    preds.to_parquet(OUT_PRED_PARQUET, index=False)
    print(f"📝 Wrote per-cell predictions → {OUT_PRED_PARQUET} ({len(preds)} rows)")
else:
    print("No predictions generated.")

if len(summaries):
    summaries = summaries[["SANGER_MODEL_ID", "drug_id", "n", "mean", "median", "sd", "q10", "q90"]]
    summaries.to_csv(OUT_SUMMARY_CSV, index=False)
    print(f"📝 Wrote summaries → {OUT_SUMMARY_CSV}")
else:
    print("No summary rows generated.")


📝 Wrote per-cell predictions → out/per_cell_predictions.parquet (24132 rows)
📝 Wrote summaries → out/summaries.csv
