# Performance Evaluation in Breast Cancer PDxOs

## Contents

- [Data Loading](#data-loading)
- [Performance across all PDxO/drug pairs](#performance-across-all-PDxO-drug-pairs)
    - [High confidence drugs](#high-confidence-drugs)
    - [Performance boxplots](#performance-boxplots)
    - [Performance stratified by ID vs OOD drugs](#performance-stratified-by-id-vs-ood-drugs)
    - [Performance stratified by drug mechanism](#performance-stratified-by-drug-mechanism)
    - [auROC analysis](#auroc-analysis)
    - [Response rate analysis](#response-rate-analysis)
- [ScreenAhead with vs without domain-specific fine-tuning](#screenahead-with-vs-without-domain-specific-fine-tuning)

In [None]:
from __future__ import annotations

import json
import itertools

import altair as alt
import altair_forge as af
import pandas as pd
import numpy as np
import sklearn.metrics as skm

from io import StringIO
from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf

from cdrpy.datasets import Dataset

from screendl import model as screendl
from screendl.utils import evaluation as eval_utils
from screendl.utils.drug_selectors import get_response_matrix

from utils.plot import (
    MODEL_COLORS,
    MODEL_SHAPES,
    DEFAULT_BOXPLOT_CONFIG,
    configure_chart,
)
from utils.const import FIXED_DRUG_NAMES, DRUG_TO_PATHWAY_EXT
from utils.eval import ResponseRateEvaluator, select_best_therapy, auroc

In [None]:
MIN_OBS = 10  # minimum number of tumors for a drug to be considered in evaluation

## Data Loading

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        fold_id = file_path.parent.stem.split("_")[-1]
        fold_pred_df = pd.read_csv(file_path)
        fold_pred_df["fold"] = int(fold_id)
        return fold_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
root = Path("../../../datastore")

In [None]:
drug_types_path = root / "processed/DrugAnnotations/drug_types.json"
with open(drug_types_path, "r") as fh:
    drug_to_type = json.load(fh)

In [None]:
fixed_drug_types = {"chemo": "Chemo", "targeted": "Targeted", "other": "Other"}
drug_to_type = {k: fixed_drug_types[v] for k, v in drug_to_type.items()}

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2-HCI-v1.0.0"

drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
drug_meta["type"] = drug_meta.index.map(drug_to_type)
drug_encoders = screendl.load_drug_features(
    dataset_dir / "ScreenDL/FeatureMorganFingerprints.csv"
)

cell_meta = pd.read_csv(dataset_dir / "MetaSampleAnnotations.csv", index_col=0)
cell_encoders = screendl.load_cell_features(
    dataset_dir / "ScreenDL/FeatureGeneExpression.csv"
)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_encoders=cell_encoders,
    drug_encoders=drug_encoders,
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name="CellModelPassports-GDSC-HCI",
)

cell_ids = D.cell_meta[D.cell_meta["domain"] == "CELL"].index
pdmc_ids = D.cell_meta[D.cell_meta["domain"] == "PDMC"].index

cell_ds = D.select_cells(cell_ids, name="cell_ds")
pdmc_ds = D.select_cells(pdmc_ids, name="pdmc_ds")

print(cell_ds)
print(pdmc_ds)

In [None]:
drug_meta_ext = pd.read_csv(
    root / "inputs/CellModelPassports-GDSCv1v2/MetaDrugAnnotations.csv",
    index_col=0,
    usecols=["drug_id", "targets", "target_pathway"],
)

drug_to_pathway = drug_meta_ext["target_pathway"].to_dict()
fixed_pathways = {"EGFR signaling": "EGFR/HER2 signaling"}
drug_to_pathway = {k: fixed_pathways.get(v, v) for k, v in drug_to_pathway.items()}
drug_to_pathway.update(DRUG_TO_PATHWAY_EXT)

In [None]:
model_results = {}
output_dir = root / "outputs"
rescale = lambda df, col: df.groupby(["fold", "drug_id"])[col].transform(stats.zscore)

In [None]:
# DeepCDR results

path_fmt = "experiments/pdx_benchmarking/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI-Mutations"
model = "DeepCDR-legacy"
date = "2025-06-15_09-28-48"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

_temp = load_multirun_predictions(run_dir, run_regex, splits=["test"]).assign(
    y_true=lambda df: rescale(df, "y_true"),
    y_pred=lambda df: rescale(df, "y_pred"),
    model=model.split("-")[0],
    was_screened=False,
)
model_results[model.split("-")[0]] = _temp

In [None]:
# HiDRA results

path_fmt = "experiments/pdx_benchmarking/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "HiDRA-legacy"
date = "2025-06-15_14-15-22"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

_temp = load_multirun_predictions(run_dir, run_regex, splits=["test"]).assign(
    y_true=lambda df: rescale(df, "y_true"),
    y_pred=lambda df: rescale(df, "y_pred"),
    model=model.split("-")[0],
    was_screened=False,
)
model_results[model.split("-")[0]] = _temp

In [None]:
path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2025-06-23_09-59-12"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen": f"{model}-SA (ALL)",
}

# NOTE: predictions are already rescaled for ScreenDL
model_results[model] = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
assert conf.screenahead.opt.exclude_drugs is None
print(OmegaConf.to_yaml(conf.xfer))

In [None]:
# load baselines
path_fmt = "experiments/pdxo_baselines/{0}/runs/{1}/predictions.csv"
dataset = "CellModelPassports-GDSCv1v2-HCI"
run_dates = [
    ("2025-06-28_21-28-10", "Ridge (C)", "Ridge (P)"),
    ("2025-06-28_22-10-42", "Random Forest (C)", "Random Forest (P)"),
]

baseline_results = []
for run_date, name, name0 in run_dates:
    file_path = output_dir / path_fmt.format(dataset, run_date)
    cfg = OmegaConf.load(file_path.parent / ".hydra/config.yaml")
    get_name = lambda x: name0 if x.endswith("-0") else name
    baseline_results.append(
        pd.read_csv(file_path).assign(model=lambda df: df["model"].map(get_name))
    )

baseline_results = pd.concat(baseline_results)
baseline_results.head()

In [None]:
y_true_df = pdmc_ds.obs.copy()
y_true_df["y_true"] = y_true_df.groupby("drug_id")["label"].transform(stats.zscore)
y_true_df = y_true_df[["cell_id", "drug_id", "y_true"]]
y_true_df.sort_values(["cell_id", "drug_id"]).head()

In [None]:
baseline_results = (
    baseline_results.drop(columns="y_true")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
    .dropna(subset="y_true")
)

In [None]:
grouped = baseline_results.groupby(["model", "drug_id"])
baseline_corrs = grouped.apply(eval_utils.pcorr, min_obs=MIN_OBS)
baseline_corrs.groupby("model").describe()

In [None]:
baseline_results_no_ood = baseline_results.query("drug_id in @cell_ds.drug_ids")
grouped = baseline_results_no_ood.groupby(["model", "drug_id"])
grouped.apply(eval_utils.pcorr, min_obs=MIN_OBS).groupby(["model"]).describe()

In [None]:
model_results_df = (
    pd.concat(model_results.values())
    .drop(columns="y_true")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
    .dropna(subset="y_true")
)
model_results_df.head()

In [None]:
DROP_TUMORS = [
    # NOTE: sample swap identified by Somalier
    "BCM3561",
    "BCM5471",
]

model_results_df = model_results_df.query("cell_id not in @DROP_TUMORS")
baseline_results = baseline_results.query("cell_id not in @DROP_TUMORS")

In [None]:
# look at how many times a drug was selected for screening
drug_num_times_screened = (
    model_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .drop_duplicates(["cell_id", "drug_id"])
    .groupby(["drug_id", "was_screened"])["cell_id"]
    .nunique()
    .unstack()
    .fillna(0)
    .sort_values(False)
)

drug_num_times_screened.describe()

In [None]:
screened_counts = (
    model_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .query("was_screened == True")
    .groupby("cell_id")["drug_id"]
    .nunique()
    .value_counts()
)

# confirm that the same number of drugs were screened for each PDXO
assert screened_counts.shape[0] == 1

In [None]:
# get ScreenAhead NBS results
_temp_nbs = (
    model_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .query("was_screened == False")
    .assign(model="ScreenDL-SA (NBS)")
    .copy()
)

model_results_df = pd.concat([model_results_df, _temp_nbs])

In [None]:
MODELS = [
    "HiDRA",
    "DeepCDR",
    "ScreenDL-PT",
    "ScreenDL-FT",
    "ScreenDL-SA (NBS)",
    "ScreenDL-SA (ALL)",
]

MODELS_EXT = [
    "HiDRA",
    "DeepCDR",
    "Ridge (C)",
    "Ridge (P)",
    "Random Forest (C)",
    "Random Forest (P)",
    "ScreenDL-PT",
    "ScreenDL-FT",
    "ScreenDL-SA (NBS)",
    "ScreenDL-SA (ALL)",
]

In [None]:
model_to_color = {m: MODEL_COLORS[m] for m in MODELS_EXT}
model_to_shape = {m: MODEL_SHAPES[m] for m in MODELS_EXT}

MODEL_COLOR_SCALE = alt.Scale(
    domain=list(model_to_color.keys()), range=list(model_to_color.values())
)
MODEL_SHAPE_SCALE = alt.Scale(
    domain=list(model_to_shape.keys()), range=list(model_to_shape.values())
)

In [None]:
BOXPLOT_CONFIG = DEFAULT_BOXPLOT_CONFIG.copy()
BOXPLOT_CONFIG["size"] = 34

## Performance across all PDxO-drug pairs

In [None]:
pdxo_fold_metrics = model_results_df.groupby(["model", "fold", "drug_id"]).apply(
    eval_utils.pcorr, min_obs=MIN_OBS
)

pdxo_agg_metrics = (
    pdxo_fold_metrics.groupby(["model", "drug_id"])
    .mean()
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
    )
)

pdxo_agg_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
pdxo_fold_metrics_no_ood = (
    pdxo_fold_metrics.to_frame("pcc")
    .query("drug_id in @cell_ds.drug_ids")["pcc"]
    .groupby(["model", "fold"])
    .median()
    .unstack()
)

pdxo_fold_metrics_no_ood.loc[MODELS]

In [None]:
# confirm that the same 12 drugs were screened for each SA ensemble member
num_screened_drugs_per_tumor = (
    model_results_df.query("was_screened == True")
    .groupby("cell_id")["drug_id"]
    .nunique()
    .value_counts()
)

assert num_screened_drugs_per_tumor.index.nunique() == 1
assert num_screened_drugs_per_tumor.index[0] == conf.screenahead.opt.n_drugs

In [None]:
# here, we consider ScreenDL as an ensemble model trained on different samples of the
# cell line data -> we take the mean of ten model predictions as the predicted value
ensemble_results_df = (
    model_results_df.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .reset_index()
)

In [None]:
ensemble_tumor_metrics = (
    ensemble_results_df.dropna(subset=["y_true", "y_pred"])
    .groupby(["model", "cell_id"])
    .apply(eval_utils.pcorr, min_obs=10)
    .to_frame("pcc")
    .reset_index()
)

ensemble_tumor_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
ensemble_metrics = (
    ensemble_results_df.groupby(["model", "drug_id"])
    .apply(eval_utils.pcorr, min_obs=MIN_OBS)
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
        is_ood_drug=lambda df: ~df["drug_id"].isin(cell_ds.drug_ids),
    )
)

ensemble_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
top_pt_ood_drugs = ensemble_metrics.query(
    "model == 'ScreenDL-PT' and is_ood_drug == True"
).sort_values("pcc", ascending=False)

top_pt_ood_drugs.head()

In [None]:
ensemble_metrics_no_ood_drugs = ensemble_metrics.query("drug_id in @cell_ds.drug_ids")
ensemble_metrics_no_ood_drugs.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
temp = ensemble_metrics.set_index(["drug_id", "model"])["pcc"].unstack()

m1 = "ScreenDL-PT"
m2 = "ScreenDL-FT"
X = temp[[m1, m2]].dropna()

x1 = X[m1]
x2 = X[m2]

deltas = x2 - x1
print(f"Mean Delta: {deltas.mean()}")
print(f"No. Improved: {(deltas > 0).sum():.2f}")
print(f"Pct. Improved: {(deltas > 0).sum() / len(deltas) * 100:.2f}")

In [None]:
baseline_metrics = (
    baseline_results.groupby(["model", "drug_id"])
    .apply(eval_utils.pcorr, min_obs=MIN_OBS)
    .to_frame("pcc")
    .reset_index()
)
baseline_metrics.groupby("model")["pcc"].describe()

In [None]:
baseline_metrics_no_ood = baseline_metrics.query("drug_id in @cell_ds.drug_ids")
baseline_metrics_no_ood.groupby("model")["pcc"].describe()

### High confidence drugs

In [None]:
summarize_high_conf_drugs = lambda df: pd.Series(
    {"n_drugs": df.shape[0], "n_high_conf_drugs": df[df["pcc"] >= 0.5].shape[0]}
)

In [None]:
sa_ensemble_metrics = ensemble_metrics.query("model.str.contains('ScreenDL-SA')")
hc_drug_summary = (
    sa_ensemble_metrics.dropna(subset="pcc")
    .groupby("model")
    .apply(summarize_high_conf_drugs)
    .dropna()
    .assign(
        pct_high_conf=lambda df: round(100 * df["n_high_conf_drugs"] / df["n_drugs"], 2)
    )
)
hc_drug_summary

In [None]:
source = (
    sa_ensemble_metrics.query("model == 'ScreenDL-SA (ALL)'")
    .query("pcc >= 0.5")
    .assign(drug_name=lambda df: df["drug_id"].map(lambda x: FIXED_DRUG_NAMES.get(x, x)))
)

color_domain = list(source.sort_values("pcc", ascending=False)["pathway"].unique())

high_conf_drugs_chart_all = (
    alt.Chart(source)
    .mark_bar(stroke="black", size=11.5, strokeWidth=1, opacity=1)
    .encode(
        alt.X("drug_name:N", sort="-y")
        .axis(domainColor="black", labelAngle=-70)
        .scale(paddingOuter=0.15)
        .title(None),
        alt.Y("pcc:Q")
        .axis(grid=False, tickCount=5, domainColor="black", titlePadding=10)
        .scale(domain=(0.0, 1.0))
        .title("Pearson Correlation"),
        alt.Color("pathway:N")
        .scale(domain=color_domain, scheme="tableau20")
        .legend(
            columns=4,
            orient="top",
            offset=15,
            symbolStrokeWidth=1,
            direction="vertical",
        )
        .title(None),
    )
    .properties(height=220, width=13 * source.shape[0] + 30)
)

configure_chart(high_conf_drugs_chart_all)

### Performance boxplots

In [None]:
temp_ = ensemble_metrics.set_index(["drug_id", "model"])["pcc"].unstack()[MODELS]

m1 = "ScreenDL-FT"
m2 = "ScreenDL-PT"

X = temp_[[m1, m2]].dropna()
print(stats.wilcoxon(X[m1], X[m2], alternative="greater"))
(temp_[m1] - temp_[m2]).dropna().describe()

In [None]:
baseline_source = baseline_metrics_no_ood.copy()
baseline_source["drug_type"] = baseline_source["drug_id"].map(drug_to_type)

source = pd.concat([ensemble_metrics_no_ood_drugs, baseline_source])

In [None]:
ensemble_metrics.drop_duplicates("drug_id")["drug_type"].value_counts(dropna=False)

In [None]:
ensemble_metrics_no_ood_drugs.groupby(["model"])["pcc"].describe()

In [None]:
ensemble_metrics_no_ood_drugs.groupby(["model", "drug_type"])["pcc"].describe()

In [None]:
boxes_all = (
    alt.Chart(source, width=40 * len(MODELS_EXT), height=300)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N")
        .axis(labelAngle=-70, labelPadding=5)
        .sort(MODELS_EXT)
        .title(None),
        alt.Y("pcc:Q")
        .axis(titlePadding=10, tickCount=4, grid=False)
        .scale(domain=[-0.8, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

boxes_types = (
    alt.Chart(source.dropna(subset=["drug_type"]), width=40 * 2, height=300)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("drug_type:N")
        .axis(labelAngle=-70, labelPadding=5, orient="bottom")
        .sort(["Targeted", "Chemo"])
        .title(None),
        alt.Y("pcc:Q").axis(None).scale(domain=[-0.8, 1]).title(None),
        alt.Color("model:N").scale(scheme="dark2").legend(None),
        alt.Column("model:N")
        .header(orient="top")
        .spacing(10)
        .sort(MODELS_EXT)
        .title(None),
    )
)

ensemble_boxes = alt.hconcat(boxes_all, boxes_types, spacing=5)
configure_chart(ensemble_boxes)

### Performance stratified by drug mechanism

In [None]:
source["pathway"] = source["drug_id"].map(drug_to_pathway)

In [None]:
pathway_counts = (
    source.groupby(["model", "pathway"])["drug_id"].nunique().unstack(0).min(axis=1)
)

keep_pathways = pathway_counts[pathway_counts >= 2].index

median_pathway_metrics = (
    source.assign(
        pathway=lambda df: df["pathway"].map(
            lambda x: x if x in keep_pathways else "Other"
        )
    )
    .groupby(["model", "pathway"])["pcc"]
    .agg(["median", "size"])
    .reset_index()
)

pathway_order = (
    median_pathway_metrics.query("model == 'ScreenDL-SA (ALL)'")
    .sort_values("median", ascending=False)["pathway"]
    .to_list()
)

In [None]:
circles = (
    alt.Chart(median_pathway_metrics, width=300, height=220)
    .mark_point(size=90, opacity=0.8, stroke="black", strokeWidth=0.5, filled=True)
    .encode(
        alt.X(
            "median:Q",
            axis=alt.Axis(
                titlePadding=10,
                values=[-0.2, 0.0, 0.2, 0.4, 0.6, 0.8],
                grid=False,
            ),
            scale=alt.Scale(domain=(-0.3, 0.9)),
            title="Median Pearson Correlation Per Drug",
        ),
        alt.Y("pathway:N", sort=pathway_order, title=None),
        alt.Color(
            "model:N",
            scale=MODEL_COLOR_SCALE,
            legend=alt.Legend(
                orient="top",
                title=None,
                symbolStrokeWidth=1,
                columns=3,
                direction="horizontal",
            ),
        ),
        alt.Shape(
            "model:N",
            scale=MODEL_SHAPE_SCALE,
            legend=alt.Legend(
                orient="top",
                title=None,
                symbolStrokeWidth=1,
                columns=3,
                direction="horizontal",
            ),
        ),
        tooltip=["median:Q", "pathway:N", "model:N"],
    )
)

counts = median_pathway_metrics.query("model == 'ScreenDL-PT'")
bars = alt.Chart(counts, width=60, height=220).encode(
    alt.X(
        "size:Q",
        axis=alt.Axis(grid=False, values=[0, 15], titlePadding=10),
        scale=alt.Scale(domain=(0, 15)),
        title="No. Drugs",
    ),
    alt.Y(
        "pathway:N",
        axis=alt.Axis(ticks=False, labels=False, offset=0, domainOpacity=0),
        sort=pathway_order,
        title=None,
    ),
    text="size",
)
bars = bars.mark_bar(stroke="black", strokeWidth=0.5, size=15, color="#999999")
bars += bars.mark_text(align="left", dx=4, fontSize=10)

pathway_performance_chart = alt.hconcat(circles, bars, spacing=5)
configure_chart(pathway_performance_chart)

In [None]:
n_pathways = median_pathway_metrics["pathway"].nunique()
best_model_by_pathway = (
    median_pathway_metrics.dropna(subset="median")
    .query("~model.str.contains('ScreenDL-SA')")
    .groupby("pathway")
    .apply(lambda g: g["model"].loc[g["median"].idxmax()])
    .value_counts()
)
best_model_by_pathway / n_pathways * 100

In [None]:
# fraction of pathways for which fine-tuning improved performance (FT vs PT)
temp = median_pathway_metrics.set_index(["pathway", "model"])["median"].unstack()
((temp["ScreenDL-FT"] - temp["ScreenDL-PT"]) > 0).sum() / temp.shape[0]

### auROC analysis

In [None]:
y_true_df = (
    ensemble_results_df[["model", "cell_id", "drug_id"]]
    .merge(pdmc_ds.obs, on=["cell_id", "drug_id"])
    .assign(
        label=lambda df: df.groupby(["model", "drug_id"])["label"].transform(stats.zscore)
    )
)

y_true_df["y_true_class"] = (
    y_true_df.groupby(["model", "drug_id"])["label"]
    .transform(lambda x: x < x.quantile(0.3))
    .astype(int)
)
y_true_df["y_true_class_t"] = (
    y_true_df.groupby(["model", "cell_id"])["label"]
    .transform(lambda x: x < x.quantile(0.3))
    .astype(int)
)

ensemble_results_df = ensemble_results_df.merge(
    y_true_df.drop(columns=["label", "id"]), on=["model", "cell_id", "drug_id"]
)

In [None]:
ensemble_auroc_metrics = (
    ensemble_results_df.groupby(["model", "drug_id"])
    .apply(auroc, col1="y_true_class", col2="y_pred", min_obs=MIN_OBS)
    .to_frame(name="auROC")
    .reset_index()
)

ensemble_auroc_metrics.groupby("model")["auROC"].describe().loc[MODELS]

In [None]:
ensembl_auroc_metrics_no_ood_drugs = (
    ensemble_results_df.query("drug_id in @cell_ds.drug_ids")
    .groupby(["model", "drug_id"])
    .apply(auroc, col1="y_true_class", col2="y_pred", min_obs=MIN_OBS)
    .to_frame(name="auROC")
    .reset_index()
)

ensembl_auroc_metrics_no_ood_drugs.groupby("model")["auROC"].describe().loc[MODELS]

In [None]:
y_true_df_b = (
    baseline_results[["model", "cell_id", "drug_id"]]
    .merge(pdmc_ds.obs, on=["cell_id", "drug_id"])
    .assign(
        label=lambda df: df.groupby(["model", "drug_id"])["label"].transform(stats.zscore)
    )
)

y_true_df_b["y_true_class"] = (
    y_true_df_b.groupby(["model", "drug_id"])["label"]
    .transform(lambda x: x < x.quantile(0.3))
    .astype(int)
)

baseline_results = baseline_results.merge(
    y_true_df_b.drop(columns=["label", "id"]),
    on=["model", "cell_id", "drug_id"],
)

baseline_auroc_metrics = (
    baseline_results.fillna({"was_screened": False})
    .groupby(["model", "drug_id"])
    .apply(auroc, col1="y_true_class", col2="y_pred", min_obs=MIN_OBS)
    .to_frame(name="auROC")
    .reset_index()
)

baseline_auroc_metrics.groupby(["model"])["auROC"].describe()

In [None]:
baseline_auroc_metrics_no_ood_drugs = baseline_auroc_metrics.query(
    "drug_id in @cell_ds.drug_ids"
)
baseline_auroc_metrics_no_ood_drugs.groupby(["model"])["auROC"].describe()

In [None]:
X = ensemble_auroc_metrics.set_index(["model", "drug_id"])["auROC"].unstack(0).dropna()

res = []
for m1, m2 in itertools.combinations(MODELS, 2):
    x1 = X[m1]
    x2 = X[m2]
    s, p = stats.wilcoxon(x1, x2)
    res.append([m1, m2, s, p])

cols = ["m1", "m2", "statistic", "pvalue"]
res = pd.DataFrame(res, columns=cols).set_index(["m1", "m2"])
res.round(3)

In [None]:
source = pd.concat([ensemble_auroc_metrics, baseline_auroc_metrics])
sorted_models = source.groupby("model")["auROC"].median().sort_values().index.to_list()

bars = (
    alt.Chart(source)
    .mark_bar(stroke="black", strokeWidth=1)
    .encode(
        alt.X("median(auROC):Q")
        .axis(grid=False, tickCount=5, domainColor="black", titlePadding=10)
        .scale(domain=(0.4, 1.0))
        .title("auROC"),
        alt.Y("model:N")
        .axis(domainColor="black")
        .scale(domain=list(reversed(sorted_models)), paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(width=275, height=220)
)


rule = (
    alt.Chart(pd.DataFrame({"x": [0.5]}))
    .mark_rule(stroke="black", strokeDash=[4, 3], strokeWidth=1.25)
    .encode(x="x:Q")
)

error_bars = (
    alt.Chart(source)
    .mark_errorbar(
        extent="iqr", ticks=alt.MarkConfig(size=5, color="black", strokeWidth=1)
    )
    .encode(alt.Y("model:N"), alt.X("auROC:Q"))
)

auroc_metrics_chart = alt.layer(bars, error_bars, rule)
configure_chart(auroc_metrics_chart)

### Response rate analysis

In [None]:
random_response_rates = []
for _ in range(1000):
    selected_random = y_true_df.groupby("cell_id").sample(1)
    response_rate = selected_random["y_true_class"].sum() / len(selected_random)
    random_response_rates.append(response_rate)

random_response_rate = np.mean(random_response_rates)
print(f"Random Response Rate {random_response_rate:.2f}")

In [None]:
rre = ResponseRateEvaluator(
    y_pred_var="y_pred",
    y_true_var="y_true_class",
    n_iter=10000,
)

In [None]:
ensemble_response_rates = (
    ensemble_results_df.query("model != 'ScreenDL-SA (NBS)'")
    .groupby("model")
    .apply(rre.eval)
)

ensemble_response_rates.loc[MODELS[:-2] + [MODELS[-1]]]

In [None]:
baseline_response_rates = baseline_results.groupby("model").apply(rre.eval)
baseline_response_rates

In [None]:
sa_selected_screened_counts = (
    ensemble_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .groupby("cell_id", as_index=False)
    .apply(select_best_therapy, y_pred_var="y_pred")
    .merge(model_results_df, on=["model", "cell_id", "drug_id"])
    .drop_duplicates(subset="cell_id")["was_screened"]
    .value_counts()
)

sa_selected_screened_counts

In [None]:
source = (
    pd.concat([ensemble_response_rates, baseline_response_rates])
    .to_frame(name="response_rate")
    .reset_index()
)

In [None]:
y_order = source.groupby("model")["response_rate"].mean().sort_values().index.to_list()

base = alt.Chart(source)
bars = (
    base.mark_bar(stroke="black", strokeWidth=1)
    .encode(
        alt.X("response_rate:Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 1))
        .title("Response Rate (%)"),
        alt.Y("model:N")
        .axis(domainColor="black")
        .scale(paddingOuter=0.15, domain=list(reversed(y_order)))
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(width=275, height=220)
)

rule = (
    alt.Chart(pd.DataFrame({"x": [random_response_rate]}))
    .mark_rule(stroke="black", strokeDash=[4, 3], strokeWidth=1.25)
    .encode(x="x:Q")
)

text = base.mark_text(align="left", dx=6, fontSize=10).encode(
    alt.X("response_rate:Q").title("Response Rate (%)"),
    alt.Y("model:N").title(None),
    alt.Text("mean(response_rate):Q", format=".1%"),
)


response_rate_chart = alt.layer(bars, rule, text)
configure_chart(response_rate_chart)

In [None]:
top = ensemble_boxes
mid2 = alt.hconcat(high_conf_drugs_chart_all, pathway_performance_chart, spacing=60)
mid2 = mid2.resolve_scale(color="independent", shape="independent")
bot = alt.hconcat(auroc_metrics_chart, response_rate_chart, spacing=60)
bot = bot.resolve_scale(color="independent", shape="independent")

chart = configure_chart(
    alt.vconcat(top, mid2, bot, spacing=50).resolve_scale(
        color="independent", shape="independent"
    )
)
chart.configure_legend(titleFontStyle="italic", titleFont="arial", labelFont="arial")

## PDxO Screening Heatmap

In [None]:
hr_status = pd.read_csv(
    root / "raw/Welm/receptor_status_combined.pdx.final.csv",
    index_col=0,
)
hr_status.head()

In [None]:
M = get_response_matrix(pdmc_ds, impute=False).T

tumors_per_drug = (~M.isna()).sum()
keep_drugs = tumors_per_drug[tumors_per_drug >= M.shape[0] * 0.7].index

tumor_na_counts = M[keep_drugs].isna().sum(axis=1)
keep_tumors = tumor_na_counts[tumor_na_counts == 0].index

M = M.loc[keep_tumors, keep_drugs].transform(stats.zscore).T
M.index = M.index.map(lambda x: FIXED_DRUG_NAMES.get(x, x))
print(M.shape)

In [None]:
col_margin_data = (
    hr_status.loc[M.columns]
    .dropna()
    .melt(ignore_index=False, var_name="receptor", value_name="status")
    .reset_index()
    .applymap(lambda x: x.upper())
)
col_margin_data.columns = ["x", "y", "value"]

col_margin_y_scale = alt.Scale(domain=("ER", "PR", "HER2"))
col_margin_z_scale = alt.Scale(domain=("NEG", "POS"), range=("white", "black"))

In [None]:
row_margin_data = pd.DataFrame(
    {"x": 1, "y": M.index, "value": M.index.map(drug_to_pathway)}
)
row_margin_z_scale = alt.Scale(scheme="tableau20")

In [None]:
hmap = af.cluster_heatmap(
    M,
    height=400,
    width=890,
    row_dendro_size=40,
    row_margin_data=row_margin_data,
    row_margin_z_scale=row_margin_z_scale,
    row_margin_legend_title="Target Pathway",
    row_margin_legend_config=alt.LegendConfig(symbolStrokeColor="black"),
    col_dendro_size=40,
    col_margin_data=col_margin_data,
    col_margin_z_scale=col_margin_z_scale,
    col_margin_y_scale=col_margin_y_scale,
    col_margin_legend_title="Receptor Status",
    col_margin_legend_config=alt.LegendConfig(columns=2, symbolStrokeColor="black"),
    legend_title="Z-Score ln(IC50)",
    legend_config=alt.LegendConfig(
        gradientLength=100, gradientThickness=15, direction="horizontal", tickCount=5
    ),
    legend_spacing=15,
)

(
    hmap.configure_axis(
        labelFont="arial", tickColor="black", tickSize=3
    ).configure_legend(
        titleFont="arial",
        titleFontStyle="italic",
        titleFontWeight="bold",
        labelFont="arial",
        labelFontStyle="regular",
    )
)

In [None]:
# test for enhanced sensitivity to PI3K/MTOR targeting drugs in ER+ lines
M = get_response_matrix(pdmc_ds, impute=False).T

pathway_drugs = drug_meta_ext.query("target_pathway == 'PI3K/MTOR signaling'")
pathway_drugs = pathway_drugs.index.to_list()

pathway_data = (
    M.filter(items=pathway_drugs, axis=1)
    .melt(ignore_index=False, value_name="Zd")
    .dropna()
    .reset_index()
    .assign(
        Zd=lambda df: df.groupby("drug_id")["Zd"].transform(stats.zscore),
        ER=lambda df: df["cell_id"].map(hr_status["ER"]),
    )
)

grouped = pathway_data.groupby("ER")["Zd"].agg(list).to_dict()
print(stats.mannwhitneyu(grouped["neg"], grouped["pos"]))
pathway_data.groupby("ER")["Zd"].describe()

In [None]:
# test for increased platinum sensitivity in TNBC lines
temp_ = (
    M["Carboplatin"]
    .dropna()
    .to_frame(name="Zd")
    .transform(stats.zscore)
    .join(((hr_status == "neg").sum(axis=1) == 3).to_frame("is_TNBC"))
)
grouped = temp_.groupby("is_TNBC")["Zd"].agg(list).to_dict()
print(stats.mannwhitneyu(grouped[True], grouped[False]))
temp_.groupby("is_TNBC")["Zd"].describe()

In [None]:
def dd_plot(d1: str, d2: str, M: pd.DataFrame) -> alt.Chart:
    """Plot drug-drug correlation."""
    X = M[[d1, d2]].dropna().transform(stats.zscore)
    r, p = stats.pearsonr(X[d1], X[d2])
    print(f"Pearson correlation between {d1} and {d2}: {r:.3f}, p-value: {p:.3e}")

    return (
        alt.Chart(X)
        .mark_circle(size=40, opacity=0.8, color="darkgray")
        .encode(
            alt.X(f"{d1}:Q").axis(grid=False).scale(domain=(-5, 4)).title(f"{d1} (Zd)"),
            alt.Y(f"{d2}:Q").axis(grid=False).scale(domain=(-5, 4)).title(f"{d2} (Zd)"),
        )
        .properties(width=200, height=200)
    )

In [None]:
pairs = [
    ("Olaparib", "Talazoparib"),
    ("Birinapant", "Tolinapant"),
    ("Trametinib", "Selumetinib"),
]
plots = [dd_plot(d1, d2, M) for d1, d2 in pairs]
configure_chart(alt.hconcat(*plots, spacing=20))

In [None]:
# test for enhanced sensitivity to PI3K/MTOR targeting drugs in ER+ lines
M = get_response_matrix(pdmc_ds, impute=False).T

pathway_drugs = drug_meta_ext.query("target_pathway == 'PI3K/MTOR signaling'")
pathway_drugs = pathway_drugs.index.to_list()
print(pathway_drugs)

pathway_data = (
    M.filter(items=pathway_drugs, axis=1)
    .melt(ignore_index=False, value_name="Zd")
    .dropna()
    .reset_index()
    .assign(
        Zd=lambda df: df.groupby("drug_id")["Zd"].transform(stats.zscore),
        ER=lambda df: df["cell_id"].map(hr_status["ER"]),
    )
)

grouped = pathway_data.groupby("ER")["Zd"].agg(list).to_dict()
print(stats.mannwhitneyu(grouped["neg"], grouped["pos"]))
pathway_data.groupby("ER")["Zd"].describe()

## Analysis by Tumor Subtype

In [None]:
hr_status = pd.read_csv(
    root / "raw/Welm/receptor_status_combined.pdx.final.csv", index_col=0
)

# assign TNBC status
hr_status["is_TNBC"] = hr_status.apply(lambda r: all(x == "neg" for x in r), axis=1)
hr_status["is_HER2"] = hr_status["HER2"] == "pos"
hr_status["is_ER"] = hr_status["ER"] == "pos"
hr_status["is_PR"] = hr_status["PR"] == "pos"
hr_status["is_HR"] = (hr_status["ER"] == "pos") | (hr_status["PR"] == "pos")

hr_status.head()

In [None]:
y_true_df = pdmc_ds.obs.copy()
y_true_df["y_true"] = y_true_df.groupby("drug_id")["label"].transform(stats.zscore)
y_true_df = y_true_df[["cell_id", "drug_id", "y_true"]]
y_true_df.head()

In [None]:
screendl_sa_preds = (
    model_results["ScreenDL"]
    .drop(columns="y_true")
    .query("cell_id not in @DROP_TUMORS")
    .query("model == 'ScreenDL-SA (ALL)'")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
    .merge(hr_status, on="cell_id", how="left")
    .dropna(subset=["y_pred"])
    .copy()
)

screendl_sa_preds.head()

In [None]:
def get_significance_str(p: float) -> str:
    if p < 0.001:
        return "***"
    elif p < 0.01:
        return "**"
    elif p < 0.05:
        return "*"
    else:
        return "ns"

In [None]:
boxplot_config = DEFAULT_BOXPLOT_CONFIG.copy()
boxplot_config["size"] = 35

In [None]:
chemo_drugs = [
    k
    for k, v in drug_to_type.items()
    if v == "Chemo" and k not in ["Carboplatin", "Cisplatin", "Oxaliplatin"]
]

source = screendl_sa_preds.query("drug_id in @chemo_drugs").dropna(subset="is_TNBC")
source["has_biomarker"] = source["is_TNBC"]
source["x"] = source["has_biomarker"].map({True: "Yes", False: "No"})

url = "./temp/chemo_preds.json"
source.to_json(url, orient="records")

# Mann-Whitney U test for platinum sensitivity in TNBC vs non-TNBC
vals = source.groupby("has_biomarker")["y_pred"].agg(list)
U, p = stats.mannwhitneyu(vals[False], vals[True], alternative="greater")
print(f"U-statistic: {U:.2f}, p-value: {p:.10f}")

label = get_significance_str(p)

# y-position just above the max of both boxes:
y_max = source["y_pred"].max()
y = y_max + 1.0

annots = pd.DataFrame([{"x1": "Yes", "x2": "No", "y": y, "label": label}])

box = (
    alt.Chart(url)
    .mark_boxplot(**boxplot_config)
    .encode(
        alt.X("x:N").axis(labelAngle=0).scale(domain=("Yes", "No")).title("TNBC"),
        alt.Y(
            "y_pred:Q",
            axis=alt.Axis(grid=False, offset=10, values=[-6, -4, -2, 0, 2, 4, 6]),
        )
        .scale(domain=(-6, 6))
        .title(f"Predicted Response (Zd)"),
        alt.Color("x:N")
        .scale(domain=("Yes", "No"), range=("#55A24A", "lightgray"))
        .legend(None)
        .title(None),
    )
    .properties(
        width=85,
        height=250,
        title=alt.TitleParams(
            ["Chemotherapy", "(Non-Platinum)"],
            fontSize=10,
            font="arial",
            fontWeight="normal",
            offset=0,
        ),
    )
)

rule = alt.Chart(annots).mark_rule().encode(x="x1:N", x2="x2:N", y="y:Q")

text = (
    alt.Chart(annots)
    .mark_text(dy=-5, dx=20.5)
    .encode(alt.X("x1:N"), alt.X2("x2:N"), alt.Y("y:Q"), alt.Text("label:N"))
)

chemo_chart = box + rule + text
configure_chart(chemo_chart)

In [None]:
platinum_agents = [
    "Carboplatin",
    "Cisplatin",
    "Oxaliplatin",
]

platinum_charts = []
for i, drug_id in enumerate(platinum_agents):
    drug_source = screendl_sa_preds.query("drug_id == @drug_id").dropna(subset="is_TNBC")
    drug_source["has_biomarker"] = drug_source["is_TNBC"]
    drug_source["x"] = drug_source["has_biomarker"].map({True: "Yes", False: "No"})

    # Mann-Whitney U test for platinum sensitivity in TNBC vs non-TNBC
    vals = drug_source.groupby("has_biomarker")["y_pred"].agg(list)
    U, p = stats.mannwhitneyu(vals[False], vals[True], alternative="greater")
    print(f"{drug_id} U-statistic: {U:.2f}, p-value: {p:.5f}")

    label = get_significance_str(p)

    # y-position just above the max of both boxes:
    y_max = drug_source["y_pred"].max()
    y = y_max + 1.0

    annots = pd.DataFrame([{"x1": "Yes", "x2": "No", "y": y, "label": label}])

    box = (
        alt.Chart(drug_source)
        .mark_boxplot(**boxplot_config)
        .encode(
            alt.X("x:N").axis(labelAngle=0).scale(domain=("Yes", "No")).title("TNBC"),
            alt.Y(
                "y_pred:Q",
                axis=(
                    None
                    if i > 0
                    else alt.Axis(grid=False, offset=10, values=[-6, -4, -2, 0, 2, 4, 6])
                ),
            )
            .scale(domain=(-6, 6))
            .title(f"Predicted Response (Zd)"),
            alt.Color("x:N")
            .scale(domain=("Yes", "No"), range=("#55A24A", "lightgray"))
            .legend(None)
            .title(None),
        )
        .properties(
            width=85,
            height=250,
            title=alt.TitleParams(
                drug_id,
                fontSize=10,
                font="arial",
                fontWeight="normal",
                offset=0 if i == 0 else 5,  # HACK: align titles
            ),
        )
    )

    rule = alt.Chart(annots).mark_rule().encode(x="x1:N", x2="x2:N", y="y:Q")

    text = (
        alt.Chart(annots)
        .mark_text(dy=-5, dx=20.5)
        .encode(alt.X("x1:N"), alt.X2("x2:N"), alt.Y("y:Q"), alt.Text("label:N"))
    )

    platinum_charts.append(box + rule + text)

platinum_chart = alt.hconcat(*platinum_charts, spacing=20)
configure_chart(platinum_chart).display()

In [None]:
pik3ca_mtor_drugs = [
    "Alpelisib",
    "AZD5363",  # Capivasertib
    "Everolimus",
]

pik3ca_mtor_charts = []
for i, drug_id in enumerate(pik3ca_mtor_drugs):
    drug_source = screendl_sa_preds.query("drug_id == @drug_id").dropna(
        subset=["is_HR", "is_HER2"]
    )
    # drug_source["has_biomarker"] = drug_source["is_HR"] & ~drug_source["is_HER2"]
    drug_source["has_biomarker"] = drug_source["is_HR"] & ~drug_source["is_HER2"]
    drug_source["x"] = drug_source["has_biomarker"].map({True: "Yes", False: "No"})

    # Mann-Whitney U test for platinum sensitivity in TNBC vs non-TNBC
    vals = drug_source.groupby("has_biomarker")["y_pred"].agg(list)
    U, p = stats.mannwhitneyu(vals[False], vals[True], alternative="greater")
    print(f"{drug_id} U-statistic: {U:.2f}, p-value: {p:.5f}")

    label = get_significance_str(p)

    # y-position just above the max of both boxes:
    y_max = drug_source["y_pred"].max()
    y = y_max + 1.0

    annots = pd.DataFrame([{"x1": "Yes", "x2": "No", "y": y, "label": label}])

    box = (
        alt.Chart(drug_source)
        .mark_boxplot(**boxplot_config)
        .encode(
            alt.X("x:N")
            .axis(labelAngle=0)
            .scale(domain=("Yes", "No"))
            .title("HR+/HER2-"),
            alt.Y(
                "y_pred:Q",
                axis=(
                    None
                    if i > 0
                    else alt.Axis(grid=False, offset=10, values=[-6, -4, -2, 0, 2, 4, 6])
                ),
            )
            .scale(domain=(-6, 6))
            .title(f"Predicted Response (Zd)"),
            alt.Color("x:N")
            .scale(domain=("Yes", "No"), range=("#55A24A", "lightgray"))
            .legend(None)
            .title(None),
        )
        .properties(
            width=85,
            height=250,
            title=alt.TitleParams(
                drug_id,
                fontSize=10,
                font="arial",
                fontWeight="normal",
                offset=0 if i == 0 else 5,  # HACK: align titles
            ),
        )
    )

    rule = alt.Chart(annots).mark_rule().encode(x="x1:N", x2="x2:N", y="y:Q")

    text = (
        alt.Chart(annots)
        .mark_text(dy=-5, dx=20.5)
        .encode(alt.X("x1:N"), alt.X2("x2:N"), alt.Y("y:Q"), alt.Text("label:N"))
    )

    pik3ca_mtor_charts.append(box + rule + text)

pik3ca_mtor_chart = alt.hconcat(*pik3ca_mtor_charts, spacing=20)
configure_chart(pik3ca_mtor_chart).display()

In [None]:
her2_drugs = [
    "Lapatinib",
    "Neratinib",
    "Tucatinib",
]

her2_charts = []
for i, drug_id in enumerate(her2_drugs):
    drug_source = screendl_sa_preds.query("drug_id == @drug_id").dropna(
        subset=["is_HER2"]
    )
    drug_source["has_biomarker"] = drug_source["is_HER2"]
    drug_source["x"] = drug_source["has_biomarker"].map({True: "Yes", False: "No"})

    # Mann-Whitney U test for platinum sensitivity in TNBC vs non-TNBC
    vals = drug_source.groupby("has_biomarker")["y_pred"].agg(list)
    U, p = stats.mannwhitneyu(vals[False], vals[True], alternative="greater")
    print(f"{drug_id} U-statistic: {U:.2f}, p-value: {p:.5f}")

    label = get_significance_str(p)

    # y-position just above the max of both boxes:
    y_max = drug_source["y_pred"].max()
    y = y_max + 1.0

    annots = pd.DataFrame([{"x1": "Yes", "x2": "No", "y": y, "label": label}])

    box = (
        alt.Chart(drug_source)
        .mark_boxplot(**boxplot_config)
        .encode(
            alt.X("x:N").axis(labelAngle=0).scale(domain=("Yes", "No")).title("HER2+"),
            alt.Y(
                "y_pred:Q",
                axis=(
                    None
                    if i > 0
                    else alt.Axis(grid=False, offset=10, values=[-6, -4, -2, 0, 2, 4, 6])
                ),
            )
            .scale(domain=(-6, 6))
            .title(f"Predicted Response (Zd)"),
            alt.Color("x:N")
            .scale(domain=("Yes", "No"), range=("#55A24A", "lightgray"))
            .legend(None)
            .title(None),
        )
        .properties(
            width=85,
            height=250,
            title=alt.TitleParams(
                drug_id,
                fontSize=10,
                font="arial",
                fontWeight="normal",
                offset=0 if i == 0 else 5,  # HACK: align titles
            ),
        )
    )

    rule = alt.Chart(annots).mark_rule().encode(x="x1:N", x2="x2:N", y="y:Q")

    text = (
        alt.Chart(annots)
        .mark_text(dy=-5, dx=20.5)
        .encode(alt.X("x1:N"), alt.X2("x2:N"), alt.Y("y:Q"), alt.Text("label:N"))
    )

    her2_charts.append(box + rule + text)

her2_chart = alt.hconcat(*her2_charts, spacing=20)
configure_chart(her2_chart).display()

In [None]:
configure_chart(
    alt.hconcat(
        pik3ca_mtor_chart,
        her2_chart,
        platinum_chart,
        chemo_chart,
        spacing=40,
    )
)

## ScreenAhead with vs without domain-specific fine-tuning

In [None]:
path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2025-06-21_21-05-53"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen (fine-tune)": f"{model}-SA (+FT)",
    "screen (no-fine-tune)": f"{model}-SA (-FT)",
}

# NOTE: predictions are already rescaled for ScreenDL
results_df = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)

In [None]:
results_df = (
    results_df.dropna(subset="y_true")
    .groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .reset_index()
)
results_df.head()

In [None]:
from utils.plot import NPGPalette

MODELS = ["ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA (-FT)", "ScreenDL-SA (+FT)"]

MODEL_COLOR_SCALE = alt.Scale(
    domain=MODELS,
    range=(
        NPGPalette.PURPLE_LIGHT.value,
        NPGPalette.PURPLE.value,
        NPGPalette.PURPLE_DARK.value,
        NPGPalette.PURPLE_DARK.value,
    ),
)

In [None]:
y_true_df = pdmc_ds.obs.copy()
y_true_df["y_true"] = y_true_df.groupby("drug_id")["label"].transform(stats.zscore)
y_true_df = y_true_df[["cell_id", "drug_id", "y_true"]]
y_true_df.sort_values(["cell_id", "drug_id"]).head()

In [None]:
ensembl_pcc_result = (
    results_df.drop(columns="y_true")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
    .dropna(subset="y_true")
    .groupby(["model", "drug_id"])
    .apply(eval_utils.pcorr, min_obs=MIN_OBS)
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
        is_ood_drug=lambda df: ~df["drug_id"].isin(cell_ds.drug_ids),
    )
)

ensembl_pcc_result.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
temp_ = ensembl_pcc_result.set_index(["drug_id", "model"])["pcc"].unstack().dropna()
m1 = "ScreenDL-SA (-FT)"
m2 = "ScreenDL-SA (+FT)"

stats.wilcoxon(temp_[m1], temp_[m2])

In [None]:
source = ensembl_pcc_result.copy()
source["screened_in_cells"] = source["is_ood_drug"].map({True: "No", False: "Yes"})

boxes_1 = (
    alt.Chart(source, width=35 * len(MODELS), height=250)
    .mark_boxplot(**DEFAULT_BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N").axis(labelAngle=-45, labelPadding=5).sort(MODELS).title(None),
        alt.Y("pcc:Q")
        .axis(tickCount=5, grid=False, titlePadding=10)
        .scale(domain=[-1, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

boxes_2 = (
    alt.Chart(source, width=35 * 2, height=250)
    .mark_boxplot(**DEFAULT_BOXPLOT_CONFIG)
    .encode(
        alt.X("screened_in_cells:N")
        .axis(labelAngle=0, labelPadding=5, orient="bottom")
        .sort(["No", "Yes"])
        .title(None),
        alt.Y("pcc:Q").axis(None).scale(domain=[-1, 1]).title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
        alt.Column("model:N").header(orient="top").spacing(10).sort(MODELS).title(None),
    )
)

configure_chart(alt.hconcat(boxes_1, boxes_2, spacing=10))