# Comparison of ScreenDL with Biomarker-Only Models in Cell Lines

* [Dabrafenib + BRAF](#dabrafenib--braf)
* [Capivasertib + PIK3CA/AKT1/PTEN](#capivasertib--pik3caakt1pten)

In [None]:
from __future__ import annotations

import altair as alt
import pandas as pd
import numpy as np
import sklearn.metrics as skm
import typing as t

from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf
from tqdm import tqdm

from cdrpy.datasets import Dataset
from cdrpy.data.preprocess import GroupStandardScaler

from screendl.utils import evaluation as eval_utils

In [None]:
root = Path("../../../datastore")

dataset = "CellModelPassports-GDSCv1v2"
model = "ScreenDL"

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2"

cell_meta = pd.read_csv(dataset_dir / "MetaCellAnnotations.csv", index_col=0)
drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name=dataset_dir.name,
)

print(D)

In [None]:
mut_data = pd.read_csv(root / "raw/CellModelPassports/mutations_all_20230202.csv")
mut_data.head()

In [None]:
MODELS = ["ScreenDL-PT", "ScreenDL-SA"]
HEIGHT = 200
WIDTH = 200

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        fold_id = file_path.parent.stem.split("_")[-1]
        fold_pred_df = pd.read_csv(file_path)
        fold_pred_df["fold"] = int(fold_id)
        return fold_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
def rescale_predictions(df: pd.DataFrame) -> pd.DataFrame:
    """Rescales the predictions based on predictions in the train set."""
    df_trn = df[df["split_group"] == "train"]
    df_tst = df[df["split_group"] == "test"]

    gss = GroupStandardScaler()
    df_trn["y_true"] = gss.fit_transform(df_trn[["y_true"]], groups=df_trn["drug_id"])
    df_tst["y_true"] = gss.transform(df_tst[["y_true"]], groups=df_tst["drug_id"])

    gss = GroupStandardScaler()
    df_trn["y_pred"] = gss.fit_transform(df_trn[["y_pred"]], groups=df_trn["drug_id"])
    df_tst["y_pred"] = gss.transform(df_tst[["y_pred"]], groups=df_tst["drug_id"])

    return pd.concat([df_trn, df_tst])

In [None]:
AXIS_CONFIG = {
    "titleFont": "arial",
    "titlePadding": 5,
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}

BOXPLOT_CONFIG = {
    "size": 25,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}


def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures boxplot for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(labelFont="arial")
        .configure_boxplot(**BOXPLOT_CONFIG)
    )

In [None]:
class Plotter:
    """Builds plots for comparisons with biomarker-only models."""

    def __init__(
        self,
        D_bm: pd.DataFrame,
        D_pt: pd.DataFrame,
        D_sa: pd.DataFrame,
        x_domain: t.Tuple[int, int],
        y_domain: t.Tuple[int, int],
        biomarker_name: str,
    ) -> None:
        self.D_bm = D_bm.sort_values("is_MUT", ascending=False)
        self.D_pt = D_pt.sort_values("is_MUT", ascending=False)
        self.D_sa = D_sa.sort_values("is_MUT", ascending=False)
        self.x_domain = x_domain
        self.y_domain = y_domain
        self.biomarker_name = biomarker_name

    def plot(self) -> alt.Chart:
        """Renders the full plot."""
        bm_boxes = self.make_boxes(self.D_bm)
        bm_scatter = self.make_scatter(self.D_bm, include_mut_lines=False)
        ft_scatter = self.make_scatter(self.D_pt, include_mut_lines=True)
        sa_scatter = self.make_scatter(self.D_sa, include_mut_lines=True)
        return alt.hconcat(bm_boxes, bm_scatter, ft_scatter, sa_scatter)

    def make_boxes(self, D: pd.DataFrame) -> alt.Chart:
        """Makes the boxlplots stratified by mutation status"""
        return (
            alt.Chart(D, width=35 * 2, height=HEIGHT)
            .mark_boxplot()
            .encode(
                alt.X("is_MUT:N")
                .axis(labelAngle=0, title=self.biomarker_name)
                .scale(domain=("MUT", "WT")),
                self.true_y_encoding,
                self.color_encoding,
            )
        )

    def make_scatter(self, D: pd.DataFrame, include_mut_lines: bool) -> alt.Chart:
        """Makes the scatter plots."""
        base = alt.Chart(
            D.sort_values("is_MUT", ascending=False),
            width=WIDTH,
            height=HEIGHT,
        )

        extent = self.true_x_encoding.to_dict()["scale"]["domain"]

        chart = (
            base.transform_regression(
                "y_true",
                "y_pred",
                extent=extent,
            )
            .mark_line(stroke="black", strokeWidth=1)
            .encode(self.true_x_encoding, self.pred_y_encoding)
        )

        if include_mut_lines:

            chart += (
                base.transform_filter(alt.datum.is_MUT == "MUT")
                .transform_regression(
                    "y_true",
                    "y_pred",
                    extent=extent,
                )
                .mark_line(stroke="#5CA453", strokeWidth=2.5, strokeDash=[4, 4])
                .encode(self.true_x_encoding, self.pred_y_encoding)
            )

            chart += (
                base.transform_filter(alt.datum.is_MUT == "WT")
                .transform_regression(
                    "y_true",
                    "y_pred",
                    extent=extent,
                )
                .mark_line(stroke="darkgray", strokeWidth=2.5, strokeDash=[4, 4])
                .encode(self.true_x_encoding, self.pred_y_encoding)
            )

        chart += base.mark_circle(stroke="black").encode(
            self.true_x_encoding,
            self.pred_y_encoding,
            self.color_encoding,
            self.size_encoding,
            self.opacity_encoding,
            self.stroke_width_encoding,
            tooltip=["cell_id:N", "y_true:Q", "y_pred:Q"],
        )

        return chart

    @property
    def true_x_encoding(self) -> alt.X:
        return (
            alt.X("y_true:Q")
            .scale(domain=self.x_domain)
            .axis(tickCount=6, grid=False)
            .title("Observed Z-Score ln(IC50)")
        )

    @property
    def true_y_encoding(self) -> alt.Y:
        return (
            alt.Y("y_true:Q")
            .scale(domain=self.y_domain)
            .axis(tickCount=6, grid=False)
            .title("Observed Z-Score ln(IC50)")
        )

    @property
    def pred_y_encoding(self) -> alt.Y:
        return (
            alt.Y("y_pred:Q")
            .scale(domain=self.y_domain)
            .axis(tickCount=6, grid=False)
            .title("Predicted Z-Score ln(IC50)")
        )

    @property
    def color_encoding(self) -> alt.Color:
        return (
            alt.Color("is_MUT:N")
            .scale(domain=("MUT", "WT", "U"), range=("#5CA453", "darkgray", "lightgray"))
            .legend(None)
        )

    @property
    def size_encoding(self) -> alt.SizeValue:
        return alt.condition(
            alt.datum.is_MUT == "WT", alt.SizeValue(40), alt.SizeValue(80)
        )

    @property
    def stroke_width_encoding(self) -> alt.StrokeWidthValue:
        return alt.condition(
            alt.datum.is_MUT == "WT",
            alt.StrokeWidthValue(0.25),
            alt.StrokeWidthValue(0.5),
        )

    @property
    def opacity_encoding(self) -> alt.Opacity:
        return alt.condition(
            alt.datum.is_MUT == "WT",
            alt.OpacityValue(0.7),
            alt.Opacity("is_MUT:N")
            .scale(domain=("MUT", "WT"), range=(0.9, 0.7))
            .legend(None),
        )

In [None]:
model_results: t.Dict[str, pd.DataFrame] = {}

In [None]:
# load ScreenDL results
output_dir = root / "outputs"
dataset = "CellModelPassports-GDSCv1v2"
model = "ScreenDL"
date = "2024-04-18_12-54-20"

path_fmt = "screenahead/{0}/{1}/multiruns/{2}"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

model_results[model + "-PT"] = (
    load_multirun_predictions(run_dir, "*/predictions.csv", splits=["train", "test"])
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model + "-PT")
)

temp_ = model_results[model + "-PT"].query("split_group == 'train'").copy()
model_results[model + "-SA"] = (
    load_multirun_predictions(run_dir, "*/predictions_sa.csv", splits=None)
    .pipe(lambda df: pd.concat([df, temp_]))
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model + "-SA")
)

In [None]:
model_results_df = pd.concat(model_results.values()).reset_index(drop=True)
model_results_df_trn = model_results_df.query("split_group == 'train'")
model_results_df_tst = model_results_df.query("split_group == 'test'")
model_results_df_tst.head()

## Dabrafenib + BRAF

In [None]:
DRUG_ID = "Dabrafenib"
BIOMARKER_GENE_IDS = ["BRAF"]

In [None]:
corrs = (
    model_results_df_tst.groupby(["model", "drug_id"]).apply(eval_utils.pcorr).unstack(0)
)
corrs.loc[DRUG_ID][MODELS]

In [None]:
results_drug: pd.DataFrame = model_results_df_tst.query("drug_id == @DRUG_ID").copy()
results_drug.head()

In [None]:
gene_muts = (
    mut_data.query("model_id in @results_drug.cell_id")
    .query("gene_symbol in @BIOMARKER_GENE_IDS")
    .query("cancer_driver == True")
    .query("protein_mutation.str.contains('V600')")
)
gene_muts.head()

In [None]:
MUT_CATEGORIES = {True: "MUT", False: "WT", pd.NA: "U"}
results_drug_muts = (
    results_drug.assign(is_MUT=lambda df: df["cell_id"].isin(gene_muts["model_id"]))
    .replace({"is_MUT": MUT_CATEGORIES})
    .filter(items=["cell_id", "drug_id", "y_true", "y_pred", "model", "is_MUT"])
)
results_drug_muts.head()

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id", "is_MUT"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
biomarker_only_result = (
    results_drug_muts.query("model == 'ScreenDL-PT'")
    .filter(items=["cell_id", "drug_id", "y_true", "is_MUT"])
    .query("is_MUT != 'U'")
    .assign(y_pred=lambda df: df.groupby("is_MUT")["y_true"].transform("mean"))
)

biomarker_only_result.groupby("is_MUT")["y_true"].describe()

In [None]:
eval_utils.pcorr(biomarker_only_result)

In [None]:
mut_resps = biomarker_only_result.query("is_MUT == 'MUT'")["y_true"]
wt_resps = biomarker_only_result.query("is_MUT == 'WT'")["y_true"]
stats.mannwhitneyu(mut_resps, wt_resps)

In [None]:
plotter = Plotter(
    D_bm=biomarker_only_result,
    D_pt=results_drug_muts.query("model == 'ScreenDL-PT'"),
    D_sa=results_drug_muts.query("model == 'ScreenDL-SA'"),
    x_domain=(-6.5, 4.5),
    y_domain=(-6.5, 4.5),
    biomarker_name="BRAF V600",
)
dabrafenib_chart = plotter.plot()
configure_chart(dabrafenib_chart)

### auROC Comparison

As in previous notebooks, here, we define responders as the top 30% most sensitive lines and quantify auROC for ScreenDL and the biomarker-only model.

In [None]:
tumor_counts = results_drug["cell_id"].value_counts()
keep_tumors = tumor_counts[tumor_counts == 2].index.to_list()

In [None]:
y_true_df = (
    D.obs.rename(columns={"label": "y_true"})
    .drop(columns="id")
    .query("drug_id == @DRUG_ID")
    .copy()
)
y_true_df["y_true"] = y_true_df.groupby("drug_id")["y_true"].transform(stats.zscore)
y_true_df["y_true_cls"] = y_true_df["y_true"] <= y_true_df["y_true"].quantile(0.30)
y_true_df["y_true_cls"] = y_true_df["y_true_cls"].astype(int)
tumor_to_y_true_cls = y_true_df.set_index("cell_id")["y_true_cls"].to_dict()

In [None]:
temp = biomarker_only_result.query("cell_id in @keep_tumors").copy()
temp["y_true_cls"] = temp["cell_id"].map(tumor_to_y_true_cls)
skm.roc_auc_score(temp["y_true_cls"], -1 * temp["y_pred"])

In [None]:
auroc_metrics = (
    results_drug_muts.assign(y_true_cls=lambda df: df["cell_id"].map(tumor_to_y_true_cls))
    .query("cell_id in @keep_tumors")
    .groupby("model")
    .apply(lambda g: skm.roc_auc_score(g["y_true_cls"], -1 * g["y_pred"]))
)

auroc_metrics

### Recovery of BRAF V600 Mutant Poor Responders

In [None]:
y_true_df = (
    D.obs.rename(columns={"label": "y_true"})
    .drop(columns="id")
    .query("drug_id == @DRUG_ID")
    .copy()
)
y_true_df["y_true"] = y_true_df.groupby("drug_id")["y_true"].transform(stats.zscore)
y_true_df["y_true_cls"] = y_true_df["y_true"] >= -0.5
y_true_df["y_true_cls"] = y_true_df["y_true_cls"].astype(int)
tumor_to_y_true_cls = y_true_df.set_index("cell_id")["y_true_cls"].to_dict()

In [None]:
tumor_counts = results_drug["cell_id"].value_counts()
keep_tumors = tumor_counts[tumor_counts == 2].index.to_list()

In [None]:
results_drug_muts_only = results_drug_muts.query("is_MUT == 'MUT'")

In [None]:
# so we can see that we have reduced false positives with SreenAhead
# we predicted some tumors as responders that were not and ScreenAhead helped to correct this

In [None]:
t_ids = (
    results_drug_muts.query("is_MUT == 'MUT'")
    .query("y_true >= -0.5")
    .query("model == 'ScreenDL-SA'")["cell_id"]
    .unique()
)

obs = D.obs.copy()
obs["label"] = obs.groupby("drug_id")["label"].transform(stats.zscore)
print(len(t_ids))

In [None]:
agg_func = lambda g: pd.Series(
    {
        "auROC": skm.roc_auc_score(g["is_res"], g["y_pred"]),
        "average_precision": skm.average_precision_score(g["is_res"], g["y_pred"]),
        "tp": g[g["y_pred"] >= -0.5]["is_res"].sum(),  # not optimal threshold
        "total_pos": g["is_res"].sum(),
    }
)
temp = results_drug_muts.query("is_MUT == 'MUT'").query("cell_id in @keep_tumors")
temp.assign(is_res=lambda df: df["y_true"] >= -0.5).groupby("model").apply(agg_func)

In [None]:
temp = model_results_df_tst.query("cell_id in @t_ids")
best_resp = temp.groupby(["model", "cell_id"], as_index=False).apply(
    lambda g: g.loc[g["y_pred"].idxmin()]
)
drug_resp = temp.query("drug_id == @DRUG_ID")

cols = ["model", "cell_id", "drug_id", "y_true"]
(
    drug_resp[cols]
    .merge(best_resp[cols], on=["model", "cell_id"], suffixes=("_drug", "_best"))
    .assign(is_better=lambda df: df["y_true_best"] < df["y_true_drug"])
    .groupby("model")
    .apply(
        lambda g: pd.Series(
            {
                "total": len(g),
                "better": g["is_better"].sum(),
                "worse": (~g["is_better"]).sum(),
            }
        )
    )
    .unstack(0)
)

## Capivasertib + PIK3CA/AKT1/PTEN

In [None]:
DRUG_ID = "AZD5363"
BIOMARKER_GENE_IDS = ["PIK3CA", "AKT1", "PTEN"]

In [None]:
corrs = (
    model_results_df_tst.groupby(["model", "drug_id"]).apply(eval_utils.pcorr).unstack(0)
)
corrs.loc[DRUG_ID][MODELS]

In [None]:
results_drug: pd.DataFrame = model_results_df_tst.query("drug_id == @DRUG_ID").copy()
results_drug.head()

In [None]:
gene_muts = (
    mut_data.query("model_id in @results_drug.cell_id")
    .query("gene_symbol in @BIOMARKER_GENE_IDS")
    .query("cancer_driver == True")
)
gene_muts.head()

In [None]:
MUT_CATEGORIES = {True: "MUT", False: "WT", pd.NA: "U"}
results_drug_muts = (
    results_drug.assign(is_MUT=lambda df: df["cell_id"].isin(gene_muts["model_id"]))
    .replace({"is_MUT": MUT_CATEGORIES})
    .filter(items=["cell_id", "drug_id", "y_true", "y_pred", "model", "is_MUT"])
)
results_drug_muts.head()

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id", "is_MUT"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
WT_cell_ids = results_drug_muts.query("is_MUT == 'WT'")["cell_id"].unique()
WT_exceptional_true = (
    D.obs.query("drug_id == @DRUG_ID")
    .query("label < label.quantile(0.1)")
    .query("cell_id in @WT_cell_ids")["cell_id"]
    .unique()
    .tolist()
)
len(WT_exceptional_true)

In [None]:
WT_exceptional_pred = (
    results_drug_muts.assign(
        pred_ex_responder=lambda df: df.groupby("model")["y_pred"].transform(
            lambda x: x < x.quantile(0.1)
        )
    )
    .query("pred_ex_responder == True")
    .query("cell_id in @WT_cell_ids")
    .groupby("model")["cell_id"]
    .agg(set)
    .to_dict()
)

In [None]:
tp = set(WT_exceptional_true).intersection(WT_exceptional_pred["ScreenDL-PT"])
fn = set(WT_exceptional_true).difference(WT_exceptional_pred["ScreenDL-PT"])
fp = set(WT_exceptional_pred["ScreenDL-PT"]).difference(WT_exceptional_true)
len(tp), len(fn), len(fp)

In [None]:
tp = set(WT_exceptional_true).intersection(WT_exceptional_pred["ScreenDL-SA"])
fn = set(WT_exceptional_true).difference(WT_exceptional_pred["ScreenDL-SA"])
fp = set(WT_exceptional_pred["ScreenDL-SA"]).difference(WT_exceptional_true)
len(tp), len(fn), len(fp)

In [None]:
biomarker_only_result = (
    results_drug_muts.query("model == 'ScreenDL-PT'")
    .filter(items=["cell_id", "drug_id", "y_true", "is_MUT"])
    .query("is_MUT != 'U'")
    .assign(y_pred=lambda df: df.groupby("is_MUT")["y_true"].transform("mean"))
)

biomarker_only_result.groupby("is_MUT")["y_true"].describe()

In [None]:
eval_utils.pcorr(biomarker_only_result)

In [None]:
mut_resps = biomarker_only_result.query("is_MUT == 'MUT'")["y_true"]
wt_resps = biomarker_only_result.query("is_MUT == 'WT'")["y_true"]
stats.mannwhitneyu(mut_resps, wt_resps)

In [None]:
plotter = Plotter(
    D_bm=biomarker_only_result,
    D_pt=results_drug_muts.query("model == 'ScreenDL-PT'"),
    D_sa=results_drug_muts.query("model == 'ScreenDL-SA'"),
    x_domain=(-4.5, 4.5),
    y_domain=(-4.5, 4.5),
    biomarker_name="PIK3CA",
)
capivasertib_chart = plotter.plot()
configure_chart(capivasertib_chart)

### auROC Comparison

As in previous notebooks, here, we define responders as the top 30% most sensitive lines and quantify auROC for ScreenDL and the biomarker-only model.

In [None]:
tumor_counts = results_drug["cell_id"].value_counts()
keep_tumors = tumor_counts[tumor_counts == 2].index.to_list()

In [None]:
y_true_df = (
    D.obs.rename(columns={"label": "y_true"})
    .drop(columns="id")
    .query("drug_id == @DRUG_ID")
    .copy()
)
y_true_df["y_true"] = y_true_df.groupby("drug_id")["y_true"].transform(stats.zscore)
y_true_df["y_true_cls"] = y_true_df["y_true"] <= y_true_df["y_true"].quantile(0.30)
y_true_df["y_true_cls"] = y_true_df["y_true_cls"].astype(int)
tumor_to_y_true_cls = y_true_df.set_index("cell_id")["y_true_cls"].to_dict()

In [None]:
temp = biomarker_only_result.query("cell_id in @keep_tumors").copy()
temp["y_true_cls"] = temp["cell_id"].map(tumor_to_y_true_cls)
skm.roc_auc_score(temp["y_true_cls"], -1 * temp["y_pred"])

In [None]:
auroc_metrics = (
    results_drug_muts.assign(y_true_cls=lambda df: df["cell_id"].map(tumor_to_y_true_cls))
    .query("cell_id in @keep_tumors")
    .groupby("model")
    .apply(lambda g: skm.roc_auc_score(g["y_true_cls"], -1 * g["y_pred"]))
)

auroc_metrics

### Recovery of Non-Mutant Exceptional Responders

In [None]:
y_true_df = (
    D.obs.rename(columns={"label": "y_true"})
    .drop(columns="id")
    .query("drug_id == @DRUG_ID")
    .copy()
)
y_true_df["y_true"] = y_true_df.groupby("drug_id")["y_true"].transform(stats.zscore)
y_true_df["y_true_cls"] = y_true_df["y_true"] <= -1
y_true_df["y_true_cls"] = y_true_df["y_true_cls"].astype(int)
tumor_to_y_true_cls = y_true_df.set_index("cell_id")["y_true_cls"].to_dict()

In [None]:
tumor_counts = results_drug["cell_id"].value_counts()
keep_tumors = tumor_counts[tumor_counts == 2].index.to_list()

In [None]:
results_drug_wt_only = results_drug_muts.query("is_MUT == 'WT'")

trn_grouped = (
    model_results_df_trn.query("drug_id == @DRUG_ID")
    .query("cell_id in @keep_tumors")
    .query("cell_id in @results_drug_wt_only.cell_id")
    .groupby(["fold", "model"])
)
tst_grouped = (
    model_results_df_tst.query("drug_id == @DRUG_ID")
    .query("cell_id in @keep_tumors")
    .query("cell_id in @results_drug_wt_only.cell_id")
    .groupby(["fold", "model"])
)

In [None]:
def get_mcc_threshold(t_df: pd.DataFrame, thresholds: np.ndarray) -> float:
    """Computes the optimal threshold based on the training set."""
    y_true = t_df["is_res"]
    y_score = -1 * t_df["y_pred"]

    mccs = []
    for t in thresholds:
        # NOTE: we are looking for sensitivity here so we want >= threshold
        preds = (y_score >= t).astype(int)
        mccs.append(skm.matthews_corrcoef(y_true, preds))

    best_idx = np.argmax(mccs)
    mcc_thresh = thresholds[best_idx]
    return mcc_thresh


thresholds = np.linspace(-4, 2, 100)
confusions = {m: np.zeros((2, 2), dtype=int) for m in MODELS}
for (fold, model), t_df in trn_grouped:
    t_df = (
        t_df.assign(is_res=lambda df: df["cell_id"].map(tumor_to_y_true_cls))
        .drop(columns="y_true")
        .merge(y_true_df, on=["cell_id", "drug_id"])
    )
    e_df = (
        tst_grouped.get_group((fold, model))
        .drop(columns="y_true")
        .merge(y_true_df, on=["cell_id", "drug_id"])
    )

    mcc_thresh = -1 * get_mcc_threshold(t_df, thresholds)

    y_pred = (e_df["y_pred"] <= mcc_thresh).astype(int)
    y_true = e_df["y_true_cls"]

    C = skm.confusion_matrix(y_true, y_pred, labels=[0, 1])
    confusions[model] += C

In [None]:
source_pt = (
    pd.DataFrame(confusions["ScreenDL-PT"])
    .rename_axis(index="y_true")
    .melt(ignore_index=False, var_name="y_pred", value_name="count")
    .reset_index()
)
source_sa = (
    pd.DataFrame(confusions["ScreenDL-SA"])
    .rename_axis(index="y_true")
    .melt(ignore_index=False, var_name="y_pred", value_name="count")
    .reset_index()
)

In [None]:
base_pt = (
    alt.Chart(source_pt)
    .mark_rect()
    .encode(
        alt.X("y_pred:N").axis(labelAngle=0).scale(reverse=True).title("Predicted"),
        alt.Y("y_true:N").scale(reverse=True).title("Observed"),
        alt.Color("count:Q").scale(scheme="blues").legend(gradientLength=230).title(None),
    )
    .properties(width=230, height=230)
)


pt_mtx_plot = base_pt + base_pt.mark_text(baseline="middle").encode(
    alt.Text("count:Q", format=".0f"),
    alt.condition(
        alt.datum.count < 300, alt.ColorValue("black"), alt.ColorValue("white")
    ),
)

configure_chart(pt_mtx_plot)

In [None]:
base_sa = (
    alt.Chart(source_sa)
    .mark_rect()
    .encode(
        alt.X("y_pred:N").axis(labelAngle=0).scale(reverse=True).title("Predicted"),
        alt.Y("y_true:N").scale(reverse=True).title("Observed"),
        alt.Color("count:Q").scale(scheme="blues").legend(gradientLength=230).title(None),
    )
    .properties(width=230, height=230)
)


sa_mtx_plot = base_sa + base_sa.mark_text(baseline="middle").encode(
    alt.Text("count:Q", format=".0f"),
    alt.condition(
        alt.datum.count < 300, alt.ColorValue("black"), alt.ColorValue("white")
    ),
)

configure_chart(sa_mtx_plot)

In [None]:
final_figure = alt.vconcat(dabrafenib_chart, capivasertib_chart, spacing=20)

configure_chart(final_figure)