# Performance Evaluation in PDX Models

## Contents

- [Data Loading](#data-loading)

In [None]:
from __future__ import annotations

import altair as alt
import pandas as pd
import numpy as np
import typing as t

from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf

from cdrpy.feat.transformers import GroupStandardScaler
from cdrpy.datasets import Dataset

from screendl.utils import evaluation as eval_utils

from utils.plot import DEFAULT_BOXPLOT_CONFIG, MODEL_COLORS, configure_chart
from utils.eval import ResponseRateEvaluator, auroc, select_best_therapy

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        iter_id = file_path.parent.stem.split("_")[-1]
        iter_pred_df = pd.read_csv(file_path)
        iter_pred_df["iter"] = int(iter_id)
        return iter_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
def rescale_predictions(
    pdxo_df: pd.DataFrame, pdx_df: pd.DataFrame
) -> t.Tuple[pd.DataFrame, pd.DataFrame]:
    """"""
    gss = GroupStandardScaler()
    pdxo_df["y_true"] = gss.fit_transform(pdxo_df[["y_true"]], groups=pdxo_df["drug_id"])
    # pdx_df["y_true"] = gss.transform(pdx_df[["y_true"]], groups=pdx_df["drug_id"])

    gss = GroupStandardScaler()
    pdxo_df["y_pred"] = gss.fit_transform(pdxo_df[["y_pred"]], groups=pdxo_df["drug_id"])
    pdx_df["y_pred"] = gss.transform(pdx_df[["y_pred"]], groups=pdx_df["drug_id"])

    return pdxo_df, pdx_df

In [None]:
MODELS = [
    "HiDRA",
    "DeepCDR",
    "ScreenDL-PT",
    "ScreenDL-FT",
    "ScreenDL-SA",
]

MODELS_EXT = [
    "HiDRA",
    "DeepCDR",
    "Ridge (C)",
    "Ridge (P)",
    "Random Forest (C)",
    "Random Forest (P)",
    "ScreenDL-PT",
    "ScreenDL-FT",
    "ScreenDL-SA",
]

In [None]:
model_to_color = {k: v for k, v in MODEL_COLORS.items() if k in MODELS_EXT}
MODEL_COLOR_SCALE = alt.Scale(
    domain=list(model_to_color.keys()), range=list(model_to_color.values())
)

## Data Loading

In [None]:
root = Path("../../../datastore")

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2-HCI-v1.0.0"

drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
cell_meta = pd.read_csv(dataset_dir / "MetaSampleAnnotations.csv", index_col=0)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name="CellModelPassports-GDSC-HCI",
)

cell_ids = D.cell_meta[D.cell_meta["domain"] == "CELL"].index
pdmc_ids = D.cell_meta[D.cell_meta["domain"] == "PDMC"].index

cell_ds = D.select_cells(cell_ids, name="cell_ds")
pdmc_ds = D.select_cells(pdmc_ids, name="pdmc_ds")
pdmc_ds.obs["label"] = pdmc_ds.obs.groupby("drug_id")["label"].transform(stats.zscore)

print(cell_ds)
print(pdmc_ds)

In [None]:
raw_pdx_obs = pd.read_csv(root / "processed/WelmPDX/ScreenClinicalResponseV14B20.csv")
raw_pdx_obs.head()

In [None]:
raw_pdxo_obs = pd.read_csv(
    root / "inputs/CellModelPassports-GDSCv1v2-HCI-v1.0.0/LabelsLogIC50.csv"
)

raw_pdxo_obs = raw_pdxo_obs[~raw_pdxo_obs["cell_id"].str.startswith("SIDM")]
raw_pdxo_obs["label"] = raw_pdxo_obs.groupby("drug_id")["label"].transform(stats.zscore)
raw_pdx_obs.head()

In [None]:
raw_pdxo_screen = pd.read_csv(
    root / "processed/WelmBreastPDMC-v1.0.0/internal/ScreenDoseResponse.csv"
)

with np.errstate(invalid="ignore"):
    # NOTE: a large GR_AOC is better so we multiply Zd values by -1 to flip sign
    grouped = raw_pdxo_screen.groupby("drug_name")
    raw_pdxo_screen["z_GR_AOC"] = grouped["GR_AOC"].transform(
        lambda x: stats.zscore(x) * -1
    )
    raw_pdxo_screen["z_LN_IC50"] = grouped["LN_IC50"].transform(stats.zscore)

raw_pdxo_screen.head()

In [None]:
raw_pdx_data = pd.read_csv(
    root / "processed/WelmPDX/ScreenClinicalResponseV14B20RawData.csv"
)
raw_pdx_data_ctrl = raw_pdx_data[raw_pdx_data["drug_name"] == "Vehicle"]
raw_pdx_data_drug = raw_pdx_data[raw_pdx_data["drug_name"] != "Vehicle"]

In [None]:
# DeepCDR results
output_dir = root / "outputs/experiments/pdx_benchmarking"
path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI-Mutations"
model = "DeepCDR-legacy"
date = "2025-06-15_09-41-23"

run_dir = output_dir / path_fmt.format(dataset, model, date)


deepcdr_pdxo_result = load_multirun_predictions(
    run_dir, "*/predictions.csv", splits=["test"]
)
deepcdr_pdx_result = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

grouped_pdxo = deepcdr_pdxo_result.groupby("iter", as_index=False)
grouped_pdx = deepcdr_pdx_result.groupby("iter", as_index=False)
pdxo_groups, pdx_groups = [], []
for group in grouped_pdxo.grouper.groups.keys():
    # NOTE: predictions are made for each tumor-drug pair so we can just z-score across
    # the enitre distribution of predictions for each drug
    pdxo_group = grouped_pdxo.get_group(group).copy()
    pdxo_group["y_pred"] = pdxo_group.groupby("drug_id")["y_pred"].transform(stats.zscore)
    pdxo_groups.append(pdxo_group)

    pdx_group = grouped_pdx.get_group(group).copy()
    pdx_group["y_pred"] = pdx_group.groupby("drug_id")["y_pred"].transform(stats.zscore)
    pdx_groups.append(pdx_group)

deepcdr_pdxo_result = pd.concat(pdxo_groups).reset_index(drop=True)
deepcdr_pdx_result = pd.concat(pdx_groups).reset_index(drop=True)

In [None]:
# HiDRA results
output_dir = root / "outputs/experiments/pdx_benchmarking"
path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "HiDRA-legacy"
date = "2025-06-15_14-14-43"

run_dir = output_dir / path_fmt.format(dataset, model, date)

hidra_pdxo_result = load_multirun_predictions(
    run_dir, "*/predictions.csv", splits=["test"]
)
hidra_pdx_result = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

grouped_pdxo = hidra_pdxo_result.groupby("iter", as_index=False)
grouped_pdx = hidra_pdx_result.groupby("iter", as_index=False)
pdxo_groups, pdx_groups = [], []
for group in grouped_pdxo.grouper.groups.keys():
    pdxo_group = grouped_pdxo.get_group(group).copy()
    pdxo_group["y_pred"] = pdxo_group.groupby("drug_id")["y_pred"].transform(stats.zscore)
    pdxo_groups.append(pdxo_group)

    pdx_group = grouped_pdx.get_group(group).copy()
    pdx_group["y_pred"] = pdx_group.groupby("drug_id")["y_pred"].transform(stats.zscore)
    pdx_groups.append(pdx_group)

hidra_pdxo_result = pd.concat(pdxo_groups).reset_index(drop=True)
hidra_pdx_result = pd.concat(pdx_groups).reset_index(drop=True)

In [None]:
# ScreenDL results
output_dir = root / "outputs/experiments/pdx_validation"

path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2025-06-13_09-21-25"

run_dir = output_dir / path_fmt.format(dataset, model, date)
screendl_pdxo_results = load_multirun_predictions(run_dir, "*/predictions_pdxo.csv")
screendl_pdx_results = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

# fix model names
mapper = {"base": "ScreenDL-PT", "xfer": "ScreenDL-FT", "screen": "ScreenDL-SA"}
screendl_pdxo_results["model"] = screendl_pdxo_results["model"].map(mapper)
screendl_pdx_results["model"] = screendl_pdx_results["model"].map(mapper)

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
print(OmegaConf.to_yaml(conf.xfer))

In [None]:
# linear/nonlinear baselines
output_dir = root / "outputs/experiments/pdx_baselines"
path_fmt = "{0}/runs/{1}/predictions.pdxo.csv"
dataset = "CellModelPassports-GDSCv1v2-HCI"
run_dates = [
    ("2025-06-28_14-06-23", "Ridge (C)", "Ridge (P)"),
    ("2025-06-28_12-27-18", "Random Forest (C)", "Random Forest (P)"),
]

baseline_results = []
for run_date, name, name0 in run_dates:
    file_path = output_dir / path_fmt.format(dataset, run_date)
    cfg = OmegaConf.load(file_path.parent / ".hydra/config.yaml")
    get_name = lambda x: name0 if x.endswith("-0") else name
    baseline_results.append(
        pd.read_csv(file_path).assign(model=lambda df: df["model"].map(get_name))
    )

baseline_results = pd.concat(baseline_results)
baseline_results.sort_values(["cell_id", "drug_id"]).dropna(subset="y_true").head()

In [None]:
combined_pdx_result = pd.concat(
    [hidra_pdx_result, deepcdr_pdx_result, screendl_pdx_results, baseline_results]
)
combined_pdxo_result = pd.concat(
    [hidra_pdxo_result, deepcdr_pdxo_result, screendl_pdxo_results]
)

## Aggregate Ensemble

In [None]:
agg_funcs = {
    "y_true": "first",
    "y_pred": lambda x: stats.trim_mean(x, 0.2) if len(x) > 1 else x,
}

y_true_df = raw_pdx_obs.drop(columns=["id", "label"])

ensemble_pdx_result = (
    combined_pdx_result.query("cell_id in @screendl_pdx_results.cell_id.unique()")
    .groupby(["model", "drug_id", "cell_id"])
    .aggregate(agg_funcs)
    .reset_index()
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
)

ensemble_pdx_result = ensemble_pdx_result.assign(
    y_true_pcbr=lambda df: df["mRECIST"].isin(["SD", "PR", "CR"]),
    y_true_orr=lambda df: df["mRECIST"].isin(["PR", "CR"]),
    y_true_cr=lambda df: df["mRECIST"].eq("CR"),
)

grouped = ensemble_pdx_result.groupby(["model", "cell_id"])
ensemble_pdx_result["rank_pred"] = grouped["y_pred"].transform(
    lambda x: x.rank(ascending=True, method="first")
)

ensemble_pdx_result.head()

## Identify PDX samples for evaluation

In [None]:
ensemble_pdx_result_f = ensemble_pdx_result.dropna(subset="mRECIST").copy()

# drugs that did not work in any PDX models -> this is likely due to legacy dosing
responders_per_drug = ensemble_pdx_result_f.groupby("drug_id")["y_true_pcbr"].sum()
DRUGS_WITH_RESPONDERS = responders_per_drug[responders_per_drug >= 1].index
ensemble_pdx_result_f = ensemble_pdx_result_f.query("drug_id in @DRUGS_WITH_RESPONDERS")

# drop PDX samples with less than 2 drugs to choose from
drugs_per_PDX = ensemble_pdx_result_f.groupby("cell_id")["drug_id"].nunique()
GOOD_PDX_SAMPLES = drugs_per_PDX[drugs_per_PDX >= 2].index
ensemble_pdx_result_f = ensemble_pdx_result_f.query("cell_id in @GOOD_PDX_SAMPLES")

uniq_cells = sorted(list(ensemble_pdx_result_f["cell_id"].unique()))
uniq_drugs = sorted(list(ensemble_pdx_result_f["drug_id"].unique()))
print(uniq_cells)
print(uniq_drugs)

print(f"No. PDXs: {len(uniq_cells)}")
print(f"No. Drugs: {len(uniq_drugs)}")

temp = ensemble_pdx_result_f.drop_duplicates(["cell_id", "drug_id"])
base_pcbr = temp["y_true_pcbr"].mean()
base_orr = temp["y_true_orr"].mean()
print(f"pCBR: {base_pcbr:.2f}")
print(f"ORR: {base_orr:.2f}")

In [None]:
counts = (
    ensemble_pdx_result_f.groupby(["model", "cell_id"], dropna=False)["drug_id"]
    .nunique(dropna=False)
    .unstack()
)
counts.loc[:, counts.nunique(dropna=False) > 1]

In [None]:
if "Screen - Zd" not in MODEL_COLOR_SCALE.domain:
    MODEL_COLOR_SCALE.domain.insert(0, "Screen - Zd")
    MODEL_COLOR_SCALE.range.insert(0, "lightgray")

In [None]:
groups = [
    [
        "HiDRA",
        "Ridge (C)",
        "Ridge (P)",
        "Random Forest (C)",
        "Random Forest (P)",
        *MODELS[2:],
    ],
    [
        "DeepCDR",
        *MODELS[2:],
    ],
    [
        "Screen - Zd",
        *MODELS[2:],
    ],
]

In [None]:
def render_rr_barchart(
    rr_metrics: pd.DataFrame,
    title: str,
    y_var: str,
    axis: bool = True,
) -> alt.Chart:
    """Render a bar chart for response rate metrics."""
    default_y_axis = alt.Axis(
        grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%"
    )

    base = alt.Chart(rr_metrics).encode(
        alt.X("model:N")
        .sort(rr_metrics["model"].to_list())
        .axis(domainColor="black", labelAngle=-45)
        .scale(paddingOuter=0.15, domain=rr_metrics["model"].to_list())
        .title(None),
    )

    bars = (
        base.mark_bar(stroke="black", size=26, strokeWidth=1)
        .encode(
            alt.Y(f"{y_var}:Q", axis=default_y_axis if axis else None)
            .scale(domain=(0, 1))
            .title(title),
            alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
            tooltip=[alt.Tooltip(f"{y_var}:Q", title=y_var)],
        )
        .properties(height=220, width=30 * rr_metrics.shape[0])
    )

    text = base.mark_text(
        align="left",
        yOffset=-15,
        xOffset=-1,
        fontSize=10,
        font="arial",
        angle=315,
        baseline="middle",
    ).encode(
        alt.YValue(0),
        alt.Text(f"{y_var}:Q", format=".1%"),
    )

    return bars + text

In [None]:
raw_screen_preds_pdxo_ln_ic50 = (
    raw_pdxo_screen.rename(
        columns={"model_id": "cell_id", "drug_name": "drug_id", "z_LN_IC50": "y_pred"}
    )
    .filter(items=["cell_id", "drug_id", "y_pred"])
    .assign(model="Screen - Zd")
)

raw_screen_preds_pdx_ln_ic50 = (
    raw_pdx_obs.rename(columns={"label": "y_true"})
    .filter(items=["cell_id", "drug_id", "y_true", "mRECIST"])
    .drop_duplicates()
    .merge(raw_screen_preds_pdxo_ln_ic50, on=["cell_id", "drug_id"])
    .query("cell_id in @ensemble_pdx_result_f.cell_id")
    .query("drug_id in @ensemble_pdx_result_f.drug_id")
    .dropna(subset="y_pred")
)

raw_screen_pdx_result_f = raw_screen_preds_pdx_ln_ic50.assign(
    y_true_pcbr=lambda df: df["mRECIST"].isin(["SD", "PR", "CR"]),
    y_true_orr=lambda df: df["mRECIST"].isin(["PR", "CR"]),
    y_true_cr=lambda df: df["mRECIST"].eq("CR"),
).copy()

raw_screen_pdx_result_f.head()

In [None]:
combined_source = pd.concat([ensemble_pdx_result_f, raw_screen_pdx_result_f])

In [None]:
def get_baseline_rr(df: pd.DataFrame, n_iter: int = 1000) -> t.Tuple[float, float]:
    """Estimate response rate for random selection"""
    df_uniq = df.drop_duplicates(["cell_id", "drug_id"])

    rand_pcbr = []
    rand_orr = []
    for _ in range(n_iter):
        rr = df_uniq.groupby("cell_id").sample(1)[["y_true_pcbr", "y_true_orr"]].mean()
        rand_pcbr.append(rr["y_true_pcbr"])
        rand_orr.append(rr["y_true_orr"])

    rand_pcbr = np.mean(rand_pcbr)
    rand_orr = np.mean(rand_orr)

    return rand_pcbr, rand_orr

In [None]:
rre = ResponseRateEvaluator(y_true_var="y_true_pcbr", n_iter=10000)

charts = []
for i, group in enumerate(groups):
    temp = combined_source.query("model in @group").copy()

    counts = temp.groupby(["cell_id", "drug_id"])["model"].nunique()
    valid_pairs = set(counts[counts == len(group)].index.to_list())

    pairs = list(zip(temp["cell_id"], temp["drug_id"]))
    keep_inds = [i for i, pair in enumerate(pairs) if pair in valid_pairs]

    temp_f = temp.iloc[keep_inds].copy()

    ensemble_rr_metrics = (
        temp_f.reset_index()
        .dropna(subset="y_pred")
        .astype({"y_true_pcbr": int})
        .groupby("model")
        .apply(rre.eval)
        .to_frame("pCBR")
        .loc[group]
        .reset_index()
    )

    rand_pcbr, _ = get_baseline_rr(temp_f)
    print(f"Random pCBR: {rand_pcbr:.2f} (n = {len(valid_pairs)})")

    bars = render_rr_barchart(
        ensemble_rr_metrics,
        title="Pre-Clinical Benefit Rate (%)",
        y_var="pCBR",
        axis=(i == 0),
    )

    rule = (
        alt.Chart(pd.DataFrame({"y": [rand_pcbr]}))
        .mark_rule(stroke="black", strokeWidth=1.5, strokeDash=[2, 2])
        .encode(y="y:Q")
    )

    charts.append(bars + rule)

pcbr_chart = alt.hconcat(*charts, spacing=10)
configure_chart(pcbr_chart)

In [None]:
rre = ResponseRateEvaluator(y_true_var="y_true_orr", n_iter=10000)

charts = []
for i, group in enumerate(groups):
    temp = combined_source.query("model in @group").copy()

    counts = temp.groupby(["cell_id", "drug_id"])["model"].nunique()
    valid_pairs = set(counts[counts == len(group)].index.to_list())

    pairs = list(zip(temp["cell_id"], temp["drug_id"]))
    keep_inds = [i for i, pair in enumerate(pairs) if pair in valid_pairs]

    temp_f = temp.iloc[keep_inds].copy()

    ensemble_rr_metrics = (
        temp_f.reset_index()
        .dropna(subset="y_pred")
        .astype({"y_true_orr": int})
        .groupby("model")
        .apply(rre.eval)
        .to_frame("ORR")
        .loc[group]
        .reset_index()
    )

    _, rand_orr = get_baseline_rr(temp_f)
    print(f"Random ORR: {rand_orr:.2f} (n = {len(valid_pairs)})")

    bars = render_rr_barchart(
        ensemble_rr_metrics.set_index("model").loc[group].reset_index(),
        title="Objective Response Rate (%)",
        y_var="ORR",
        axis=(i == 0),
    )

    rule = (
        alt.Chart(pd.DataFrame({"y": [rand_orr]}))
        .mark_rule(stroke="black", strokeWidth=1.5, strokeDash=[2, 2])
        .encode(y="y:Q")
    )

    charts.append(bars + rule)

orr_chart = alt.hconcat(*charts, spacing=10)
configure_chart(orr_chart)

## Confusion Matrices

In [None]:
ensemble_selected_drugs = (
    ensemble_pdx_result_f.dropna(subset="y_pred")
    .astype({"y_true_pcbr": int, "y_true_orr": int})
    .groupby(["model", "cell_id"], as_index=False)
    .apply(select_best_therapy, y_pred_var="y_pred")
)

ensemble_selected_drugs.query("model == 'ScreenDL-SA'").head()

In [None]:
pcbr_confusion_matrices = (
    ensemble_pdx_result_f.dropna(subset="mRECIST")
    .merge(
        ensemble_selected_drugs[["model", "cell_id", "drug_id"]].assign(
            was_selected=True
        ),
        on=["model", "cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
    .groupby("model")[["y_true_pcbr", "was_selected"]]
    .value_counts(dropna=False)
)

stats.fisher_exact(pcbr_confusion_matrices.xs("ScreenDL-SA").unstack().values)

In [None]:
pcbr_source = pcbr_confusion_matrices.xs("ScreenDL-SA")
print(stats.fisher_exact(pcbr_source.unstack().values))
pcbr_source.unstack()

In [None]:
base_pcbr = (
    alt.Chart(
        pcbr_source.to_frame("count")
        .reset_index()
        .assign(
            y_true_pcbr=lambda df: df["y_true_pcbr"].map({True: "Yes", False: "No"}),
            was_selected=lambda df: df["was_selected"].map({True: "Yes", False: "No"}),
        )
    )
    .mark_rect()
    .encode(
        alt.X("was_selected:N")
        .axis(labelAngle=0)
        .scale(reverse=True)
        .title("ScreenDL Optimal Therapy"),
        alt.Y("y_true_pcbr:N").scale(reverse=True).title("Pre-Clinical Benefit"),
        alt.Color("count:Q")
        .scale(scheme="blues")
        .legend(gradientLength=220)
        .title(None),
        tooltip=["y_true_pcbr:N", "was_selected:N", "count:Q"],
    )
    .properties(width=220, height=220)
)


pcbr_confusion_mtx_plot = base_pcbr + base_pcbr.mark_text(baseline="middle").encode(
    alt.Text("count:Q", format=".0f"),
    alt.condition(alt.datum.count < 10, alt.ColorValue("black"), alt.ColorValue("white")),
)

configure_chart(pcbr_confusion_mtx_plot)

In [None]:
orr_confusion_matrices = (
    ensemble_pdx_result_f.dropna(subset="mRECIST")
    .merge(
        ensemble_selected_drugs[["model", "cell_id", "drug_id"]].assign(
            was_selected=True
        ),
        on=["model", "cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
    .groupby("model")[["y_true_orr", "was_selected"]]
    .value_counts(dropna=False)
)

stats.fisher_exact(orr_confusion_matrices.xs("ScreenDL-SA").unstack().values)

In [None]:
orr_source = orr_confusion_matrices.xs("ScreenDL-SA")
print(stats.fisher_exact(orr_source.unstack().values))
orr_source.unstack()

In [None]:
base_orr = (
    alt.Chart(
        orr_source.to_frame("count")
        .reset_index()
        .assign(
            y_true_orr=lambda df: df["y_true_orr"].map({True: "Yes", False: "No"}),
            was_selected=lambda df: df["was_selected"].map({True: "Yes", False: "No"}),
        )
    )
    .mark_rect()
    .encode(
        alt.X("was_selected:N")
        .axis(labelAngle=0)
        .scale(reverse=True)
        .title("ScreenDL Optimal Therapy"),
        alt.Y("y_true_orr:N").scale(reverse=True).title("Objective Response"),
        alt.Color("count:Q").scale(scheme="blues").legend(gradientLength=220).title(None),
        tooltip=["y_true_orr:N", "was_selected:N", "count:Q"],
    )
    .properties(width=220, height=220)
)


orr_confusion_mtx_plot = base_orr + base_orr.mark_text(baseline="middle").encode(
    alt.Text("count:Q", format=".0f"),
    alt.condition(alt.datum.count < 30, alt.ColorValue("black"), alt.ColorValue("white")),
)

configure_chart(orr_confusion_mtx_plot)

In [None]:
configure_chart(
    alt.hconcat(
        pcbr_confusion_mtx_plot, orr_confusion_mtx_plot, spacing=50
    ).resolve_legend(color="independent")
)

In [None]:
raw_screen_selected_drugs = (
    raw_screen_pdx_result_f.dropna(subset="y_pred")
    .astype({"y_true_pcbr": int, "y_true_orr": int})
    .groupby(["model", "cell_id"], as_index=False)
    .apply(select_best_therapy, y_pred_var="y_pred")
)

raw_screen_selected_drugs.query("model == 'Screen - Zd'").head()

In [None]:
# statistical tests for raw screening data
pcbr_confusion_matrices = (
    raw_screen_pdx_result_f.dropna(subset="mRECIST")
    .merge(
        raw_screen_selected_drugs[["model", "cell_id", "drug_id"]].assign(
            was_selected=True
        ),
        on=["model", "cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
    .groupby("model")[["y_true_pcbr", "was_selected"]]
    .value_counts(dropna=False)
)

stats.fisher_exact(pcbr_confusion_matrices.xs("Screen - Zd").unstack().values)

In [None]:
orr_confusion_matrices = (
    raw_screen_pdx_result_f.dropna(subset="mRECIST")
    .merge(
        raw_screen_selected_drugs[["model", "cell_id", "drug_id"]].assign(
            was_selected=True
        ),
        on=["model", "cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
    .groupby("model")[["y_true_orr", "was_selected"]]
    .value_counts(dropna=False)
)

stats.fisher_exact(orr_confusion_matrices.xs("Screen - Zd").unstack().values)

## Waterfall Plots

In [None]:
screendl_sa_ensemble_result = (
    ensemble_pdx_result_f.query("model == 'ScreenDL-SA'")
    .merge(
        ensemble_selected_drugs.query("model == 'ScreenDL-SA'")[
            ["cell_id", "drug_id"]
        ].assign(was_selected=True),
        on=["cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
    .assign(
        x=lambda df: df["cell_id"] + " + " + df["drug_id"],
        r_avg=lambda df: df["r_avg"] / 100,
    )
)

In [None]:
not_selected_chart = (
    alt.Chart(screendl_sa_ensemble_result.query("was_selected == False"))
    .mark_bar(size=14, stroke="black", strokeWidth=1)
    .encode(
        alt.X("x:N")
        .sort("-y")
        .axis(grid=False, labels=False, ticks=False, offset=-110)
        .scale(paddingOuter=0.2)
        .title(None),
        alt.Y("r_avg:Q")
        .axis(grid=False, tickCount=5, titlePadding=10, format="%")
        .scale(domain=(-1, 1), clamp=True)
        .title(["Change in tumor volume (%)", "(BestAvgResponse)"]),
        alt.Color("mRECIST:N").scale(
            domain=("CR", "PR", "SD", "PD"),
            range=("#9ECAE9", "#89D27A", "#F2CF5B", "#FF9D98"),
        ),
    )
    .properties(
        width=16.5 * screendl_sa_ensemble_result.query("was_selected == False").shape[0],
        height=220,
    )
)

selected_chart = (
    alt.Chart(screendl_sa_ensemble_result.query("was_selected == True"))
    .mark_bar(size=14, stroke="black", strokeWidth=1)
    .encode(
        alt.X("x:N")
        .sort("-y")
        .axis(grid=False, labels=False, ticks=False, offset=-110)
        .scale(paddingOuter=0.2)
        .title(None),
        alt.Y("r_avg:Q")
        .axis(grid=False, tickCount=5, titlePadding=10, format="%")
        .scale(domain=(-1, 1), clamp=True)
        .title(["Change in tumor volume (%)", "(BestAvgResponse)"]),
        alt.Color("mRECIST:N").scale(
            domain=("CR", "PR", "SD", "PD"),
            range=("#9ECAE9", "#89D27A", "#F2CF5B", "#FF9D98"),
        ),
    )
    .properties(
        width=16.5 * screendl_sa_ensemble_result.query("was_selected == True").shape[0],
        height=220,
    )
)

waterfall_plot = alt.hconcat(selected_chart, not_selected_chart, spacing=60)
configure_chart(waterfall_plot)

## Final Combined Figure

In [None]:
chart = alt.vconcat(
    alt.hconcat(pcbr_chart, orr_chart, spacing=60),
    alt.hconcat(waterfall_plot, spacing=60),
    spacing=50,
)

configure_chart(chart.resolve_scale(color="independent"))

## Delta Tumor Volume Plots

In [None]:
source = (
    raw_pdx_data_drug.set_index(["sample_id", "drug_name", "exp_id"])
    .loc[pd.Index(screendl_sa_ensemble_result[["cell_id", "drug_id", "exp_id"]])]
    .reset_index()
    .drop(columns="drug_id")
    .rename(columns={"drug_name": "drug_id", "sample_id": "cell_id"})
    .merge(
        ensemble_selected_drugs.query("model == 'ScreenDL-SA'")[
            ["cell_id", "drug_id"]
        ].assign(was_selected=True),
        on=["cell_id", "drug_id"],
        how="left",
    )
    .fillna({"was_selected": False})
)

keep_tumors = (
    source.groupby("cell_id")["drug_id"].nunique().loc[lambda x: x >= 3].index.to_list()
)

source = source.query("cell_id in @keep_tumors").copy()
source = source[~((source["mouse_id"] == "M4 (mm3)")  & (source["cell_id"] == "HCI045"))]

source.head()

In [None]:
# NOTE: we filter out observations beyond 40 days after treatment start
base = (
    alt.Chart(source)
    .transform_filter(alt.datum.day <= 40)
    .encode(
        alt.X("day:Q")
        .axis(labelAngle=0, tickCount=4, grid=False, titlePadding=10)
        .scale(domainMax=40, domainMin=0)
        .title("Time after treatment (d)"),
        alt.Color("was_selected:N").scale(
            domain=[True, False], range=["red", "darkgray"]
        ),
        alt.Detail("drug_id:Q"),
    )
)

lines = base.mark_line(strokeWidth=1.5, point=alt.MarkConfig(size=20), clip=False).encode(
    alt.Y("mean(rel_tumor_vol_pct):Q")
    .axis(grid=False, tickCount=4, minExtent=35, titlePadding=5)
    .scale(domainMin=0, nice=True)
    .title("Relative tumor volume (%)"),
)

ebars = base.mark_errorbar(
    extent="stderr",
    rule=alt.MarkConfig(strokeWidth=1.5),
    ticks=alt.MarkConfig(width=5, height=5),
).encode(
    alt.Y("rel_tumor_vol_pct:Q")
    .axis(grid=False, tickCount=4, minExtent=35, titlePadding=5)
    .scale(domainMin=0, nice=True)
    .title("Relative tumor volume (%)"),
)

delta_tv_chart = (
    (lines + ebars)
    .properties(width=200, height=150)
    .facet(alt.Facet("cell_id:N", header=None), columns=6, spacing=30)
    .resolve_scale(y="independent", x="independent")
)

configure_chart(delta_tv_chart)

## Check for Bias in the Screening Data

In [None]:
combined_raw_obs = (
    raw_pdx_obs.drop(columns="id")
    .merge(
        raw_pdxo_obs.drop(columns="id"),
        on=["drug_id", "cell_id"],
        how="outer",
        suffixes=("_pdx", "_pdxo"),
    )
    .assign(screened_in_pdx=lambda df: df["label_pdx"].notna())
)

# filter out invalid PDX samples
pairs = ensemble_pdx_result_f[["cell_id", "drug_id"]].drop_duplicates()
pairs = list(zip(pairs["cell_id"], pairs["drug_id"]))

t1 = (
    combined_raw_obs.query("screened_in_pdx == True")
    .set_index(["cell_id", "drug_id"])
    .loc[pairs]
    .reset_index()
)
t2 = combined_raw_obs.query("screened_in_pdx == False").query("cell_id in @t1.cell_id")
combined_raw_obs = pd.concat([t1, t2], ignore_index=True)

combined_raw_obs.groupby("screened_in_pdx")["label_pdxo"].describe()

In [None]:
stats.mannwhitneyu(
    *combined_raw_obs.dropna(subset="label_pdxo")
    .groupby("screened_in_pdx")["label_pdxo"]
    .agg(list)
)

In [None]:
source = combined_raw_obs.copy()
source["screened_in_pdx"] = source["screened_in_pdx"].map({True: "Y", False: "N"})

y_true_boxes = (
    alt.Chart(source)
    .mark_boxplot(**DEFAULT_BOXPLOT_CONFIG)
    .encode(
        alt.X("screened_in_pdx:N").axis(labelAngle=0, titlePadding=10).title(None),
        alt.Y("label_pdxo:Q")
        .axis(grid=False, tickCount=4, titlePadding=10)
        .scale(domain=(-5, 5))
        .title("Observed Response (Zd)"),
        alt.Color("screened_in_pdx:N")
        .scale(domain=["Y", "N"], range=["#55A24A", "lightgray"])
        .legend(orient="top", columns=1)
        .title(None),
    )
    .properties(width=40 * 2, height=250)
)

configure_chart(y_true_boxes)

In [None]:
mRECIST_color = alt.Color("mRECIST:N").scale(
    domain=("CR", "PR", "SD", "PD"),
    range=("#9ECAE9", "#89D27A", "#F2CF5B", "#FF9D98"),
)

In [None]:
sorted_tumors = (
    source.query("screened_in_pdx == 'Y'")
    .groupby("cell_id")["drug_id"]
    .nunique()
    .sort_values(ascending=False)
    .index.to_list()
)

base = (
    alt.Chart(source)
    .transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
    .encode(alt.X("screened_in_pdx:N").axis(labelAngle=0).title(None))
    .properties(width=40, height=250)
)

points = base.mark_circle(size=40).encode(
    alt.Y("label_pdxo:Q")
    .axis(grid=False, offset=10)
    .scale(domain=(-5, 5))
    .title("Observed Response (Zd)"),
    alt.XOffset("jitter:Q"),
    alt.condition(
        alt.datum.screened_in_pdx == "Y",
        mRECIST_color.legend(orient="top"),
        alt.ColorValue("darkgray"),
    ),
    alt.Size("screened_in_pdx:O")
    .scale(domain=("Y", "N"), range=(55, 15))
    .legend(orient="top")
    .title("Tested in PDX?"),
    alt.Opacity("screened_in_pdx:O")
    .scale(domain=("Y", "N"), range=(0.9, 0.5))
    .legend(None)
    .title(None),
    tooltip=[
        alt.Tooltip("drug_id:N", title="Drug"),
        alt.Tooltip("cell_id:N", title="Tumor"),
    ],
)

y_true_chart = points.facet(
    alt.Column("cell_id:N").sort(sorted_tumors).title(None), spacing=10
)
configure_chart(y_true_chart)

In [None]:
agg_funcs = {"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)}

ensemble_pdxo_result = (
    combined_pdxo_result.groupby(["model", "drug_id", "cell_id"])
    .aggregate(agg_funcs)
    .reset_index()
)

grouped = ensemble_pdxo_result.groupby(["model", "cell_id"])
ensemble_pdxo_result["rank_pred"] = grouped["y_pred"].transform("rank")
ensemble_pdxo_result["rank_true"] = grouped["y_true"].transform("rank")

cols = ["cell_id", "drug_id", "mRECIST", "r_best", "r_avg"]
y_true_df = ensemble_pdx_result_f[cols].drop_duplicates()

ensemble_pdxo_result = ensemble_pdxo_result.merge(
    y_true_df,
    on=["cell_id", "drug_id"],
    how="left",
)

ensemble_pdxo_result.head()

In [None]:
source = (
    ensemble_pdxo_result.query("model == 'ScreenDL-SA'")
    .query("cell_id in @combined_raw_obs.cell_id")
    .assign(
        screened_in_pdx=lambda df: df["mRECIST"].notna(),
        in_pdxo_dataset=lambda df: df["drug_id"].isin(set(pdmc_ds.drug_ids)),
    )
    .copy()
)

pairs = ensemble_pdx_result_f[["cell_id", "drug_id"]].drop_duplicates()
pairs = list(zip(pairs["cell_id"], pairs["drug_id"]))

t1 = (
    source.query("screened_in_pdx == True")
    .set_index(["cell_id", "drug_id"])
    .loc[pairs]
    .reset_index()
)
t2 = source.query("screened_in_pdx == False").query("cell_id in @t1.cell_id")
source = pd.concat([t1, t2], ignore_index=True)

In [None]:
# select best therapy for each PDX
selected_all = source.groupby("cell_id").apply(select_best_therapy)
mu_all = selected_all["y_pred"].mean()
print(f"Mean prediction (all drugs): {mu_all:.3f}")

# select best therapy from drugs screened in PDX
selected_pdx_only = (
    source.query("screened_in_pdx == True")
    .groupby("cell_id")
    .apply(select_best_therapy)
    .copy()
)
mu_pdx_only = selected_pdx_only["y_pred"].mean()
print(f"Mean prediction (PDX drugs only): {mu_pdx_only:.3f}")
print(stats.mannwhitneyu(selected_all["y_pred"], selected_pdx_only["y_pred"]))
source["screened_in_pdx"] = source["screened_in_pdx"].map({True: "Y", False: "N"})

url = "./temp/pdx_y_pred.json"
source.to_json(url, orient="records")

In [None]:
df1 = selected_pdx_only[["cell_id", "drug_id", "y_pred"]].assign(pool="Tested in PDX")
df2 = selected_all[["cell_id", "drug_id", "y_pred"]].assign(pool="All drugs")
temp = pd.concat([df1, df2], ignore_index=True)

In [None]:
y_pred_boxes = (
    alt.Chart(temp)
    .mark_boxplot(**DEFAULT_BOXPLOT_CONFIG)
    .encode(
        alt.X("pool:N")
        .sort(["Tested in PDX", "All drugs"])
        .axis(labelAngle=-45, titlePadding=10)
        .title(None),
        alt.Y("y_pred:Q")
        .axis(grid=False, tickCount=4, titlePadding=10)
        .scale(domain=(-3, 2))
        .title("Predicted Response (Zd)"),
        alt.Color("pool:N")
        .scale(domain=["All drugs", "Tested in PDX"], range=["gray", "lightgray"])
        .legend(orient="top", columns=1)
        .title(None),
    )
    .properties(width=40 * 2, height=250)
)

configure_chart(y_pred_boxes)

In [None]:
source["drug_id"].nunique(), source["cell_id"].nunique()

In [None]:
base = (
    alt.Chart(url)
    .transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
    .encode(alt.X("screened_in_pdx:N").axis(labelAngle=0).title(None))
    .properties(width=40, height=250)
)

points = base.mark_circle().encode(
    alt.Y("y_pred:Q")
    .axis(grid=False, offset=10)
    .scale(domain=(-3, 2))
    .title("Predicted Response (Zd)"),
    alt.XOffset("jitter:Q"),
    alt.condition(
        alt.datum.screened_in_pdx == "Y",
        mRECIST_color.legend(orient="top"),
        alt.ColorValue("darkgray"),
    ),
    alt.Size("screened_in_pdx:O")
    .scale(domain=("Y", "N"), range=(50, 5))
    .title("Tested in PDX?")
    .legend(orient="top"),
    alt.Opacity("screened_in_pdx:O")
    .scale(domain=("Y", "N"), range=(0.9, 0.5))
    .legend(None)
    .title(None),
    tooltip=[
        alt.Tooltip("drug_id:N", title="Drug"),
        alt.Tooltip("cell_id:N", title="Tumor"),
    ],
)

y_pred_chart = points.facet(
    alt.Column("cell_id:N").sort(sorted_tumors).title(None), spacing=10
)
configure_chart(y_pred_chart)

In [None]:
top = alt.hconcat(y_true_boxes, y_true_chart, spacing=50).resolve_scale(
    color="independent", size="independent"
)
bot = alt.hconcat(y_pred_boxes, y_pred_chart, spacing=50).resolve_scale(
    color="independent", size="independent"
)

configure_chart(
    alt.vconcat(top, bot, spacing=70).resolve_scale(
        color="independent", size="independent"
    )
)