In [5]:
import os
import numpy as np
import pandas as pd
import joblib

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, root_mean_squared_error
from scipy.stats import pearsonr
from sklearn.mixture import GaussianMixture


In [6]:
# Paths (adjust if needed)
MODELS_DIR       = "../trained_models"                                # elasticnet_drug{ID}.joblib
SC_EXPR_PARQUET  = "../../data/filtered_datasets/sc_overlap_genes.parquet"    # cells × genes (SIDG columns)
SC_META_CSV      = "../metadata/sc_cell_to_sanger.csv"                # columns: cell_id, SANGER_MODEL_ID (optional, see fallback)
GDSC_PARQUET     = "../../../../bulk_state_of_the_art/data/processed/gdsc_final_cleaned.parquet"
CALIB_DIR        = "../singlecell_validation_outputs"                 # where calibration_drug{ID}.txt might be
OUTDIR           = "../singlecell_per_line_results"
os.makedirs(OUTDIR, exist_ok=True)

BEST_DRUGS = [1845, 2540, 2038, 2508, 1096, 1931, 2515, 1089, 427, 1526]  # or list(models present)


In [7]:
# <<< SET THIS >>>
SANGER_ID = "XXXX"   # e.g., "SIDM00438" or your SANGER_MODEL_ID string


In [8]:
# Load sc expression (cells × SIDG genes)
sc_expr = pd.read_parquet(SC_EXPR_PARQUET)
print("sc_expr:", sc_expr.shape)

def load_sc_to_sanger(index_like, meta_csv=SC_META_CSV):
    if meta_csv and os.path.exists(meta_csv):
        meta = pd.read_csv(meta_csv).set_index("cell_id")
        assert "SANGER_MODEL_ID" in meta.columns
        series = meta.loc[meta.index.intersection(index_like), "SANGER_MODEL_ID"].astype(str)
        # Reindex to sc_expr in case order differs
        return series.reindex(index_like)
    else:
        # Fallback: pull from AnnData obs if available
        import scanpy as sc
        H5AD = "breast_cancer_raw_annotated.h5ad"  # adjust if needed
        adata = sc.read_h5ad(H5AD)
        obs = adata.obs
        assert "SANGER_MODEL_ID" in obs.columns, "SANGER_MODEL_ID not found in obs; provide SC_META_CSV"
        m = obs["SANGER_MODEL_ID"].astype(str)
        m = m.reindex(sc_expr.index)
        return m

sc_to_sanger = load_sc_to_sanger(sc_expr.index)
print("Mapped cells:", sc_to_sanger.notna().sum(), "/", len(sc_to_sanger))


sc_expr: (35276, 545)


FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'breast_cancer_raw_annotated.h5ad', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
cells_of_line = sc_to_sanger[sc_to_sanger == SANGER_ID].index
assert len(cells_of_line) > 0, f"No cells found for SANGER_MODEL_ID={SANGER_ID}"
sc_line = sc_expr.loc[cells_of_line]
print(f"Cells for {SANGER_ID}:", sc_line.shape)


In [None]:
gdsc = pd.read_parquet(GDSC_PARQUET)
gdsc = gdsc.dropna(subset=["SANGER_MODEL_ID","DRUG_ID","LN_IC50"]).copy()
gdsc["SANGER_MODEL_ID"] = gdsc["SANGER_MODEL_ID"].astype(str)

gdsc_line = gdsc[(gdsc["SANGER_MODEL_ID"] == SANGER_ID) & (gdsc["DRUG_ID"].isin(BEST_DRUGS))][["DRUG_ID","LN_IC50"]]
gdsc_line = gdsc_line.drop_duplicates(subset=["DRUG_ID"])  # safety
print("GDSC rows for this line:", gdsc_line.shape[0])


In [None]:
def load_calibration(drug_id, calib_dir=CALIB_DIR):
    """
    Returns slope, intercept if a calibration file exists; else identity mapping.
    File format: two lines 'slope\t<...>' and 'intercept\t<...>'.
    """
    path = os.path.join(calib_dir, f"calibration_drug{drug_id}.txt")
    if not os.path.exists(path):
        return 1.0, 0.0
    slope, intercept = 1.0, 0.0
    with open(path, "r") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                k, v = parts
                if k == "slope": slope = float(v)
                if k == "intercept": intercept = float(v)
    return slope, intercept

def heterogeneity_metrics(y_cell, random_state=42):
    """
    Fit GMM with k=1..3 to per-cell predictions; return chosen k by BIC,
    fraction of highest-mean component, and separation (max_mean - min_mean).
    """
    y = np.asarray(y_cell).reshape(-1,1)
    best_k, best_bic, best_gmm = 1, np.inf, None
    for k in (1,2,3):
        g = GaussianMixture(n_components=k, random_state=random_state).fit(y)
        bic = g.bic(y)
        if bic < best_bic:
            best_k, best_bic, best_gmm = k, bic, g
    means = best_gmm.means_.flatten()
    weights = best_gmm.weights_.flatten()
    hi = np.argmax(means)
    frac_resistant = float(weights[hi])
    delta_means = float(np.max(means) - np.min(means)) if best_k > 1 else 0.0
    return best_k, frac_resistant, delta_means


In [None]:
summ_rows = []
percell_store = []  # optional: store per-cell predictions (can be large)

for drug_id in BEST_DRUGS:
    bundle_path = os.path.join(MODELS_DIR, f"elasticnet_drug{drug_id}.joblib")
    if not os.path.exists(bundle_path):
        print(f"Missing model for drug {drug_id}, skipping.")
        continue

    # Load model bundle
    bundle = joblib.load(bundle_path)
    model  = bundle["model"]
    scaler = bundle["scaler"]
    gene_cols = bundle["gene_cols"]  # expected SIDG column order

    # Align sc features to model's gene order; fill missing genes with zeros
    present = [g for g in gene_cols if g in sc_line.columns]
    if len(present) == 0:
        print(f"Drug {drug_id}: no overlapping genes — skipping.")
        continue

    X = np.zeros((sc_line.shape[0], len(gene_cols)), dtype=float)
    col_index = {g:i for i,g in enumerate(sc_line.columns)}
    for j, g in enumerate(gene_cols):
        if g in col_index:
            X[:, j] = sc_line.iloc[:, col_index[g]].values

    # Scale and predict
    X_scaled = scaler.transform(X)
    y_pred_cell = model.predict(X_scaled)

    # Calibrate to bulk scale (if calibration exists)
    slope, intercept = load_calibration(drug_id)
    y_pred_cell_cal = y_pred_cell * slope + intercept

    # Summaries for this cell line & drug
    n_cells = len(y_pred_cell_cal)
    mean_ = float(np.mean(y_pred_cell_cal))
    med_  = float(np.median(y_pred_cell_cal))
    q10_  = float(np.quantile(y_pred_cell_cal, 0.10))
    q90_  = float(np.quantile(y_pred_cell_cal, 0.90))
    k_, frac_res_, dmu_ = heterogeneity_metrics(y_pred_cell_cal)

    # Bulk truth for this line (if present)
    ln_ic50_bulk = gdsc_line.loc[gdsc_line["DRUG_ID"] == drug_id, "LN_IC50"]
    ln_ic50_bulk = float(ln_ic50_bulk.values[0]) if len(ln_ic50_bulk) else np.nan
    delta = mean_ - ln_ic50_bulk if not np.isnan(ln_ic50_bulk) else np.nan

    summ_rows.append({
        "SANGER_MODEL_ID": SANGER_ID,
        "DRUG_ID": drug_id,
        "n_cells": n_cells,
        "pred_mean_cal": mean_,
        "pred_median_cal": med_,
        "pred_q10_cal": q10_,
        "pred_q90_cal": q90_,
        "hetero_k": k_,
        "hetero_frac_resistant": frac_res_,
        "hetero_delta_means": dmu_,
        "bulk_LN_IC50": ln_ic50_bulk,
        "mean_minus_bulk": delta,
        "calib_slope": slope,
        "calib_intercept": intercept
    })

    # (Optional) keep per-cell values (comment out if large)
    percell_store.append(pd.DataFrame({
        "cell_id": sc_line.index,
        "SANGER_MODEL_ID": SANGER_ID,
        "DRUG_ID": drug_id,
        "pred_cell_cal": y_pred_cell_cal
    }))


In [None]:
summary = pd.DataFrame(summ_rows)

# Lower LN_IC50 = more sensitive
summary_sorted = summary.sort_values("pred_mean_cal", ascending=True).reset_index(drop=True)

print("Top predicted sensitive drugs for", SANGER_ID)
display(summary_sorted[["DRUG_ID","pred_mean_cal","pred_q10_cal","pred_q90_cal","bulk_LN_IC50","mean_minus_bulk"]])

# Save
sum_path = os.path.join(OUTDIR, f"{SANGER_ID}_sc_predictions_vs_bulk.csv")
summary_sorted.to_csv(sum_path, index=False)
print("Saved summary:", sum_path)

# (Optional) save per-cell predictions
if len(percell_store):
    percell_df = pd.concat(percell_store, ignore_index=True)
    pc_path = os.path.join(OUTDIR, f"{SANGER_ID}_percell_predictions.csv")
    percell_df.to_csv(pc_path, index=False)
    print("Saved per-cell predictions:", pc_path)


In [None]:
import matplotlib.pyplot as plt

top_n = 10  # show all selected drugs
plot_df = summary_sorted.head(top_n).copy()
plt.figure(figsize=(8, 4 + 0.25*top_n))

# error bars from q10 to q90 around mean
ypos = np.arange(len(plot_df))
plt.errorbar(plot_df["pred_mean_cal"], ypos,
             xerr=[plot_df["pred_mean_cal"] - plot_df["pred_q10_cal"],
                   plot_df["pred_q90_cal"] - plot_df["pred_mean_cal"]],
             fmt='o', capsize=3, label="Predicted (mean ± q10–q90)")

# overlay bulk truth if available
if plot_df["bulk_LN_IC50"].notna().any():
    plt.scatter(plot_df["bulk_LN_IC50"], ypos, marker='x', label="Bulk LN_IC50")

plt.yticks(ypos, plot_df["DRUG_ID"])
plt.gca().invert_yaxis()
plt.xlabel("LN_IC50 (lower = more sensitive)")
plt.title(f"{SANGER_ID}: predicted single-cell drug response vs bulk")
plt.legend()
plt.tight_layout()
plt.show()
