# Comparison of ScreenDL with Biomarker-Only Models in PDxOs

* [Talazoparib + BRCA1/2](#Talazoparib-+-BRCA1/2)
* [Capivasertib + PIK3CA/AKT1/PTEN](#Capivasertib-+-PIK3CA/AKT1/PTEN)
* [Alpelisib + PIK3CA](#Alpelisib-+-PIK3CA)

In [None]:
# FIXME: confirm that I am pulling in all samples here

In [None]:
from __future__ import annotations

import altair as alt
import pandas as pd
import sklearn.metrics as skm
import typing as t

from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf

from screendl.utils import evaluation as eval_utils

In [None]:
root = Path("../../../datastore")

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"

In [None]:
mut_path = root / "processed/WelmFinal/OmicsSomaticMutationsMAF.AllSamples.csv"
mut_data = pd.read_csv(mut_path)
mut_data.head()

In [None]:
pdmc_meta_path = root / f"datasets/{dataset}/pdmc/CellLineAnnotations.csv"
pdmc_meta = pd.read_csv(pdmc_meta_path, index_col=0).dropna(subset="sample_id_wes")
pdmc_meta.head()

In [None]:
pdmc_to_wes_id = pdmc_meta["sample_id_wes"].to_dict()

In [None]:
MODELS = ["ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA"]

In [None]:
parse_sift = lambda x: x if not isinstance(x, str) else x.split("(")[0]
parse_polyphen = lambda x: x if not isinstance(x, str) else x.split("(")[0]

In [None]:
HEIGHT = 200
WIDTH = 200

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        fold_id = file_path.parent.stem.split("_")[-1]
        fold_pred_df = pd.read_csv(file_path)
        fold_pred_df["fold"] = int(fold_id)
        return fold_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
AXIS_CONFIG = {
    "titleFont": "arial",
    "titlePadding": 5,
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}

BOXPLOT_CONFIG = {
    "size": 25,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}

def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures boxplot for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(labelFont="arial")
        .configure_boxplot(**BOXPLOT_CONFIG)
    )

In [None]:
class Plotter:
    """Builds plots for comparisons with biomarker-only models."""

    def __init__(
        self,
        D_bm: pd.DataFrame,
        D_ft: pd.DataFrame,
        D_sa: pd.DataFrame,
        x_domain: t.Tuple[int, int],
        y_domain: t.Tuple[int, int],
        biomarker_name: str,
    ) -> None:
        self.D_bm = D_bm.sort_values("is_ANY", ascending=False)
        self.D_ft = D_ft.sort_values("is_ANY", ascending=False)
        self.D_sa = D_sa.sort_values("is_ANY", ascending=False)
        self.x_domain = x_domain
        self.y_domain = y_domain
        self.biomarker_name = biomarker_name

    def plot(self) -> alt.Chart:
        """Renders the full plot."""
        bm_boxes = self.make_boxes(self.D_bm)
        bm_scatter = self.make_scatter(self.D_bm, include_mut_lines=False)
        ft_scatter = self.make_scatter(self.D_ft, include_mut_lines=True)
        sa_scatter = self.make_scatter(self.D_sa, include_mut_lines=True)
        return alt.hconcat(bm_boxes, bm_scatter, ft_scatter, sa_scatter)

    def make_boxes(self, D: pd.DataFrame) -> alt.Chart:
        """Makes the boxlplots stratified by mutation status"""
        return (
            alt.Chart(D, width=35 * 2, height=HEIGHT)
            .mark_boxplot()
            .encode(
                alt.X("is_ANY:N")
                .axis(labelAngle=0, title=self.biomarker_name)
                .scale(domain=("MUT", "WT")),
                self.true_y_encoding,
                self.color_encoding,
            )
        )

    def make_scatter(self, D: pd.DataFrame, include_mut_lines: bool) -> alt.Chart:
        """Makes the scatter plots."""
        base = alt.Chart(
            D.sort_values("is_ANY", ascending=False),
            width=WIDTH,
            height=HEIGHT,
        )

        extent = self.true_x_encoding.to_dict()["scale"]["domain"]

        chart = (
            base.transform_regression(
                "y_true",
                "y_pred",
                extent=extent,
            )
            .mark_line(stroke="black", strokeWidth=1)
            .encode(self.true_x_encoding, self.pred_y_encoding)
        )

        if include_mut_lines:

            chart += (
                base.transform_filter(alt.datum.is_ANY == "MUT")
                .transform_regression(
                    "y_true",
                    "y_pred",
                    extent=extent,
                )
                .mark_line(stroke="#5CA453", strokeWidth=2.5, strokeDash=[4, 4])
                .encode(self.true_x_encoding, self.pred_y_encoding)
            )

            chart += (
                base.transform_filter(alt.datum.is_ANY == "WT")
                .transform_regression(
                    "y_true",
                    "y_pred",
                    extent=extent,
                )
                .mark_line(stroke="darkgray", strokeWidth=2.5, strokeDash=[4, 4])
                .encode(self.true_x_encoding, self.pred_y_encoding)
            )

        chart += base.mark_circle(stroke="black").encode(
            self.true_x_encoding,
            self.pred_y_encoding,
            self.color_encoding,
            self.size_encoding,
            self.opacity_encoding,
            self.stroke_width_encoding,
            tooltip=[
                "cell_id:N",
                "y_true:Q",
                "y_pred:Q",
                "is_SIFT:N",
                "is_POLYPHEN:N",
                "is_CLINVAR:N",
                "is_HIGH:N",
            ],
        )

        return chart

    @property
    def true_x_encoding(self) -> alt.X:
        return (
            alt.X("y_true:Q")
            .scale(domain=self.x_domain)
            .axis(tickCount=3, grid=False)
            .title("Observed Z-Score ln(IC50)")
        )

    @property
    def true_y_encoding(self) -> alt.Y:
        return (
            alt.Y("y_true:Q")
            .scale(domain=self.y_domain)
            .axis(tickCount=3, grid=False)
            .title("Observed Z-Score ln(IC50)")
        )

    @property
    def pred_y_encoding(self) -> alt.Y:
        return (
            alt.Y("y_pred:Q")
            .scale(domain=self.y_domain)
            .axis(tickCount=3, grid=False)
            .title("Predicted Z-Score ln(IC50)")
        )

    @property
    def color_encoding(self) -> alt.Color:
        return (
            alt.Color("is_ANY:N")
            .scale(domain=("MUT", "WT", "U"), range=("#5CA453", "darkgray", "lightgray"))
            .legend(None)
        )

    @property
    def size_encoding(self) -> alt.SizeValue:
        return alt.condition(
            alt.datum.is_ANY == "U", alt.SizeValue(40), alt.SizeValue(80)
        )

    @property
    def stroke_width_encoding(self) -> alt.StrokeWidthValue:
        return alt.condition(
            alt.datum.is_ANY == "U", alt.StrokeWidthValue(0.25), alt.StrokeWidthValue(0.5)
        )

    @property
    def opacity_encoding(self) -> alt.Opacity:
        return alt.condition(
            alt.datum.is_ANY == "U",
            alt.OpacityValue(0.7),
            alt.Opacity("is_mut:N")
            .scale(domain=("MUT", "WT"), range=(0.8, 0.8))
            .legend(None),
        )

## Talazoparib + BRCA1/2

In [None]:
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2024-11-27_17-19-46"    # Talazoparib excluded from screening

path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
run_dir = root /"outputs" / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen": f"{model}-SA",
}

results = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)
results.head()

In [None]:
DRUG_ID = "Talazoparib"
BIOMARKER_GENE_IDS = ["BRCA1", "BRCA2"]

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
assert DRUG_ID in conf.screenahead.opt.exclude_drugs

In [None]:
ensembl_results = (
    results.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": "mean"})
    .reset_index()
)

corrs = ensembl_results.groupby(["model", "drug_id"]).apply(eval_utils.pcorr).unstack(0)
corrs.loc[corrs.index.str.contains("parib")][MODELS]

In [None]:
results_drug: pd.DataFrame = ensembl_results.query("drug_id == @DRUG_ID").copy()
results_drug["sample_id_wes"] = results_drug["cell_id"].map(pdmc_to_wes_id)
results_drug.head()

In [None]:
gene_muts = (
    # FIXME: use all samples here?
    mut_data.query("sample_barcode in @results_drug.sample_id_wes")
    .query("gene_symbol in @BIOMARKER_GENE_IDS")
    .assign(SIFT=lambda df: df["SIFT"].map(parse_sift))
    .assign(PolyPhen=lambda df: df["PolyPhen"].map(parse_polyphen))
    .copy()
)
gene_muts.head()

In [None]:
MUT_CATEGORIES = {True: "MUT", False: "WT", pd.NA: "U"}

In [None]:
is_sift_deleterious = lambda df: (df["SIFT"] == "deleterious").fillna(False)
is_polyphen_damaging = lambda df: df["PolyPhen"].str.contains("damaging").fillna(False)
is_clinvar_pathogenic = lambda df: df["CLIN_SIG"].str.contains("pathogenic").fillna(False)
is_high_impact = lambda df: (df["IMPACT"] == "HIGH").fillna(False)
annotation_cols = ["is_SIFT", "is_POLYPHEN", "is_CLINVAR", "is_HIGH"]

gene_muts_agg = (
    gene_muts.assign(
        is_SIFT=is_sift_deleterious,
        is_POLYPHEN=is_polyphen_damaging,
        is_CLINVAR=is_clinvar_pathogenic,
        is_HIGH=is_high_impact,
    )
    .filter(items=["sample_barcode", *annotation_cols])
    .groupby("sample_barcode")
    .max()
    .assign(is_ANY=lambda df: df.max(axis=1))
    .reset_index()
)

results_drug_muts = (
    results_drug.merge(
        gene_muts_agg,
        left_on="sample_id_wes",
        right_on="sample_barcode",
        how="left",
    )
    .replace({c: MUT_CATEGORIES for c in annotation_cols + ["is_ANY"]})
    .drop(columns="sample_barcode")
)

results_drug_muts.head()

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id", "is_ANY"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
biomarker_only_result = (
    results_drug_muts.query("model == 'ScreenDL-PT'")
    .filter(items=["cell_id", "drug_id", "y_true", "is_ANY", *annotation_cols])
    .query("is_ANY != 'U'")
    .assign(y_pred=lambda df: df.groupby("is_ANY")["y_true"].transform("mean"))
)

biomarker_only_result.groupby("is_ANY")["y_true"].describe()

In [None]:
eval_utils.pcorr(biomarker_only_result)

In [None]:
mut_resps = biomarker_only_result.query("is_ANY == 'MUT'")["y_true"]
wt_resps = biomarker_only_result.query("is_ANY == 'WT'")["y_true"]
stats.mannwhitneyu(mut_resps, wt_resps)

In [None]:
plotter = Plotter(
    D_bm=biomarker_only_result,
    D_ft=results_drug_muts.query("model == 'ScreenDL-FT'"),
    D_sa=results_drug_muts.query("model == 'ScreenDL-SA'"),
    x_domain=(-3, 3),
    y_domain=(-3, 3),
    biomarker_name="BRCA1/2"
)
talazoparib_chart = plotter.plot()
configure_chart(talazoparib_chart)

## Carboplatin + BRCA1/2

In [None]:
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2024-11-27_17-20-18"    # Carboplatin excluded from screening

path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
run_dir = root /"outputs" / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen": f"{model}-SA",
}

results = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)
results.head()

In [None]:
DRUG_ID = "Carboplatin"
BIOMARKER_GENE_IDS = ["BRCA1", "BRCA2"]

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
assert DRUG_ID in conf.screenahead.opt.exclude_drugs

In [None]:
ensembl_results = (
    results.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": "mean"})
    .reset_index()
)

corrs = ensembl_results.groupby(["model", "drug_id"]).apply(eval_utils.pcorr).unstack(0)
corrs.loc[DRUG_ID][MODELS]

In [None]:
results_drug: pd.DataFrame = ensembl_results.query("drug_id == @DRUG_ID").copy()
results_drug["sample_id_wes"] = results_drug["cell_id"].map(pdmc_to_wes_id)
results_drug.head()

In [None]:
gene_muts = (
    mut_data.query("sample_barcode in @results_drug.sample_id_wes")
    .query("gene_symbol in @BIOMARKER_GENE_IDS")
    .assign(SIFT=lambda df: df["SIFT"].map(parse_sift))
    .assign(PolyPhen=lambda df: df["PolyPhen"].map(parse_polyphen))
    .copy()
)
gene_muts.head()

In [None]:
MUT_CATEGORIES = {True: "MUT", False: "WT", pd.NA: "U"}

In [None]:
is_sift_deleterious = lambda df: (df["SIFT"] == "deleterious").fillna(False)
is_polyphen_damaging = lambda df: df["PolyPhen"].str.contains("damaging").fillna(False)
is_clinvar_pathogenic = lambda df: df["CLIN_SIG"].str.contains("pathogenic").fillna(False)
is_high_impact = lambda df: (df["IMPACT"] == "HIGH").fillna(False)
annotation_cols = ["is_SIFT", "is_POLYPHEN", "is_CLINVAR", "is_HIGH"]

gene_muts_agg = (
    gene_muts.assign(
        is_SIFT=is_sift_deleterious,
        is_POLYPHEN=is_polyphen_damaging,
        is_CLINVAR=is_clinvar_pathogenic,
        is_HIGH=is_high_impact,
    )
    .filter(items=["sample_barcode", *annotation_cols])
    .groupby("sample_barcode")
    .max()
    .assign(is_ANY=lambda df: df.max(axis=1))
    .reset_index()
)

results_drug_muts = (
    results_drug.merge(
        gene_muts_agg,
        left_on="sample_id_wes",
        right_on="sample_barcode",
        how="left",
    )
    .replace({c: MUT_CATEGORIES for c in annotation_cols + ["is_ANY"]})
    .drop(columns="sample_barcode")
)

results_drug_muts.head()

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id", "is_ANY"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
biomarker_only_result = (
    results_drug_muts.query("model == 'ScreenDL-PT'")
    .filter(items=["cell_id", "drug_id", "y_true", "is_ANY", *annotation_cols])
    .query("is_ANY != 'U'")
    .assign(y_pred=lambda df: df.groupby("is_ANY")["y_true"].transform("mean"))
)

biomarker_only_result.groupby("is_ANY")["y_true"].describe()

In [None]:
eval_utils.pcorr(biomarker_only_result)

In [None]:
mut_resps = biomarker_only_result.query("is_ANY == 'MUT'")["y_true"]
wt_resps = biomarker_only_result.query("is_ANY == 'WT'")["y_true"]
stats.mannwhitneyu(mut_resps, wt_resps)

In [None]:
plotter = Plotter(
    D_bm=biomarker_only_result,
    D_ft=results_drug_muts.query("model == 'ScreenDL-FT'"),
    D_sa=results_drug_muts.query("model == 'ScreenDL-SA'"),
    x_domain=(-4, 4),
    y_domain=(-4, 4),
    biomarker_name="BRCA1/2"
)
carboplatin_chart = plotter.plot()
configure_chart(carboplatin_chart)

## Capivasertib + PIK3CA/AKT1/PTEN

In [None]:
dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
date = "2024-11-29_08-42-26"    # Capivasertib excluded from screening

path_fmt = "experiments/pdxo_validation/{0}/{1}/multiruns/{2}"
run_dir = root /"outputs" / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

fixed_model_names = {
    "base": "ScreenDL-PT",
    "xfer": f"{model}-FT",
    "screen": f"{model}-SA",
}

results = load_multirun_predictions(run_dir, run_regex, splits=None).assign(
    model=lambda df: df["model"].map(fixed_model_names),
    was_screened=lambda df: df["was_screened"].fillna(False),
)
results.head()

In [None]:
DRUG_ID = "AZD5363"
BIOMARKER_GENE_IDS = ["PIK3CA", "AKT1", "PTEN"]

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
assert DRUG_ID in conf.screenahead.opt.exclude_drugs

In [None]:
ensembl_results = (
    results.groupby(["model", "drug_id", "cell_id"])
    .agg({"y_true": "first", "y_pred": "mean"})
    .reset_index()
)

corrs = ensembl_results.groupby(["model", "drug_id"]).apply(eval_utils.pcorr).unstack(0)
corrs.loc[DRUG_ID][MODELS]

In [None]:
results_drug: pd.DataFrame = ensembl_results.query("drug_id == @DRUG_ID").copy()
results_drug["sample_id_wes"] = results_drug["cell_id"].map(pdmc_to_wes_id)
results_drug.head()

In [None]:
gene_muts = (
    mut_data.query("sample_barcode in @results_drug.sample_id_wes")
    .query("gene_symbol in @BIOMARKER_GENE_IDS")
    .assign(SIFT=lambda df: df["SIFT"].map(parse_sift))
    .assign(PolyPhen=lambda df: df["PolyPhen"].map(parse_polyphen))
    .copy()
)
gene_muts.head()

In [None]:
MUT_CATEGORIES = {True: "MUT", False: "WT", pd.NA: "U"}

In [None]:
is_sift_deleterious = lambda df: (df["SIFT"] == "deleterious").fillna(False)
is_polyphen_damaging = lambda df: df["PolyPhen"].str.contains("damaging").fillna(False)
is_clinvar_pathogenic = lambda df: df["CLIN_SIG"].str.contains("pathogenic").fillna(False)
is_high_impact = lambda df: (df["IMPACT"] == "HIGH").fillna(False)
annotation_cols = ["is_SIFT", "is_POLYPHEN", "is_CLINVAR", "is_HIGH"]

gene_muts_agg = (
    gene_muts.assign(
        is_SIFT=is_sift_deleterious,
        is_POLYPHEN=is_polyphen_damaging,
        is_CLINVAR=is_clinvar_pathogenic,
        is_HIGH=is_high_impact,
    )
    .filter(items=["sample_barcode", *annotation_cols])
    .groupby("sample_barcode")
    .max()
    .assign(is_ANY=lambda df: df.max(axis=1))
    .reset_index()
)

results_drug_muts = (
    results_drug.merge(
        gene_muts_agg,
        left_on="sample_id_wes",
        right_on="sample_barcode",
        how="left",
    )
    .replace({c: MUT_CATEGORIES for c in annotation_cols + ["is_ANY"]})
    .drop(columns="sample_barcode")
)

results_drug_muts.head()

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
corrs = results_drug_muts.groupby(["model", "drug_id", "is_ANY"]).apply(eval_utils.pcorr)
corrs.unstack(0)[MODELS]

In [None]:
biomarker_only_result = (
    results_drug_muts.query("model == 'ScreenDL-PT'")
    .filter(items=["cell_id", "drug_id", "y_true", "is_ANY", *annotation_cols])
    .query("is_ANY != 'U'")
    .assign(y_pred=lambda df: df.groupby("is_ANY")["y_true"].transform("mean"))
)

biomarker_only_result.groupby("is_ANY")["y_true"].describe()

In [None]:
eval_utils.pcorr(biomarker_only_result)

In [None]:
mut_resps = biomarker_only_result.query("is_ANY == 'MUT'")["y_true"]
wt_resps = biomarker_only_result.query("is_ANY == 'WT'")["y_true"]
stats.mannwhitneyu(mut_resps, wt_resps)

In [None]:
plotter = Plotter(
    D_bm=biomarker_only_result,
    D_ft=results_drug_muts.query("model == 'ScreenDL-FT'"),
    D_sa=results_drug_muts.query("model == 'ScreenDL-SA'"),
    x_domain=(-3, 3),
    y_domain=(-3, 3),
    biomarker_name="PIK3CA"
)
capivasertib_chart = plotter.plot()
configure_chart(capivasertib_chart)

In [None]:
final_figure = alt.vconcat(
    capivasertib_chart, talazoparib_chart, carboplatin_chart, spacing=20
)

configure_chart(final_figure)