# Performance Evaluation in Breast Cancer PDxOs

## Contents

- [Data Loading](#data-loading)
- [Performance in Original PDxO Cohort](#performance-in-original-pdxo-cohort)
- [Performance across all PDxO/drug pairs](#performance-across-all-PDxO-drug-pairs)
    - [High confidence drugs](#high-confidence-drugs)
    - [Performance boxplots](#performance-boxplots)
    - [Performance stratified by ID vs OOD drugs](#performance-stratified-by-id-vs-ood-drugs)
    - [Performance stratified by drug mechanism](#performance-stratified-by-drug-mechanism)
    - [auROC analysis](#auroc-analysis)
    - [Response rate analysis](#response-rate-analysis)
    - [Fig 3. ScreenDL achieves accurate response prediction in high-risk/metastatic breast cancer PDxO models](#fig-3-screendl-achieves-accurate-response-prediction-in-high-riskmetastatic-breast-cancer-pdxo-models)
- [Performance across common PDxO/drug pairs](#performance-across-common-pdxodrug-pairs)
- [ScreenAhead with vs without domain-specific fine-tuning](#screenahead-with-vs-without-domain-specific-fine-tuning)

In [None]:
# DEEPCDR_ROOT="/scratch/ucgd/lustre-labs/marth/scratch/u0871891/projects/screendl/pkg/DeepCDR/prog" python scripts/runners/run.py -m model=DeepCDR-legacy dataset.preprocess.norm=global dataset=CellModelPassports-GDSCv1v2-HCI-Mutations
# HIDRA_ROOT="/scratch/ucgd/lustre-labs/marth/scratch/u0871891/projects/screendl/pkg/HiDRA" python scripts/runners/run.py -m model=HiDRA-legacy dataset.preprocess.norm=global dataset=CellModelPassports-GDSCv1v2-HCI

In [None]:
from __future__ import annotations

import json
import itertools

import altair as alt
import altair_forge as af
import pandas as pd
import numpy as np
import sklearn.metrics as skm

from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf

from cdrpy.datasets import Dataset

from screendl import model as screendl
from screendl.utils import evaluation as eval_utils
from screendl.utils.drug_selectors import get_response_matrix

## Data Loading

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        fold_id = file_path.parent.stem.split("_")[-1]
        fold_pred_df = pd.read_csv(file_path)
        fold_pred_df["fold"] = int(fold_id)
        return fold_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
root = Path("../../../datastore")

In [None]:
drug_types_path = root / "processed/DrugAnnotations/drug_types.json"
with open(drug_types_path, "r") as fh:
    drug_to_type = json.load(fh)

In [None]:
fixed_drug_types = {"chemotherapy": "Chemo", "targeted": "Targeted", "other": "Other"}
drug_to_type = {k: fixed_drug_types[v] for k,v in drug_to_type.items()}

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2-HCI"

drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
drug_meta["type"] = drug_meta.index.map(drug_to_type)
drug_encoders = screendl.load_drug_features(
    dataset_dir / "ScreenDL/FeatureMorganFingerprints.csv"
)

cell_meta = pd.read_csv(dataset_dir / "MetaSampleAnnotations.csv", index_col=0)
cell_encoders = screendl.load_cell_features(
    dataset_dir / "ScreenDL/FeatureGeneExpression.csv"
)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_encoders=cell_encoders,
    drug_encoders=drug_encoders,
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name="CellModelPassports-GDSC-HCI",
)

cell_ids = D.cell_meta[D.cell_meta["domain"] == "CELL"].index
pdmc_ids = D.cell_meta[D.cell_meta["domain"] == "PDMC"].index

cell_ds = D.select_cells(cell_ids, name="cell_ds")
pdmc_ds = D.select_cells(pdmc_ids, name="pdmc_ds")

print(cell_ds)
print(pdmc_ds)

In [None]:
DRUG_TO_PATHWAY_EXT = {
    "4mu8C": "Other", # targets IRE1 RNase which is involved in the unfolded protein response
    "A-966492": "Genome integrity",
    "Abemaciclib": "Cell cycle",
    "AMG232": "p53 pathway",
    "APG1387": "Apoptosis regulation",
    "ASLAN-002": "RTK signaling",
    "AZD0156": "Genome integrity",
    "AZD4573": "Cell cycle",
    "AZD5363": "PI3K/MTOR signaling",
    "Berzosertib": "Genome integrity",
    "Birinapant": "Apoptosis regulation",
    "Carboplatin": "DNA replication",
    "Ceritinib": "RTK signaling",
    "Cobimetinib": "ERK MAPK signaling",
    "Copanlisib": "PI3K/MTOR signaling",
    "Crenigacestat": "Other",
    "Emavusertib": "Other, kinases",
    "Endoxifen": "Hormone-related",
    "Enzalutamide": "Hormone-related",
    "Epacadostat": "Metabolism",
    "EPZ011989": "Chromatin histone methylation",
    "Erdafitinib": "RTK signaling",
    "Eribulin": "Mitosis",
    "Everolimus": "PI3K/MTOR signaling",
    "GDC-0152": "Apoptosis regulation",
    "GDC-0917": "Apoptosis regulation",
    "Ixazomib": "Protein stability and degradation",
    "Megestrol Acetate": "Hormone-related",
    "Methoxyamine": "DNA replication",
    "MIK665": "Apoptosis regulation",
    "Nedisertib": "Genome integrity",
    "Neratinib": "EGFR/HER2 signaling",  # could also be RTK signaling (HER2/ERBB2 and EGFR)
    "Onalespib": "Protein stability and degradation",
    "ONC206": "Other, kinases",
    "Pamiparib": "Genome integrity",
    "Pelcitoclax": "Apoptosis regulation",
    "Pevonedistat": "Genome integrity",
    "RO4929097": "WNT signaling",
    "Sapanisertib": "PI3K/MTOR signaling",
    "Selinexor": "Other",
    "Sotorasib": "ERK MAPK signaling",
    "TAK-243": "Protein stability and degradation",
    "Telaglenastat": "Other",
    "Tivantinib": "RTK signaling",
    "Tolinapant": "Apoptosis regulation",
    "Triapine": "DNA replication",
    "Tucatinib": "Apoptosis regulation",
    "Vemurafenib": "ERK MAPK signaling",
    "ZW4864": "WNT signaling",
}

In [None]:
FIXED_DRUG_NAMES = {
    "MK-1775": "Adavosertib",
    "AZD5363": "Capivasertib",
    "VE-822": "Berzosertib",
    "EPZ5676": "Pinometostat",
}

In [None]:
drug_meta_ext = pd.read_csv(
    root / "inputs/CellModelPassportsGDSCv1v2Hallmark/MetaDrugAnnotations.csv",
    index_col=0,
    usecols=["drug_id", "targets", "target_pathway"],
)

drug_to_pathway = drug_meta_ext["target_pathway"].to_dict()
fixed_pathways = {"EGFR signaling": "EGFR/HER2 signaling"}
drug_to_pathway = {k: fixed_pathways.get(v, v) for k, v in drug_to_pathway.items()}
drug_to_pathway.update(DRUG_TO_PATHWAY_EXT)

In [None]:
model_results = {}
output_dir = root / "outputs"
rescale = lambda df, col: df.groupby(["fold", "drug_id"])[col].transform(stats.zscore)

In [None]:
# DeepCDR results

path_fmt = "experiments/pdx_benchmarking/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI-Mutations"
model = "DeepCDR-legacy"
date = "2024-11-21_11-17-57"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

_temp = load_multirun_predictions(run_dir, run_regex, splits=["test"]).assign(
    y_true=lambda df: rescale(df, "y_true"),
    y_pred=lambda df: rescale(df, "y_pred"),
    model=model.split("-")[0],
    was_screened=False,
)
model_results[model.split("-")[0]] = _temp

In [None]:
# HiDRA results

path_fmt = "experiments/pdx_benchmarking/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "HiDRA-legacy"
date = "2024-11-26_21-29-14"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

_temp = load_multirun_predictions(run_dir, run_regex, splits=["test"]).assign(
    y_true=lambda df: rescale(df, "y_true"),
    y_pred=lambda df: rescale(df, "y_pred"),
    model=model.split("-")[0],
    was_screened=False,
)
model_results[model.split("-")[0]] = _temp

In [None]:
path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2024-11-27_11-58-34"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen": f"{model}-SA (ALL)",
}

# NOTE: predictions are already rescaled for ScreenDL
model_results[model] = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
assert conf.screenahead.opt.exclude_drugs is None
print(OmegaConf.to_yaml(conf.xfer))

In [None]:
model_results_df = pd.concat(model_results.values())
model_results_df.head()

In [None]:
# look at how many times a drug was selected for screening
drug_num_times_screened = (
    model_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .drop_duplicates(["cell_id", "drug_id"])
    .groupby(["drug_id", "was_screened"])["cell_id"]
    .nunique()
    .unstack()
    .fillna(0)
    .sort_values(False)
)

drug_num_times_screened.describe()

In [None]:
model_results_df.groupby("model")["fold"].nunique()

In [None]:
# filter for drugs/PDxOs evaluated in all models
n_models = model_results_df["model"].nunique()
model_results_common = []
for f, g in model_results_df.groupby("fold"):
    counts = g.groupby(["cell_id", "drug_id"])["model"].nunique()
    keep_idx = counts[counts == n_models].index
    g_common = g.set_index(["cell_id", "drug_id"]).loc[keep_idx].reset_index()
    model_results_common.append(g_common)

model_results_common_df = pd.concat(model_results_common)
model_results_common_df.head()

In [None]:
# get ScreenAhead NBS results

_temp_nbs = (
    model_results_df.query("model == 'ScreenDL-SA (ALL)'")
    .query("was_screened == False")
    .assign(model="ScreenDL-SA (NBS)")
)

model_results_df = pd.concat([model_results_df, _temp_nbs])

In [None]:
MODELS = [
    "HiDRA",
    "DeepCDR",
    "ScreenDL-PT",
    "ScreenDL-FT",
    "ScreenDL-SA (NBS)",
    "ScreenDL-SA (ALL)",
]

MODEL_COLOR_SCALE = alt.Scale(
    domain=MODELS,
    range=("darkgray", "gray", "#4C78A8", "#B278A2", "#89D27A", "#5CA453"),
)

## Performance in original PDxO cohort

In [None]:
ORIGINAL_PDXO_IDS = [
    "HCI001",
    "HCI002",
    "HCI003",
    "HCI005",
    "HCI008",
    "HCI010",
    "HCI011",
    "HCI012",
    "HCI015",
    "HCI016",
    "HCI017",
    "HCI019",
    "HCI023",
    "HCI024",
    "HCI025",
    "HCI027",
]

In [None]:
original_pdxo_ensembl_metrics = (
    model_results_df
    .query("model in ['DeepCDR', 'HiDRA', 'ScreenDL-PT']")
    .groupby(["model", "drug_id", "cell_id"])
    .aggregate({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .groupby(["model", "drug_id"])
    .apply(lambda g: eval_utils.pcorr(g))
    .to_frame("pcc")
    .reset_index()
)

original_pdxo_ensembl_metrics.groupby("model")["pcc"].describe().loc[MODELS[:3]]

In [None]:
BOXPLOT_CONFIG = {
    "size": 30,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}

AXIS_CONFIG = {
    "titleFont": "arial",
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}

In [None]:
def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures boxplot for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(labelFont="arial")
    )

In [None]:
original_pdxo_boxes = (
    alt.Chart(
        original_pdxo_ensembl_metrics,
        width=40 * original_pdxo_ensembl_metrics["model"].nunique(),
        height=250,
    )
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N")
        .axis(labelAngle=-45, labelPadding=5)
        .scale(paddingOuter=0.1)
        .sort(MODELS)
        .title(None),
        alt.Y("pcc:Q")
        .axis(titlePadding=10, tickCount=5, grid=False)
        .scale(domain=[-1, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", legend=None, scale=MODEL_COLOR_SCALE),
    )
)

configure_chart(original_pdxo_boxes)

## Performance across all PDxO-drug pairs

In [None]:
old_tumors = model_results_df[model_results_df["model"] == "HiDRA"]["cell_id"]

In [None]:
pdxo_fold_metrics = model_results_df.groupby(["model", "fold", "drug_id"]).apply(
    lambda g: eval_utils.pcorr(g, min_obs=10)
)

pdxo_agg_metrics = (
    pdxo_fold_metrics.groupby(["model", "drug_id"])
    .mean()
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
    )
)

pdxo_agg_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
# confirm that the same 12 drugs were screened for each SA ensemble member
num_screened_drugs_per_tumor = (
    model_results_df.query("was_screened == True")
    .groupby("cell_id")["drug_id"]
    .nunique()
    .value_counts()
)

assert num_screened_drugs_per_tumor.index.nunique() == 1
assert num_screened_drugs_per_tumor.index[0] == conf.screenahead.opt.n_drugs

In [None]:
# here, we consider ScreenDL as an ensemble model trained on different samples of the
# cell line data -> we take the mean of ten model predictions as the predicted value

ensembl_metrics = (
    model_results_df
    .groupby(["model", "drug_id", "cell_id"])
    .aggregate({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .groupby(["model", "drug_id"])
    .apply(lambda g: eval_utils.pcorr(g, min_obs=10))
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
        is_ood_drug=lambda df: ~df["drug_id"].isin(cell_ds.drug_ids),
    )
)

ensembl_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
top_pt_ood_drugs = ensembl_metrics.query(
    "model == 'ScreenDL-PT' and is_ood_drug == True"
).sort_values("pcc", ascending=False)

top_pt_ood_drugs.head(15)

In [None]:
ensemble_metrics_no_ood_drugs = ensembl_metrics.query("drug_id in @cell_ds.drug_ids")
ensemble_metrics_no_ood_drugs.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
temp = ensembl_metrics.set_index(["drug_id", "model"])["pcc"].unstack()

m1 = "ScreenDL-PT"
m2 = "ScreenDL-FT"
X = temp[[m1, m2]].dropna()

x1 = X[m1]
x2 = X[m2]

deltas = (x2 - x1)
print(f"Mean Delta: {deltas.mean()}")
print(f"No. Improved: {(deltas > 0).sum():.2f}")
print(f"Pct. Improved: {(deltas > 0).sum() / len(deltas) * 100:.2f}")

### High confidence drugs

In [None]:
summarize_high_conf_drugs = lambda df: pd.Series(
    {"n_drugs": df.shape[0], "n_high_conf_drugs": df[df["pcc"] >= 0.5].shape[0]}
)

In [None]:
sa_ensembl_metrics = ensembl_metrics.query("model.str.contains('ScreenDL-SA')")

sa_ensembl_metrics.groupby("model").apply(summarize_high_conf_drugs).dropna().assign(
    pct_high_conf=lambda df: round(100 * df["n_high_conf_drugs"] / df["n_drugs"], 2)
)

In [None]:
source = (
    sa_ensembl_metrics.query("model == 'ScreenDL-SA (ALL)'")
    .query("pcc >= 0.5")
    .assign(drug_name=lambda df: df["drug_id"].map(lambda x: FIXED_DRUG_NAMES.get(x, x)))
)

color_domain = list(source.sort_values("pcc", ascending=False)["pathway"].unique())
# color_domain = list(source["pathway"].sort_values().unique())

high_conf_drugs_chart_all = (
    alt.Chart(source)
    .mark_bar(stroke="black", size=10.5, strokeWidth=1, opacity=1)
    .encode(
        alt.X("drug_name:N", sort="-y")
        .axis(domainColor="black", labelAngle=-65)
        .scale(paddingOuter=0.15)
        .title(None),
        alt.Y("pcc:Q")
        .axis(grid=False, tickCount=5, domainColor="black", titlePadding=10)
        .scale(domain=(0.0, 1.0))
        .title("Pearson Correlation"),
        alt.Color("pathway:N")
        .scale(domain=color_domain, scheme="tableau20")
        .legend(columns=4, orient="bottom", offset=15)
        .title(None),
    )
    .properties(height=170, width=12.5 * source.shape[0])
)

configure_chart(high_conf_drugs_chart_all)

### Performance boxplots

In [None]:
temp_ = ensemble_metrics_no_ood_drugs.set_index(["drug_id", "model"])["pcc"].unstack()[MODELS]

m1 = "ScreenDL-PT"
m2 = "ScreenDL-FT"

X = temp_[[m1, m2]].dropna()
stats.wilcoxon(X[m1], X[m2])

In [None]:
source = ensemble_metrics_no_ood_drugs.copy()

boxes_all = (
    alt.Chart(source, width=37 * len(MODELS), height=220)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N").axis(labelAngle=-45, labelPadding=5).sort(MODELS).title(None),
        alt.Y("pcc:Q")
        .axis(titlePadding=10, tickCount=4, grid=False)
        .scale(domain=[-0.8, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

boxes_types = (
    alt.Chart(source.dropna(subset=["drug_type"]), width=33 * 2, height=220)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("drug_type:N")
        .axis(labelAngle=-45, labelPadding=5, orient="bottom")
        .sort(["Targeted", "Chemo"])
        .title(None),
        alt.Y("pcc:Q").axis(None).scale(domain=[-0.8, 1]).title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
        alt.Column("model:N").header(orient="top").spacing(10).sort(MODELS).title(None),
    )
)

ensemble_boxes = alt.hconcat(boxes_all, boxes_types, spacing=-5)
configure_chart(ensemble_boxes)

### Performance stratified by ID vs OOD drugs

In [None]:
source = ensembl_metrics.copy()
source["screened_in_cells"] = source["is_ood_drug"].map({True: "No", False: "Yes"})

boxes_ood_vs_id_all_drugs = (
    alt.Chart(source, width=37 * 2, height=250)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("screened_in_cells:N")
        .axis(labelAngle=0, labelPadding=5, orient="bottom")
        .sort(["No", "Yes"])
        .title(None),
        alt.Y("pcc:Q")
        .axis(grid=False, tickCount=6)
        .scale(domain=[-0.8, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
        alt.Column("model:N").header(orient="top").spacing(10).sort(MODELS).title(None),
    )
)

configure_chart(boxes_ood_vs_id_all_drugs)

In [None]:
m1 = "ScreenDL-PT"
m2 = "ScreenDL-FT"
for ood_status, group in ensembl_metrics.groupby("is_ood_drug"):
    print(f"OOD: {ood_status}")
    temp = group.set_index(["drug_id", "model"])["pcc"].unstack()
    X = temp[[m1, m2]].dropna()
    x1, x2 = X[m1], X[m2]
    deltas = (x2 - x1)
    print(f"Mean Delta: {deltas.mean()}")
    print(f"No. Improved: {(deltas > 0).sum():.2f}")
    print(f"Pct. Improved: {(deltas > 0).sum() / len(deltas) * 100:.2f}\n")

In [None]:
# We also observed a marked improvement in predictions for OOD drugs, with a 

In [None]:
# NOTE: we would expect ScreenAhead not to work for OOD drugs as well unless they have
# high chemical similarity to an existing drug and the chemical similarity correlates
# with similar biological function - we may get some benefit from FT here

In [None]:
# TODO: add comparison showing in distribution vs OOD performance in extended data figure
# TODO: show all high-confidence drugs
# TODO: add pattern showing whether the drug was in vs out of distribution for the high-confidence drugs
# TODO: show auROC and response rate for all drugs

### Performance stratified by drug mechanism

In [None]:
pathway_counts = (
    ensembl_metrics.groupby(["model", "pathway"])["drug_id"]
    .nunique()
    .unstack(0)
    .dropna()
    .min(axis=1)
)
keep_pathways = pathway_counts[pathway_counts >= 2].index

median_pathway_metrics = (
    ensembl_metrics.query("pathway in @keep_pathways")
    .groupby(["model", "pathway"])["pcc"]
    .agg(["median", "size"])
    .reset_index()
)

pathway_order = (
    median_pathway_metrics.query("model == 'ScreenDL-SA (ALL)'")
    .sort_values("median", ascending=False)["pathway"]
    .to_list()
)

In [None]:
circles = (
    alt.Chart(median_pathway_metrics, width=220, height=15 * len(pathway_order))
    .mark_circle(size=80, opacity=0.8, stroke="black", strokeWidth=0.5)
    .encode(
        alt.X(
            "median:Q",
            axis=alt.Axis(
                titlePadding=10,
                values=[-0.2, 0.0, 0.2, 0.4, 0.6, 0.8],
                grid=False,
            ),
            scale=alt.Scale(domain=(-0.3, 0.9)),
            title="Median Pearson Correlation Per Drug",
        ),
        alt.Y("pathway:N", sort=pathway_order, title=None),
        alt.Color(
            "model:N",
            scale=MODEL_COLOR_SCALE,
            legend=alt.Legend(
                orient="top",
                title=None,
                symbolStrokeWidth=1,
                columns=3,
                direction="vertical",
            ),
        ),
        tooltip=["median:Q", "pathway:N"],
    )
)

bars = alt.Chart(
    median_pathway_metrics.query("model == 'ScreenDL-PT'"),
    width=50,
    height=15 * len(pathway_order),
).encode(
    alt.X(
        "size:Q",
        axis=alt.Axis(grid=False, values=[0, 15], titlePadding=10),
        scale=alt.Scale(domain=(0, 15)),
        title="No. Drugs",
    ),
    alt.Y(
        "pathway:N",
        axis=alt.Axis(ticks=False, labels=False, offset=0, domainOpacity=0),
        sort=pathway_order,
        title=None,
    ),
    text="size",
)

bars = bars.mark_bar(stroke="black", strokeWidth=0.5, size=13, color="#999999")
bars += bars.mark_text(align="left", dx=4, fontSize=10)

pathway_performance_chart = alt.hconcat(circles, bars, spacing=5)
configure_chart(pathway_performance_chart)

In [None]:
n_pathways = median_pathway_metrics["pathway"].nunique()
best_model_by_pathway = (
    median_pathway_metrics.query("~model.str.contains('ScreenDL-SA')")
    .groupby("pathway")
    .apply(lambda g: g["model"].loc[g["median"].idxmax()])
    .value_counts()
)
best_model_by_pathway / n_pathways * 100

In [None]:
# fraction of pathways for which fine-tuning improved performance (FT vs PT)
temp = median_pathway_metrics.set_index(["pathway", "model"])["median"].unstack()
((temp["ScreenDL-FT"] - temp["ScreenDL-PT"]) > 0).sum() / temp.shape[0]

### auROC analysis

In [None]:
def compute_roc_auc(y_true: pd.Series, y_pred: pd.Series, min_obs: int = 10) -> float:
    if y_true.nunique() <= 1 or len(y_true) < min_obs:
        return np.nan
    return skm.roc_auc_score(y_true, -1 * y_pred)

In [None]:
ensembl_results_df = (
    model_results_df.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .reset_index()
)

In [None]:
y_true_df = (
    ensembl_results_df[["model", "cell_id", "drug_id"]]
    .merge(pdmc_ds.obs, on=["cell_id", "drug_id"])
    .assign(
        label=lambda df: df.groupby(["model", "drug_id"])["label"].transform(stats.zscore)
    )
)

y_true_df["y_true_class"] = (
    y_true_df.groupby(["model", "drug_id"])["label"]
    .transform(lambda x: x < x.quantile(0.3))
    .astype(int)
)

ensembl_results_df = ensembl_results_df.merge(
    y_true_df.drop(columns=["label", "id"]), on=["model", "cell_id", "drug_id"]
)

In [None]:
ensembl_auroc_metrics = (
    ensembl_results_df.groupby(["model", "drug_id"])
    .apply(lambda g: compute_roc_auc(g["y_true_class"], g["y_pred"]))
    .to_frame(name="auROC")
    .reset_index()
)

ensembl_auroc_metrics.groupby("model")["auROC"].describe().loc[MODELS]

In [None]:
ensembl_auroc_metrics_no_ood_drugs = (
    ensembl_results_df.query("drug_id in @cell_ds.drug_ids")
    .groupby(["model", "drug_id"])
    .apply(lambda g: compute_roc_auc(g["y_true_class"], g["y_pred"]))
    .to_frame(name="auROC")
    .reset_index()
)

ensembl_auroc_metrics_no_ood_drugs.groupby("model")["auROC"].describe().loc[MODELS]

In [None]:
X = ensembl_auroc_metrics.set_index(["model", "drug_id"])["auROC"].unstack(0).dropna()

res = []
for m1, m2 in itertools.combinations(MODELS, 2):
    x1 = X[m1]
    x2 = X[m2]
    s, p = stats.wilcoxon(x1, x2)
    res.append([m1, m2, s, p])

cols = ["m1", "m2", "statistic", "pvalue"]
res = pd.DataFrame(res, columns=cols).set_index(["m1", "m2"])
res.round(3)

In [None]:
bars = (
    alt.Chart(ensembl_auroc_metrics)
    .mark_bar(stroke="black", size=17, strokeWidth=1)
    .encode(
        alt.X("median(auROC):Q")
        .axis(grid=False, tickCount=5, domainColor="black", titlePadding=10)
        .scale(domain=(0.4, 0.9))
        .title("auROC"),
        alt.Y("model:N")
        .axis(domainColor="black")
        .scale(domain=list(reversed(MODELS)), paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(width=275, height=20 * len(MODELS))
)


rule = (
    alt.Chart(pd.DataFrame({"x": [0.5]}))
    .mark_rule(stroke="black", strokeDash=[4, 3], strokeWidth=1.25)
    .encode(x="x:Q")
)

error_bars = (
    alt.Chart(ensembl_auroc_metrics)
    .mark_errorbar(
        extent="iqr", ticks=alt.MarkConfig(size=5, color="black", strokeWidth=1)
    )
    .encode(alt.Y("model:N"), alt.X("auROC:Q"))
)

auroc_metrics_chart = alt.layer(bars, error_bars, rule)
configure_chart(auroc_metrics_chart)

### Response rate analysis

In [None]:
random_response_rates = []
for _ in range(1000):
    selected_random = y_true_df.groupby("cell_id").sample(1)
    response_rate = selected_random["y_true_class"].sum() / len(selected_random)
    random_response_rates.append(response_rate)

random_response_rate = np.mean(random_response_rates)
print(f"Random Response Rate {random_response_rate:.2f}")

In [None]:
def select_best_therapy(df: pd.DataFrame, on_: str = "y_pred") -> pd.DataFrame:
    return df.loc[df[on_].idxmin()]

In [None]:
ensembl_selected_drugs = (
    ensembl_results_df.groupby(["model", "cell_id"], as_index=False)
    .apply(lambda g: g.loc[g["y_pred"].idxmin()])
    .reset_index(drop=True)
)

ensembl_response_rates = (
    ensembl_selected_drugs
    .groupby("model")["y_true_class"]
    .apply(lambda x: x.sum() / len(x))
    .to_frame(name="response_rate")
    .loc[MODELS]
    .reset_index()
    .query("model != 'ScreenDL-SA (NBS)'")
)

ensembl_response_rates

In [None]:
sa_selected_screened_counts = (
    ensembl_selected_drugs.query("model == 'ScreenDL-SA (ALL)'")
    .merge(model_results_df, on=["model", "cell_id", "drug_id"])
    .drop_duplicates(subset="cell_id")["was_screened"]
    .value_counts()
)

sa_selected_screened_counts

In [None]:
temp = ensembl_response_rates.set_index("model")["response_rate"]
x1 = temp.loc["ScreenDL-PT"]
x2 = temp.loc["ScreenDL-FT"]
(x2 - x1) / x1 * 100

In [None]:
y_order = MODELS[:-2] + MODELS[-1:]

base = alt.Chart(ensembl_response_rates)

bars = (
    base.mark_bar(stroke="black", size=17, strokeWidth=1)
    .encode(
        alt.X("mean(response_rate):Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 0.8))
        .title("Response Rate (%)"),
        alt.Y("model:N")
        .axis(domainColor="black")
        .scale(paddingOuter=0.15, domain=list(reversed(y_order)))
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(width=275, height=20 * (len(MODELS) - 1))
)

rule = (
    alt.Chart(pd.DataFrame({"x": [0.3]}))
    .mark_rule(stroke="black", strokeDash=[4, 3], strokeWidth=1.25)
    .encode(x="x:Q")
)

text = base.mark_text(align="left", dx=6, fontSize=10).encode(
    alt.X("mean(response_rate):Q").title("Response Rate (%)"),
    alt.Y("model:N").title(None),
    alt.Text("mean(response_rate):Q", format=".1%"),
)


response_rate_chart = alt.layer(bars, rule, text)

configure_chart(response_rate_chart)

## PDxO Screening Heatmap

In [None]:
M = get_response_matrix(pdmc_ds, impute=False).T

X = M[["Olaparib", "Talazoparib"]].dropna()
stats.pearsonr(X["Olaparib"], X["Talazoparib"])

In [None]:
X = M[["Birinapant", "Tolinapant"]].dropna()
stats.pearsonr(X["Birinapant"], X["Tolinapant"])

In [None]:
hr_status = pd.read_csv(root / "raw/Welm/receptor_status_combined.csv")

hr_status_extras = [
    ["HCI027BR", "neg", "neg", "neg"],
    ["HCI011E2", "pos", "pos", "neg"],
    ["HCI017E2", "pos", "pos", "neg"],
    ["HCI015BGR", "neg", "neg", "neg"],
    ["HCI015BR", "neg", "neg", "neg"],
    ["HCI027BS", "neg", "neg", "neg"],
    ["HCI048CR", "pos", "neg", "neg"],
    ["HCI054CR", "neg", "neg", "neg"],
    ["HCI053CR", "neg", "neg", "neg"],
    ["TOW85", "neg", "neg", "pos"],
    ["HCI023BR", "neg", "neg", "neg"],
    ["HCI045VR", "neg", "neg", "neg"],
    ["HCI064", "pos", "pos", "pos"],
]

hr_status = pd.concat(
    [hr_status, pd.DataFrame(hr_status_extras, columns=hr_status.columns)]
).drop_duplicates()
hr_status = hr_status.set_index("sample_id")

# https://pdxportal.research.bcm.edu/pdxportal/collections/Breast?dswid=-7519
hr_status.loc["BCM3277"] = ["pos", "pos", "neg"]

hr_status.head()

In [None]:
# test for enhanced sensitivity to PI3K/MTOR targeting drugs in ER+ lines
pathway_drugs = drug_meta_ext.query("target_pathway == 'PI3K/MTOR signaling'")
pathway_drugs = pathway_drugs.index.to_list()

pathway_data = (
    M.filter(items=pathway_drugs, axis=1)
    .melt(ignore_index=False, value_name="Zd")
    .dropna()
    .reset_index()
    .assign(
        Zd=lambda df: df.groupby("drug_id")["Zd"].transform(stats.zscore),
        ER=lambda df: df["cell_id"].map(hr_status["ER"]),
    )
)

grouped = pathway_data.groupby("ER")["Zd"].agg(list).to_dict()
print(stats.mannwhitneyu(grouped["neg"], grouped["pos"]))
pathway_data.groupby("ER")["Zd"].describe()

In [None]:
# test for increased platinum sensitivity in TNBC lines
temp_ = (
    M["Carboplatin"]
    .dropna()
    .to_frame(name="Zd")
    .transform(stats.zscore)
    .join(((hr_status == "neg").sum(axis=1) == 3).to_frame("is_TNBC"))
)
grouped = temp_.groupby("is_TNBC")["Zd"].agg(list).to_dict()
print(stats.mannwhitneyu(grouped[True], grouped[False]))
temp_.groupby("is_TNBC")["Zd"].describe()

In [None]:
M = get_response_matrix(pdmc_ds, impute=False).T

tumors_per_drug = (~M.isna()).sum()
keep_drugs = tumors_per_drug[tumors_per_drug >= M.shape[0] * 0.7].index

tumor_na_counts = M[keep_drugs].isna().sum(axis=1)
keep_tumors = tumor_na_counts[tumor_na_counts == 0].index

M = M.loc[keep_tumors, keep_drugs].transform(stats.zscore).T
M.index = M.index.map(lambda x: FIXED_DRUG_NAMES.get(x, x))
print(M.shape)

In [None]:
col_margin_data = (
    hr_status.loc[M.columns].dropna()
    .melt(ignore_index=False, var_name="receptor", value_name="status")
    .reset_index()
    .applymap(lambda x: x.upper())
)
col_margin_data.columns = ["x", "y", "value"]

col_margin_y_scale = alt.Scale(domain=("ER", "PR", "HER2"))
col_margin_z_scale = alt.Scale(domain=("NEG", "POS"), range=("white", "black"))

In [None]:
row_margin_data = pd.DataFrame(
    {"x": 1, "y": M.index, "value": M.index.map(drug_to_pathway)}
)
row_margin_z_scale = alt.Scale(scheme="tableau20")

In [None]:
hmap = af.cluster_heatmap(
    M,
    height=350,
    width=760,
    row_dendro_size=40,
    row_margin_data=row_margin_data,
    row_margin_z_scale=row_margin_z_scale,
    row_margin_legend_title="Target Pathway",
    row_margin_legend_config=alt.LegendConfig(symbolStrokeColor="black"),
    col_dendro_size=40,
    col_margin_data=col_margin_data,
    col_margin_z_scale=col_margin_z_scale,
    col_margin_y_scale=col_margin_y_scale,
    col_margin_legend_title="Receptor Status",
    col_margin_legend_config=alt.LegendConfig(columns=2, symbolStrokeColor="black"),
    legend_title="Z-Score ln(IC50)",
    legend_config=alt.LegendConfig(
        gradientLength=100, gradientThickness=15, direction="horizontal", tickCount=5
    ),
    legend_spacing=15,
)

(
    hmap.configure_axis(
        labelFont="arial", tickColor="black", tickSize=3
    ).configure_legend(
        titleFont="arial",
        titleFontStyle="italic",
        titleFontWeight="bold",
        labelFont="arial",
        labelFontStyle="regular",
    )
)

### Fig 3. ScreenDL achieves accurate response prediction in high-risk/metastatic breast cancer PDxO models

In [None]:
temp1 = alt.vconcat(ensemble_boxes, high_conf_drugs_chart_all)
temp1 = temp1.resolve_scale(color="independent")

temp2 = alt.vconcat(
    pathway_performance_chart,
    auroc_metrics_chart,
    response_rate_chart,
    spacing=30,
)
temp2 = temp2.resolve_scale(color="independent")

top = hmap
bottom = alt.hconcat(temp1, temp2).resolve_legend(color="independent")
configure_chart(alt.vconcat(top, bottom, spacing=20)).configure_legend(
    titleFontStyle="italic", titleFont="arial", labelFont="arial"
)

## Performance across common PDxO/drug pairs

In [None]:
_temp_nbs = (
    model_results_common_df.query("model == 'ScreenDL-SA (ALL)'")
    .query("was_screened == False")
    .assign(model="ScreenDL-SA (NBS)")
)

model_results_common_df = pd.concat([model_results_common_df, _temp_nbs])

In [None]:
ensembl_metrics_common = (
    model_results_common_df
    .groupby(["model", "drug_id", "cell_id"])
    .aggregate({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .groupby(["model", "drug_id"])
    .apply(lambda g: eval_utils.pcorr(g, min_obs=10))
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
        is_ood_drug=lambda df: ~df["drug_id"].isin(cell_ds.drug_ids),
    )
)

ensembl_metrics_common.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
ensemble_metrics_common_no_ood_drugs = ensembl_metrics_common.query(
    "drug_id in @cell_ds.drug_ids"
)
ensemble_metrics_common_no_ood_drugs.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
source = ensemble_metrics_common_no_ood_drugs.copy()

boxes_all = (
    alt.Chart(source, width=35 * len(MODELS), height=220)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N").axis(labelAngle=-45, labelPadding=5).sort(MODELS).title(None),
        alt.Y("pcc:Q")
        .axis(titlePadding=10, tickCount=4, grid=False)
        .scale(domain=[-0.8, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

boxes_types = (
    alt.Chart(source.dropna(subset=["drug_type"]), width=35 * 2, height=220)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("drug_type:N")
        .axis(labelAngle=-45, labelPadding=5, orient="bottom")
        .sort(["Targeted", "Chemo"])
        .title(None),
        alt.Y("pcc:Q").axis(None).scale(domain=[-0.8, 1]).title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
        alt.Column("model:N").header(orient="top").spacing(10).sort(MODELS).title(None),
    )
)

ensemble_boxes_common = alt.hconcat(boxes_all, boxes_types, spacing=-5)
configure_chart(ensemble_boxes_common)

## ScreenAhead with vs without domain-specific fine-tuning

In [None]:
path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2024-11-27_13-20-41"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen (fine-tune)": f"{model}-SA (+FT)",
    "screen (no-fine-tune)": f"{model}-SA (-FT)",
}

# NOTE: predictions are already rescaled for ScreenDL
results_df = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)

In [None]:
results_df = (
    results_df.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2)})
    .reset_index()
)
results_df.head()

In [None]:
MODELS = ["ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA (-FT)", "ScreenDL-SA (+FT)"]

MODEL_COLOR_SCALE = alt.Scale(
    domain=MODELS,
    range=("#4C78A8", "#B278A2", "#89D27A", "#5CA453"),
)

In [None]:
ensembl_pcc_result = (
    results_df.groupby(["model", "drug_id"])
    .apply(lambda g: eval_utils.pcorr(g, min_obs=10))
    .to_frame("pcc")
    .reset_index()
    .assign(
        drug_type=lambda df: df["drug_id"].map(drug_to_type),
        pathway=lambda df: df["drug_id"].map(drug_to_pathway),
        is_ood_drug=lambda df: ~df["drug_id"].isin(cell_ds.drug_ids),
    )
)

ensembl_pcc_result.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
temp_ = ensembl_pcc_result.set_index(["drug_id", "model"])["pcc"].unstack().dropna()
m1 = "ScreenDL-SA (-FT)"
m2 = "ScreenDL-SA (+FT)"

stats.wilcoxon(temp_[m1], temp_[m2])

In [None]:
source = ensembl_pcc_result.copy()
source["screened_in_cells"] = source["is_ood_drug"].map({True: "No", False: "Yes"})

boxes_1 = (
    alt.Chart(source, width=35 * len(MODELS), height=250)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N")
        .axis(labelAngle=-45, labelPadding=5)
        .sort(MODELS)
        .title(None),
        alt.Y("pcc:Q")
        .axis(tickCount=5, grid=False, titlePadding=10)
        .scale(domain=[-1, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

boxes_2 = (
    alt.Chart(source, width=35 * 2, height=250)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("screened_in_cells:N")
        .axis(labelAngle=0, labelPadding=5, orient="bottom")
        .sort(["No", "Yes"])
        .title(None),
        alt.Y("pcc:Q").axis(None).scale(domain=[-1, 1]).title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
        alt.Column("model:N").header(orient="top").spacing(10).sort(MODELS).title(None),
    )
)

configure_chart(alt.hconcat(boxes_1, boxes_2, spacing=10))