# Performance Evaluation in PDX Models

## Contents

- [Data Loading](#data-loading)

In [None]:
# DEEPCDR_ROOT="/scratch/ucgd/lustre-work/marth/u0871891/projects/screendl/pkg/DeepCDR/prog" python scripts/experiments/pdx_benchmarking.py -m model=DeepCDR-legacy dataset.preprocess.norm=global dataset=CellModelPassports-GDSCv1v2-HCI-Mutations
# HIDRA_ROOT="/scratch/ucgd/lustre-work/marth/u0871891/projects/screendl/pkg/HiDRA" python scripts/experiments/pdx_benchmarking.py -m model=HiDRA-legacy dataset.preprocess.norm=global dataset=CellModelPassports-GDSCv1v2-HCI

In [None]:
from __future__ import annotations

import altair as alt
import pandas as pd
import numpy as np
import typing as t
import sklearn.metrics as skm

from pathlib import Path
from scipy import stats
from omegaconf import OmegaConf

from cdrpy.feat.transformers import GroupStandardScaler

from screendl.utils import evaluation as eval_utils

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    def load_run(file_path: Path) -> pd.DataFrame:
        iter_id = file_path.parent.stem.split("_")[-1]
        iter_pred_df = pd.read_csv(file_path)
        iter_pred_df["iter"] = int(iter_id)
        return iter_pred_df

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(load_run, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
def rescale_predictions(
    pdxo_df: pd.DataFrame, pdx_df: pd.DataFrame
) -> t.Tuple[pd.DataFrame, pd.DataFrame]:
    """"""
    gss = GroupStandardScaler()
    pdxo_df["y_true"] = gss.fit_transform(pdxo_df[["y_true"]], groups=pdxo_df["drug_id"])
    # pdx_df["y_true"] = gss.transform(pdx_df[["y_true"]], groups=pdx_df["drug_id"])

    gss = GroupStandardScaler()
    pdxo_df["y_pred"] = gss.fit_transform(pdxo_df[["y_pred"]], groups=pdxo_df["drug_id"])
    pdx_df["y_pred"] = gss.transform(pdx_df[["y_pred"]], groups=pdx_df["drug_id"])

    return pdxo_df, pdx_df

In [None]:
def auroc(df: pd.DataFrame, col1: str = "y_true", col2: str = "y_pred") -> float:
    if df[col1].nunique() <= 1:
        return np.nan
    return skm.roc_auc_score(df[col1], -1 * df[col2])


def select_best_therapy(df: pd.DataFrame, on_: str = "y_pred") -> pd.DataFrame:
    return df.loc[df[on_].idxmin()]


def get_response_rate(df: pd.DataFrame, col: str = "y_true") -> float:
    return np.mean(df[col])

## Data Loading

In [None]:
root = Path("../../../datastore")

In [None]:
data_dir = root / "inputs/CellModelPassports-GDSCv1v2-HCI"
cell_meta = pd.read_csv(data_dir / "MetaSampleAnnotations.csv", index_col=0)

In [None]:
raw_pdx_obs = pd.read_csv(root / "processed/WelmPDX/ScreenClinicalResponseV14B20.csv")
raw_pdx_obs.head()

In [None]:
raw_pdxo_obs = pd.read_csv(
    root / "inputs/CellModelPassports-GDSCv1v2-HCI/LabelsLogIC50.csv"
)

raw_pdxo_obs = raw_pdxo_obs[~raw_pdxo_obs["cell_id"].str.startswith("SIDM")]
raw_pdxo_obs["label"] = raw_pdxo_obs.groupby("drug_id")["label"].transform(stats.zscore)
raw_pdx_obs.head()

In [None]:
raw_pdxo_screen = pd.read_csv(root / "processed/WelmFinal/ScreenDoseResponse.csv")
grouped = raw_pdxo_screen.groupby("drug_name")
# NOTE: a large GR_AOC is better so we multiple Zd values by -1
raw_pdxo_screen["z_GR_AOC"] = grouped["GR_AOC"].transform(lambda x: stats.zscore(x) * -1)
raw_pdxo_screen["z_LN_IC50"] = grouped["LN_IC50"].transform(stats.zscore)
raw_pdxo_screen.head()

In [None]:
raw_pdx_data = pd.read_csv(root / "processed/WelmPDX/ScreenClinicalResponseV13B20RawData.csv")
raw_pdx_data_ctrl = raw_pdx_data[raw_pdx_data["drug_name"] == "Vehicle"]
raw_pdx_data_drug = raw_pdx_data[raw_pdx_data["drug_name"] != "Vehicle"]

In [None]:
# DeepCDR results

output_dir = root / "outputs/experiments/pdx_benchmarking"
path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI-Mutations"
model = "DeepCDR-legacy"
date = "2024-11-21_11-17-57"

run_dir = output_dir / path_fmt.format(dataset, model, date)


deepcdr_pdxo_result = load_multirun_predictions(
    run_dir, "*/predictions.csv", splits=["test"]
)
deepcdr_pdx_result = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

# NOTE: we do the zscore transformation against the PDxO background
grouped_pdxo = deepcdr_pdxo_result.groupby("iter", as_index=False)
grouped_pdx = deepcdr_pdx_result.groupby("iter", as_index=False)
pdxo_groups, pdx_groups = [], []
for group in grouped_pdxo.grouper.groups.keys():
    pdxo_group = grouped_pdxo.get_group(group).copy()
    pdx_group = grouped_pdx.get_group(group).copy()
    pdxo_group, pdx_group = rescale_predictions(pdxo_group, pdx_group)
    pdxo_groups.append(pdxo_group)
    pdx_groups.append(pdx_group)

deepcdr_pdxo_result = pd.concat(pdxo_groups).reset_index(drop=True)
deepcdr_pdx_result = pd.concat(pdx_groups).reset_index(drop=True)

In [None]:
# HiDRA results

output_dir = root / "outputs/experiments/pdx_benchmarking"
path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "HiDRA-legacy"
date = "2024-11-21_11-15-25"

run_dir = output_dir / path_fmt.format(dataset, model, date)

hidra_pdxo_result = load_multirun_predictions(
    run_dir, "*/predictions.csv", splits=["test"]
)
hidra_pdx_result = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

grouped_pdxo = hidra_pdxo_result.groupby("iter", as_index=False)
grouped_pdx = hidra_pdx_result.groupby("iter", as_index=False)
pdxo_groups, pdx_groups = [], []
for group in grouped_pdxo.grouper.groups.keys():
    pdxo_group = grouped_pdxo.get_group(group).copy()
    pdx_group = grouped_pdx.get_group(group).copy()
    pdxo_group, pdx_group = rescale_predictions(pdxo_group, pdx_group)
    pdxo_groups.append(pdxo_group)
    pdx_groups.append(pdx_group)

hidra_pdxo_result = pd.concat(pdxo_groups).reset_index(drop=True)
hidra_pdx_result = pd.concat(pdx_groups).reset_index(drop=True)

In [None]:
output_dir = root / "outputs/experiments/pdx_validation"

path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"

# new data final pub V12.3 fixed script
date = "2024-11-21_14-08-27"

run_dir = output_dir / path_fmt.format(dataset, model, date)

screendl_pdxo_results = load_multirun_predictions(run_dir, "*/predictions_pdxo.csv")
screendl_pdx_results = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")

mapper = {"base": "ScreenDL-PT", "xfer": "ScreenDL-FT", "screen": "ScreenDL-SA"}
screendl_pdxo_results["model"] = screendl_pdxo_results["model"].map(mapper)
screendl_pdx_results["model"] = screendl_pdx_results["model"].map(mapper)

In [None]:
conf = OmegaConf.load(run_dir / "multirun.yaml")
print(OmegaConf.to_yaml(conf.xfer))

In [None]:
MODELS = ["HiDRA", "DeepCDR", "ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA"]

In [None]:
pdx_result = pd.concat([hidra_pdx_result, deepcdr_pdx_result, screendl_pdx_results])
pdxo_result = pd.concat([hidra_pdxo_result, deepcdr_pdxo_result, screendl_pdxo_results])

In [None]:
# add in additional annotations
temp = raw_pdx_obs.drop(columns=["id", "label"])
pdxo_result = pdxo_result.merge(temp, how="left", on=["cell_id", "drug_id"])
pdxo_result = pdxo_result.sort_values(["model", "iter", "cell_id", "drug_id"])
pdx_result = pdx_result.merge(temp, how="left", on=["cell_id", "drug_id"])
pdx_result = pdx_result.sort_values(["model", "iter", "cell_id", "drug_id"])

In [None]:
# filter the results
keep_drugs = screendl_pdx_results["drug_id"].unique()
keep_pdxs = screendl_pdx_results["cell_id"].unique()

pdx_result = pdx_result[pdx_result["cell_id"].isin(keep_pdxs)]
pdx_result = pdx_result[pdx_result["drug_id"].isin(keep_drugs)]

In [None]:
ensembl_pdxo_result = (
    pdxo_result.groupby(["model", "cell_id", "drug_id"])
    # .aggregate({"y_pred": "mean", "y_true": "first"})
    .aggregate({"y_pred": lambda x: stats.trim_mean(x, 0.2), "y_true": "first"})
    .reset_index()
)

In [None]:
ensembl_pdxo_corrs = (
    ensembl_pdxo_result.groupby(["model", "drug_id"])
    .apply(eval_utils.pcorr)
    .to_frame(name="pcc")
    .reset_index()
)

ensembl_pdxo_corrs.groupby("model").describe().loc[MODELS].T

In [None]:
ensembl_pdxo_mse = (
    ensembl_pdxo_result.groupby(["model", "drug_id"])
    .apply(lambda g: ((g["y_true"] - g["y_pred"]) ** 2).sum() / len(g))
    .to_frame(name="mse")
    .reset_index()
)

ensembl_pdxo_mse.groupby("model").describe().loc[MODELS].T

In [None]:
# filter drugs and PDX samples for evaluation

pdx_result_eval = pdx_result.copy()

# drugs that we know don't generalize from PDxO screening
# NON_GEN_DRUGS = ["Palbociclib", "Abemaciclib", "Tamoxifen"]
# NON_GEN_DRUGS = ["Tamoxifen"]
# pdx_result_eval = pdx_result_eval.query("drug_id not in @NON_GEN_DRUGS")

# drugs that did not work in any PDX models -> this is likely due to legacy dosing
responders_per_drug = raw_pdx_obs.groupby("drug_id")["label"].sum()
DRUGS_WITH_RESPONDERS = responders_per_drug[responders_per_drug >= 1].index
pdx_result_eval = pdx_result_eval.query("drug_id in @DRUGS_WITH_RESPONDERS")

# drugs with at least 20 PDxOs screened so we have good Zd background
# pdxos_per_drug = pdxo_result.groupby("drug_id")["cell_id"].nunique()
# DRUGS_WITH_ENOUGH_PDXOS = pdxos_per_drug[pdxos_per_drug >= 20].index
# pdx_result_eval = pdx_result_eval.query("drug_id in @DRUGS_WITH_ENOUGH_PDXOS")

# drop PDX samples with less than 2 drugs to choose from
drugs_per_PDX = pdx_result_eval.groupby("cell_id")["drug_id"].nunique()
GOOD_PDX_SAMPLES = drugs_per_PDX[drugs_per_PDX > 1].index
pdx_result_eval = pdx_result_eval.query("cell_id in @GOOD_PDX_SAMPLES")

# filter for PDXs predicted by all models
# model_counts = pdx_result_eval.groupby("cell_id")["model"].nunique()
# keep_cells = model_counts[model_counts == pdx_result_eval["model"].nunique()].index
# pdx_result_eval = pdx_result_eval.query("cell_id in @keep_cells")

uniq_cells = sorted(list(pdx_result_eval["cell_id"].unique()))
uniq_drugs = sorted(list(pdx_result_eval["drug_id"].unique()))
print(uniq_cells)
print(uniq_drugs)

print(f"No. PDXs: {len(uniq_cells)}")
print(f"No. Drugs: {len(uniq_drugs)}")

base_rr = get_response_rate(pdx_result_eval.drop_duplicates(["cell_id", "drug_id"]))
print(f"Response Rate: {base_rr:.2f}")

pdx_result_eval.head()

In [None]:
ensembl_pdx_result_eval = (
    pdx_result_eval.groupby(["model", "drug_id", "cell_id"])
    # .aggregate({"y_true": "first", "y_pred": "mean", "mRECIST": "first"})
    .aggregate({"y_true": "first", "y_pred": lambda x: stats.trim_mean(x, 0.2), "mRECIST": "first"})
    .reset_index()
)

In [None]:
ensembl_pdx_result_eval["drug_id"].nunique()

In [None]:
raw_screen_preds_gr_aoc = (
    raw_pdxo_screen.rename(
        columns={"model_id": "cell_id", "drug_name": "drug_id", "z_GR_AOC": "y_pred"}
    )
    .filter(items=["cell_id", "drug_id", "y_pred"])
    .assign(model="Screen - GR AOC")
)

raw_screen_preds_gr_aoc = (
    raw_pdx_obs.rename(columns={"label": "y_true"})
    .filter(items=["cell_id", "drug_id", "y_true", "mRECIST"])
    .drop_duplicates()
    .merge(raw_screen_preds_gr_aoc, on=["cell_id", "drug_id"])
    .query("cell_id in @ensembl_pdx_result_eval.cell_id")
    .query("drug_id in @ensembl_pdx_result_eval.drug_id")
)

raw_screen_preds_ln_ic50 = (
    raw_pdxo_screen.rename(
        columns={"model_id": "cell_id", "drug_name": "drug_id", "z_LN_IC50": "y_pred"}
    )
    .filter(items=["cell_id", "drug_id", "y_pred"])
    .assign(model="Screen - ln(IC50)")
)

raw_screen_preds_ln_ic50 = (
    raw_pdx_obs.rename(columns={"label": "y_true"})
    .filter(items=["cell_id", "drug_id", "y_true", "mRECIST"])
    .drop_duplicates()
    .merge(raw_screen_preds_ln_ic50, on=["cell_id", "drug_id"])
    .query("cell_id in @ensembl_pdx_result_eval.cell_id")
    .query("drug_id in @ensembl_pdx_result_eval.drug_id")
)

raw_screen_pdx_result_eval = pd.concat(
    [raw_screen_preds_ln_ic50, raw_screen_preds_gr_aoc]
)

raw_screen_pdx_result_eval.head()

In [None]:
SCREEN_MODELS = ["Screen - ln(IC50)", "Screen - GR AOC"]

In [None]:
ensembl_pdxo_corrs_eval = (
    ensembl_pdxo_result
    .query("cell_id not in @ensembl_pdx_result_eval.cell_id")
    .groupby(["model", "drug_id"])
    .apply(eval_utils.pcorr)
    .to_frame(name="pcc")
    .reset_index()
)

ensembl_pdxo_corrs_eval.groupby("model").describe().loc[MODELS].T

In [None]:
ensembl_pdxo_mse_eval = (
    ensembl_pdxo_result.query("cell_id not in @ensembl_pdx_result_eval.cell_id")
    .groupby(["model", "drug_id"])
    .apply(lambda g: np.mean((g["y_true"] - g["y_pred"]) ** 2))
    .to_frame(name="mse")
    .reset_index()
)

ensembl_pdxo_mse_eval.groupby("model").describe().loc[MODELS].T.round(4)

In [None]:
# NOTE: use the PDxO data for the other PDX models to find the optimal value of C

In [None]:
# NOTE: will need to do this for each model independently

In [None]:
ensembl_selected_drugs = (
    ensembl_pdx_result_eval.merge(
        ensembl_pdxo_corrs_eval.dropna(), on=["model", "drug_id"]
    )
    # .query("pcc >= 0")  # don't select low confidence drugs
    .groupby(["model", "cell_id"], as_index=False)
    .apply(lambda df: select_best_therapy(df, on_="y_pred"))
    .sort_values(["model", "cell_id"])
)

ensembl_CBRs = ensembl_selected_drugs.groupby("model").apply(get_response_rate)
ensembl_CBRs.loc[MODELS]

In [None]:
ensembl_selected_drugs.groupby("model")["cell_id"].nunique()

In [None]:
raw_screen_selected_drugs = (
    raw_screen_pdx_result_eval.groupby(["model", "cell_id"], as_index=False)
    .apply(lambda df: select_best_therapy(df, on_="y_pred"))
    .sort_values(["model", "cell_id"])
)

raw_screen_CBRs = raw_screen_selected_drugs.groupby("model").apply(get_response_rate)
raw_screen_CBRs.loc[SCREEN_MODELS]

In [None]:
ensembl_ORRs = (
    ensembl_selected_drugs.assign(
        y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int)
    )
    .groupby("model")
    .apply(get_response_rate)
)

ensembl_ORRs.loc[MODELS]

In [None]:
raw_screen_ORRs = (
    raw_screen_selected_drugs.assign(
        y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int)
    )
    .groupby("model")
    .apply(get_response_rate)
)

raw_screen_ORRs.loc[SCREEN_MODELS]

In [None]:
temp = pd.concat([ensembl_pdx_result_eval, raw_screen_pdx_result_eval])
grouped = temp.groupby(["model", "cell_id"])
grouped["drug_id"].nunique().unstack(0)[MODELS + SCREEN_MODELS]

In [None]:
model_CBR_source = (
    ensembl_selected_drugs.groupby("model")
    .apply(get_response_rate)
    .loc[MODELS]
    .to_frame(name="response_rate")
    .reset_index()
)

screen_CBR_source = (
    raw_screen_selected_drugs.groupby("model")
    .apply(get_response_rate)
    .loc[SCREEN_MODELS]
    .to_frame(name="response_rate")
    .reset_index()
)

In [None]:
model_ORR_source = (
    ensembl_selected_drugs.assign(
        y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int)
    )
    .groupby("model")
    .apply(get_response_rate)
    .loc[MODELS]
    .to_frame(name="response_rate")
    .reset_index()
)

screen_ORR_source = (
    raw_screen_selected_drugs.assign(
        y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int)
    )
    .groupby("model")
    .apply(get_response_rate)
    .loc[SCREEN_MODELS]
    .to_frame(name="response_rate")
    .reset_index()
)

In [None]:
# clinical benefit rate for random selection
options = ensembl_pdx_result_eval.drop_duplicates(["cell_id", "drug_id"])
max_iters = 1000
iter_CBRs_models = []
for _ in range(max_iters):
    iter_rr = get_response_rate(options.groupby("cell_id").sample(1))
    iter_CBRs_models.append(iter_rr)
rand_CBR_models = np.mean(iter_CBRs_models)
print(f"Random Selection ORR (Models): {rand_CBR_models}")

In [None]:
options = raw_screen_pdx_result_eval.drop_duplicates(["cell_id", "drug_id"])
max_iters = 1000
iter_CBRs_screen = []
for _ in range(max_iters):
    iter_rr = get_response_rate(options.groupby("cell_id").sample(1))
    iter_CBRs_screen.append(iter_rr)
rand_CBR_screen = np.mean(iter_CBRs_screen)
print(f"Random Selection ORR (Screen): {rand_CBR_screen}")

In [None]:
options = ensembl_pdx_result_eval.drop_duplicates(["cell_id", "drug_id"]).assign(
    y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int), how="left"
)

max_iters = 1000
iter_ORRs_models = []
for _ in range(max_iters):
    iter_rr = get_response_rate(options.groupby("cell_id").sample(1))
    iter_ORRs_models.append(iter_rr)
rand_ORR_models = np.mean(iter_ORRs_models)
print(f"Random Selection ORR (Models): {rand_ORR_models}")

In [None]:
options = raw_screen_pdx_result_eval.drop_duplicates(["cell_id", "drug_id"]).assign(
    y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int), how="left"
)

max_iters = 1000
iter_ORRs_screen = []
for _ in range(max_iters):
    iter_rr = get_response_rate(options.groupby("cell_id").sample(1))
    iter_ORRs_screen.append(iter_rr)
rand_ORR_screen = np.mean(iter_ORRs_screen)
print(f"Random Selection ORR (Screen): {rand_ORR_screen}")

In [None]:
num_PDX_samples = ensembl_pdx_result_eval["cell_id"].nunique()

grouped = ensembl_pdx_result_eval.groupby("cell_id")
min_CBR_models = grouped["y_true"].min().sum() / num_PDX_samples
max_CBR_models = grouped["y_true"].max().sum() / num_PDX_samples
print(f"Min Achievable CBR (Models): {min_CBR_models}")
print(f"Max Achievable CBR (Models): {max_CBR_models}")

In [None]:
num_PDX_samples = raw_screen_pdx_result_eval["cell_id"].nunique()

grouped = raw_screen_pdx_result_eval.groupby("cell_id")
min_CBR_screen = grouped["y_true"].min().sum() / num_PDX_samples
max_CBR_screen = grouped["y_true"].max().sum() / num_PDX_samples
print(f"Min Achievable CBR (Screen): {min_CBR_screen}")
print(f"Max Achievable CBR (Screen): {max_CBR_screen}")

In [None]:
num_PDX_samples = ensembl_pdx_result_eval["cell_id"].nunique()

grouped = ensembl_pdx_result_eval.assign(
    y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int), how="left"
).groupby("cell_id")["y_true"]

min_ORR_models = grouped.min().sum() / num_PDX_samples
max_ORR_models = grouped.max().sum() / num_PDX_samples
print(f"Min Achievable ORR (Models): {min_ORR_models}")
print(f"Max Achievable ORR (Models): {max_ORR_models}")

In [None]:
num_PDX_samples = raw_screen_pdx_result_eval["cell_id"].nunique()

grouped = raw_screen_pdx_result_eval.assign(
    y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int), how="left"
).groupby("cell_id")["y_true"]

min_ORR_screen = grouped.min().sum() / num_PDX_samples
max_ORR_screen = grouped.max().sum() / num_PDX_samples
print(f"Min Achievable ORR (Screen): {min_ORR_screen}")
print(f"Max Achievable ORR (Screen): {max_ORR_screen}")

In [None]:
MODEL_COLOR_SCALE = alt.Scale(
    domain=["Screen - ln(IC50)", "Screen - GR AOC"] + MODELS,
    range=("#FDBFD3", "#FF9DA5", "darkgray", "gray", "#4C78A8", "#B278A2", "#5CA453"),
)

AXIS_CONFIG = {
    "titleFont": "arial",
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}


def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures boxplot for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(labelFont="arial")
    )

In [None]:
base_models = alt.Chart(model_CBR_source)

bars_models = (
    base_models.mark_bar(stroke="black", size=25, strokeWidth=1)
    .encode(
        alt.Y("response_rate:Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 1))
        .title("Clinical Benefit Rate (%)"),
        alt.X("model:N")
        .axis(domainColor="black", labelAngle=-45)
        .scale(domain=MODELS, paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(height=230, width=31 * len(MODELS))
)

rand_CBR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [rand_CBR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

min_CBR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [min_CBR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

max_CBR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [max_CBR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

text_models = base_models.mark_text(align="center", dy=-15, fontSize=10).encode(
    alt.YValue(1),
    alt.X("model:N").title(None),
    alt.Text("response_rate:Q", format=".0%"),
)


CBR_chart_models = alt.layer(
    bars_models,
    rand_CBR_rule_models,
    min_CBR_rule_models,
    max_CBR_rule_models,
    text_models,
)

In [None]:
base_screen = alt.Chart(screen_CBR_source)

bars_screen = (
    base_screen.mark_bar(stroke="black", size=25, strokeWidth=1)
    .encode(
        alt.Y("response_rate:Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 1))
        .title("Clinical Benefit Rate (%)"),
        alt.X("model:N")
        .axis(domainColor="black", labelAngle=-45)
        .scale(domain=SCREEN_MODELS, paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(height=230, width=32 * len(SCREEN_MODELS))
)

rand_CBR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [rand_CBR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

min_CBR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [min_CBR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

max_CBR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [max_CBR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

text_screen = base_screen.mark_text(align="center", dy=-15, fontSize=10).encode(
    alt.YValue(1),
    alt.X("model:N").title(None),
    alt.Text("response_rate:Q", format=".0%"),
)


CBR_chart_screen = alt.layer(
    bars_screen,
    rand_CBR_rule_screen,
    min_CBR_rule_screen,
    max_CBR_rule_screen,
    text_screen,
)

In [None]:
CBR_chart = alt.hconcat(CBR_chart_screen, CBR_chart_models).resolve_axis(y="shared")
# configure_chart(CBR_chart)

In [None]:
max(iter_ORRs_models)

In [None]:
base_models = alt.Chart(model_ORR_source)

bars_models = (
    base_models.mark_bar(stroke="black", size=25, strokeWidth=1)
    .encode(
        alt.Y("response_rate:Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 1))
        .title("Objective Response Rate (%)"),
        alt.X("model:N")
        .axis(domainColor="black", labelAngle=-45)
        .scale(domain=MODELS, paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(height=230, width=31 * len(MODELS))
)

rand_ORR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [rand_ORR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

min_ORR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [min_ORR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

max_ORR_rule_models = (
    alt.Chart(pd.DataFrame({"y": [max_ORR_models]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

text_models = base_models.mark_text(align="center", dy=-15, fontSize=10).encode(
    alt.YValue(1),
    alt.X("model:N").title(None),
    alt.Text("response_rate:Q", format=".0%"),
)


ORR_chart_models = alt.layer(
    bars_models,
    rand_ORR_rule_models,
    min_ORR_rule_models,
    max_ORR_rule_models,
    text_models,
)

In [None]:
base_screen = alt.Chart(screen_ORR_source)

bars_screen = (
    base_screen.mark_bar(stroke="black", size=25, strokeWidth=1)
    .encode(
        alt.Y("response_rate:Q")
        .axis(grid=False, tickCount=4, domainColor="black", titlePadding=10, format="%")
        .scale(domain=(0, 1))
        .title("Objective Response Rate (%)"),
        alt.X("model:N")
        .axis(domainColor="black", labelAngle=-45)
        .scale(domain=SCREEN_MODELS, paddingOuter=0.15)
        .title(None),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
    .properties(height=230, width=32 * len(SCREEN_MODELS))
)

rand_ORR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [rand_ORR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

min_ORR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [min_ORR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

max_ORR_rule_screen = (
    alt.Chart(pd.DataFrame({"y": [max_ORR_screen]}))
    .mark_rule(stroke="black", strokeDash=[3, 3], strokeWidth=1.5)
    .encode(y="y:Q")
)

text_screen = base_screen.mark_text(align="center", dy=-15, fontSize=10).encode(
    alt.YValue(1),
    alt.X("model:N").title(None),
    alt.Text("response_rate:Q", format=".0%"),
)


ORR_chart_screen = alt.layer(
    bars_screen,
    rand_ORR_rule_screen,
    min_ORR_rule_screen,
    max_ORR_rule_screen,
    text_screen,
)

In [None]:
ORR_chart = alt.hconcat(ORR_chart_screen, ORR_chart_models).resolve_axis(y="shared")

In [None]:
chart = alt.hconcat(CBR_chart, ORR_chart, spacing=40)
configure_chart(chart)

In [None]:
# percentage of maximum attainable ORR acheived
model_ORR_source.set_index("model")["response_rate"] / max_ORR_models

In [None]:
# percentage of maximum attainable ORR acheived
screen_ORR_source.set_index("model")["response_rate"] / max_ORR_screen

In [None]:
raw_screen_selected_drugs.query("cell_id == 'HCI001'")

In [None]:
(
    ensembl_selected_drugs.query("model == 'ScreenDL-FT'")
    .assign(y_true=lambda df: df["mRECIST"].isin(["CR", "PR"]))["y_true"]
    .sum()
    + 1
) / 20

## Waterfall plots for all drugs vs drug selected by ScreenDL-SA

In [None]:
X = (
    raw_pdx_obs[["cell_id", "drug_id", "r_best", "r_avg", "mRECIST"]]
    .merge(
        ensembl_selected_drugs.query("model == 'ScreenDL-SA'")
        .filter(items=["cell_id", "drug_id"])
        .assign(was_selected=True),
        how="left",
    )
    .query("cell_id in @ensembl_selected_drugs.cell_id")
    .fillna({"was_selected": False})
    .assign(
        x=lambda df: df["cell_id"] + " + " + df["drug_id"],
        r_avg=lambda df: df["r_avg"] / 100,
    )
)

In [None]:
not_selected_chart = (
    alt.Chart(X.query("was_selected == False"))
    .mark_bar(size=13.5, stroke="black", strokeWidth=1)
    .encode(
        alt.X("x:N")
        .sort("-y")
        .axis(grid=False, labels=False, ticks=False, offset=-125)
        .scale(paddingOuter=0.2)
        .title(None),
        alt.Y("r_avg:Q")
        .axis(grid=False, tickCount=5, titlePadding=10, format="%")
        .scale(domain=(-1, 1), clamp=True)
        .title(["Change in tumor volume (%)", "(BestAvgResponse)"]),
        alt.Color("mRECIST:N").scale(
            domain=("CR", "PR", "SD", "PD"),
            range=("#9ECAE9", "#89D27A", "#F2CF5B", "#FF9D98"),
        ),
    )
    .properties(width=16 * X.query("was_selected == False").shape[0], height=250)
)

selected_chart = (
    alt.Chart(X.query("was_selected == True"))
    .mark_bar(size=13.5, stroke="black", strokeWidth=1)
    .encode(
        alt.X("x:N")
        .sort("-y")
        .axis(grid=False, labels=False, ticks=False, offset=-125)
        .scale(paddingOuter=0.2)
        .title(None),
        alt.Y("r_avg:Q")
        .axis(grid=False, tickCount=5, titlePadding=10, format="%")
        .scale(domain=(-1, 1), clamp=True)
        .title(["Change in tumor volume (%)", "(BestAvgResponse)"]),
        alt.Color("mRECIST:N").scale(
            domain=("CR", "PR", "SD", "PD"),
            range=("#9ECAE9", "#89D27A", "#F2CF5B", "#FF9D98"),
        ),
    )
    .properties(width=16 * X.query("was_selected == True").shape[0], height=250)
)

configure_chart(alt.hconcat(selected_chart, not_selected_chart, spacing=40))

In [None]:
CBR_ctab_models = (
    X.assign(CBR=lambda df: df["mRECIST"].isin(["CR", "PR", "SD"]))
    .groupby(["was_selected", "CBR"])
    .size()
    .unstack(0)
)

stats.fisher_exact(CBR_ctab_models)

In [None]:
ORR_ctab_models = (
    X.assign(ORR=lambda df: df["mRECIST"].isin(["CR", "PR"]))
    .groupby(["was_selected", "ORR"])
    .size()
    .unstack(0)
)

stats.fisher_exact(ORR_ctab_models)

In [None]:
X = (
    raw_pdx_obs[["cell_id", "drug_id", "r_best", "r_avg", "mRECIST"]]
    .merge(
        raw_screen_selected_drugs.query("model == 'Screen - ln(IC50)'")
        .filter(items=["cell_id", "drug_id"])
        .assign(was_selected=True),
        how="left",
    )
    .query("cell_id in @raw_screen_selected_drugs.cell_id")
    .fillna({"was_selected": False})
    .assign(
        x=lambda df: df["cell_id"] + " + " + df["drug_id"],
        r_avg=lambda df: df["r_avg"] / 100,
    )
)

In [None]:
CBR_ctab_screen = (
    X.assign(CBR=lambda df: df["mRECIST"].isin(["CR", "PR", "SD"]))
    .groupby(["was_selected", "CBR"])
    .size()
    .unstack(0)
)

stats.fisher_exact(CBR_ctab_screen)

In [None]:
ORR_ctab_screen = (
    X.assign(ORR=lambda df: df["mRECIST"].isin(["CR", "PR"]))
    .groupby(["was_selected", "ORR"])
    .size()
    .unstack(0)
)

stats.fisher_exact(ORR_ctab_screen)

In [None]:
# look at fisher exact test for proportion of CBR and ORR in selected vs not selected

In [None]:
# FIXME: add percentage and number of PDXs to top of panel
# FIXME: try using GRAOC metric to select drugs for the raw PDxO screening

## Delta tumor volume visuals for selected drugs

In [None]:
temp = ensembl_selected_drugs.query("model == 'ScreenDL-SA'")
temp = temp.merge(raw_pdx_obs.drop(columns="mRECIST"), on=["cell_id", "drug_id"]).copy()
temp["mRECIST"] = pd.Categorical(
    temp["mRECIST"], categories=["CR", "PR", "SD", "PD"], ordered=True
)

In [None]:
get_header = lambda r: f"{r['cell_id']} + {r['drug_id']} ({r['mRECIST']})"
temp["header"] = temp.apply(get_header, axis=1)
sample_to_header = dict(zip(temp["cell_id"], temp["header"])) 
pdx_order = temp.sort_values(["mRECIST", "r_best"])["cell_id"].to_list()
facet_order = temp.sort_values(["mRECIST", "r_best"])["header"].to_list()

In [None]:
sel_inds = pd.Index(temp[["cell_id", "drug_id", "exp_id"]])
drug_source = (
    raw_pdx_data_drug.set_index(["sample_id", "drug_name", "exp_id"])
    .loc[sel_inds]
    .reset_index()
)
ctrl_source = (
    raw_pdx_data_ctrl.set_index(["sample_id", "exp_id"])
    .loc[[(x[0], x[2]) for x in sel_inds]]
    .reset_index()
)

source = pd.concat([ctrl_source, drug_source])
source["header"] = source["sample_id"].map(sample_to_header)

In [None]:
drug_order = (
    drug_source.drop_duplicates(["sample_id", "drug_name"])
    .set_index("sample_id")
    .loc[pdx_order]["drug_name"]
    .unique()
)

In [None]:
# NOTE: we filter out observations beyond 40 days after treatment start
base = (
    alt.Chart(source)
    .transform_filter(alt.datum.day <= 40)
    .encode(
        alt.X("day:Q")
        .axis(labelAngle=0, tickCount=4, grid=False, titlePadding=10)
        .scale(domainMax=40, domainMin=0)
        .title("Time after treatment (d)"),
        alt.Y("rel_tumor_vol_pct:Q")
        .axis(grid=False, tickCount=4, minExtent=35, titlePadding=5)
        .scale(domainMin=0, nice=True)
        .title("Relative tumor volume (%)"),
        alt.condition(
            alt.datum.drug_name == "Vehicle",
            alt.ColorValue("darkgray"),
            alt.Color("drug_name:N")
            .scale(domain=list(drug_order))
            .legend(orient="right", columns=1)
            .title(None),
        ),
        alt.Detail("mouse_id:N"),
    )
)

chart = (
    base.mark_line(strokeWidth=1, point=alt.MarkConfig(size=20), clip=False)
    .properties(width=150, height=120)
    .facet(alt.Facet("header:N", sort=facet_order, header=None), columns=6, spacing=30)
    .resolve_scale(y="independent", x="independent")
    .configure_view(strokeOpacity=0)
    .configure_axis(
        titleFont="arial",
        titleFontStyle="regular",
        titleFontSize=11,
        labelFont="arial",
        tickColor="black",
        domainColor="black",
    )
    .configure_header(
        titleFont="arial",
        titleFontStyle="regular",
        titleFontSize=11,
    )
)

chart.display()

In [None]:
temp = (
    raw_screen_selected_drugs.drop(columns="mRECIST")
    .query("model == 'Screen - ln(IC50)'")
    .merge(raw_pdx_obs, on=["cell_id", "drug_id"])
    .copy()
)
temp["mRECIST"] = pd.Categorical(
    temp["mRECIST"], categories=["CR", "PR", "SD", "PD"], ordered=True
)
pdx_order = temp.sort_values(["mRECIST", "r_best"])["cell_id"].to_list()

In [None]:
sel_inds = pd.Index(temp[["cell_id", "drug_id", "exp_id"]])
drug_source = (
    raw_pdx_data_drug.set_index(["sample_id", "drug_name", "exp_id"])
    .loc[sel_inds]
    .reset_index()
)
ctrl_source = (
    raw_pdx_data_ctrl.set_index(["sample_id", "exp_id"])
    .loc[[(x[0], x[2]) for x in sel_inds]]
    .reset_index()
)

source = pd.concat([ctrl_source, drug_source])

In [None]:
drug_order = (
    drug_source.drop_duplicates(["sample_id", "drug_name"])
    .set_index("sample_id")
    .loc[pdx_order]["drug_name"]
    .unique()
)

In [None]:
temp.sort_values("mRECIST")

In [None]:
# NOTE: we filter out observations beyond 40 days after treatment start
base = (
    alt.Chart(source)
    .transform_filter(alt.datum.day <= 40)
    .encode(
        alt.X("day:Q")
        .axis(labelAngle=0, tickCount=4, grid=False, titlePadding=10)
        .scale(domainMax=40, domainMin=0)
        .title("Time after treatment (d)"),
        alt.Y("rel_tumor_vol_pct:Q")
        .axis(grid=False, tickCount=4, minExtent=35, titlePadding=5)
        .scale(domainMin=0, nice=True)
        .title("Relative tumor volume (%)"),
        alt.condition(
            alt.datum.drug_name == "Vehicle",
            alt.ColorValue("darkgray"),
            alt.Color("drug_name:N")
            .scale(domain=list(drug_order))
            .legend(orient="top")
            .title(None),
        ),
        alt.Detail("mouse_id:N"),
    )
)

chart = (
    base.mark_line(strokeWidth=1, point=alt.MarkConfig(size=20), clip=False)
    .properties(width=140, height=110)
    .facet(alt.Facet("sample_id:N", sort=pdx_order, header=None), columns=5, spacing=30)
    .resolve_scale(y="independent", x="independent")
    .configure_view(strokeOpacity=0)
    .configure_axis(
        titleFont="arial",
        titleFontStyle="regular",
        titleFontSize=11,
        labelFont="arial",
        tickColor="black",
        domainColor="black",
    )
    .configure_header(
        titleFont="arial",
        titleFontStyle="regular",
        titleFontSize=11,
    )
)

chart.display()

## Comparison of ScreenDL runs

In [None]:
output_dir = root / "outputs/experiments/pdx_validation"

path_fmt = "{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2-HCI"
model = "ScreenDL"
mapper = {"base": "ScreenDL-PT", "xfer": "ScreenDL-FT", "screen": "ScreenDL-SA"}

dates = [
    "2024-11-21_11-23-58",
    "2024-11-21_14-08-27",
    "2024-11-21_15-08-17",
    "2024-11-21_18-14-24",
    "2024-11-22_08-11-05",
    "2024-11-22_08-11-31",
    "2024-11-22_19-40-35",
    "2024-11-22_19-41-05",
    "2024-11-23_07-33-53",
    "2024-11-23_08-25-39",
]

temp = []
for date in dates:
    run_dir = output_dir / path_fmt.format(dataset, model, date)
    screendl_pdxo_results = load_multirun_predictions(run_dir, "*/predictions_pdxo.csv")
    screendl_pdx_results = load_multirun_predictions(run_dir, "*/predictions_pdx.csv")
    screendl_pdxo_results["model"] = screendl_pdxo_results["model"].map(mapper)
    screendl_pdx_results["model"] = screendl_pdx_results["model"].map(mapper)
    temp.append(screendl_pdx_results.assign(date=date))


temp = pd.concat(temp)

# add in additional annotations
temp = (
    temp.merge(
        raw_pdx_obs.drop(columns=["id", "label"]), how="left", on=["cell_id", "drug_id"]
    )
    .sort_values(["model", "iter", "cell_id", "drug_id"])
    .groupby(["model", "date", "drug_id", "cell_id"])
    .aggregate(
        {
            "y_true": "first",
            "y_pred": lambda x: stats.trim_mean(x, 0.2),
            "mRECIST": "first",
        }
    )
    .reset_index()
)

In [None]:
responders_per_drug = raw_pdx_obs.groupby("drug_id")["label"].sum()
DRUGS_WITH_RESPONDERS = responders_per_drug[responders_per_drug >= 1].index
temp = temp.query("drug_id in @DRUGS_WITH_RESPONDERS")

# drop PDX samples with less than 2 drugs to choose from
drugs_per_PDX = temp.groupby("cell_id")["drug_id"].nunique()
GOOD_PDX_SAMPLES = drugs_per_PDX[drugs_per_PDX > 1].index
temp = temp.query("cell_id in @GOOD_PDX_SAMPLES")

In [None]:
ensembl_selected_drugs = (
    temp.groupby(["model", "date", "cell_id"], as_index=False)
    .apply(lambda df: select_best_therapy(df, on_="y_pred"))
    .sort_values(["model", "date", "cell_id"])
)

ensembl_CBRs = ensembl_selected_drugs.groupby(["model", "date"]).apply(get_response_rate)
ensembl_CBRs.unstack(0)[MODELS[2:]].mean()

In [None]:
ensembl_ORRs = (
    ensembl_selected_drugs.assign(
        y_true=lambda df: df["mRECIST"].isin(["PR", "CR"]).astype(int)
    )
    .groupby(["model", "date"])
    .apply(get_response_rate)
)

ensembl_ORRs.unstack(0)[MODELS[2:]].mean()

In [None]:
ensembl_ORRs.unstack(0)[MODELS[2:]]

In [None]:
import scipy.spatial.distance as ssd

In [None]:
X = ensembl_CBRs.unstack(0).join(ensembl_ORRs.unstack(0), rsuffix=" ORR", lsuffix=" CBR")
X.index[np.argmin(ssd.cdist(X, np.expand_dims(X.mean(), 0), metric="cosine"))]

## Drug-Level Analysis

In [None]:
# NOTE: add discussion point -> even if we screened with every single drug, the machine does better (especially with the ensemble)
# NOTE: we would expect a very low response rate in this population as these patients will be resistant to lots of drugs
# NOTE: random seleciton mirrors the expected response rate to randomly selected chemotherapies

In [None]:
# NOTE: while the raw screening has fewer drugs to choose from, none of the drugs
# that are missing are chosen by any of the models
# FIXME: check why these are worse than in the old jupyter notebooks (filtering tumors)

In [None]:
# results_f2.groupby("cell_id")["drug_id"].nunique()

# raw_screen_preds.query("cell_id in @results_f2.cell_id").query(
#     "drug_id in @results_f2.drug_id"
# ).groupby("cell_id")["drug_id"].nunique()

## PDxO-Level Performance

In [None]:
# ensembl_pdx_result = (
#     results_f2.groupby(["model", "drug_id", "cell_id"])
#     .agg(
#         {
#             "y_true": "first",
#             "y_pred": "mean",
#             "mRECIST": "first",
#             "r_best": "first",
#             "r_avg": "first",
#         }
#     )
#     .reset_index()
# )

# func = lambda df: pd.Series(
#     {
#         "auROC": auroc(df),
#         "n_clinical_benefit_drugs": df["y_true"].sum(),
#         "n_total_drugs": len(df),
#     }
# )
# ensembl_cell_metrics = (
#     ensembl_pdx_result.groupby(["model", "cell_id"]).apply(func).dropna()
# )

# ensembl_cell_metrics.groupby("model")["auROC"].describe().loc[MODELS]

In [None]:
# interested_lines = (
#     ensembl_cell_metrics.loc["ScreenDL-SA"].query("n_total_drugs >= 5").index.to_list()
# )

# query = "model == 'ScreenDL-SA' and cell_id in @interested_lines"
# interested_result = ensembl_pdx_result.query(query).copy()
# grouped = interested_result.groupby("cell_id")
# interested_result["rank"] = grouped["y_pred"].transform("rank")
# interested_result.head()

In [None]:
# tv_metrics = (
#     interested_result.groupby("cell_id")
#     .apply(lambda g: eval_utils.pcorr(g, "y_pred", "r_avg", min_obs=5))
#     .to_frame(name="pcc")
#     .reset_index()
# )

# tv_metrics

In [None]:
# query = "model == 'ScreenDL-SA' and cell_id in @interested_lines"
# grouped = ensembl_pdx_result.query(query).groupby("cell_id")
# charts = []
# for i, (tumor_id, group) in enumerate(grouped):
#     n_drugs = group["drug_id"].nunique()
#     group["rank"] = group["y_pred"].rank()

#     x_encoding = alt.X("rank:O")
#     y_encoding = alt.Y("r_avg:Q")
#     c_encoding = alt.Color("mRECIST:O").scale(
#         domain=(["CR", "PR", "SD", "PD"]),
#         range=(["#2978B8", "#5BA3CF", "#9CC8E2", "#F87F2C"]),
#     )

#     if i > 0:
#         # only show y-axis on the first panel
#         y_encoding = y_encoding.axis(None)

#     chart = (
#         alt.Chart(group)
#         .mark_bar()
#         .encode(x_encoding, y_encoding, c_encoding)
#         .properties(width=15 * n_drugs, height=200, title=tumor_id)
#     )
#     charts.append(chart)


# alt.hconcat(*charts).resolve_scale(y="shared").resolve_axis(y="shared")

In [None]:
# group.sort_values("y_pred")

In [None]:
# NOTE: show PDX auROC for lines with @ lest 5 drugs screened

In [None]:
# ensembl_cell_metrics.loc["ScreenDL-FT"].query("n_total_drugs >= 5")

In [None]:
# ensembl_cell_metrics.loc["ScreenDL-SA"].query("n_total_drugs >= 5")

# Scratch

## Learning Curves

In [None]:
# path = (
#     root
#     / "outputs/experiments/pdmc_learning_curves/CellModelPassportsGDSCv1v2HCIv9AllDrugsHallmarkCombat/ScreenDL/runs/2024-11-10_19-29-17"
#     # / "outputs/experiments/pdmc_learning_curves/CellModelPassportsGDSCv1v2HCIv9AllDrugsHallmarkCombat/ScreenDL/runs/2024-11-10_14-40-42"
#     # / "outputs/experiments/pdmc_learning_curves/CellModelPassportsGDSCv1v2HCIv9AllDrugsHallmarkCombat/ScreenDL/runs/2024-11-10_16-05-31"
# )

# df = pd.read_csv(path / "predictions_bg.csv", low_memory=False)
# # df = pd.read_csv(path / "predictions.csv", low_memory=False)
# print(df.shape)
# df.head()

In [None]:
# FIXME: don't do preds vs background for the LC analysis

In [None]:
# pcc = (
#     df.query("model == 'tune'")
#     .groupby(["fold", "iter", "n_tumors", "drug_id"])
#     .apply(eval_utils.pcorr)
#     .to_frame("pcc")
#     .dropna()
#     .reset_index()
# )

In [None]:
# from scipy.optimize import curve_fit
# from scipy import stats

In [None]:
# mae_metrics = (
#     df.query("model == 'tune'")
#     # .groupby(["iter", "n_tumors"])
#     .groupby(["fold", "iter", "n_tumors"])
#     .apply(lambda g: skm.mean_absolute_error(g["y_true"], g["y_pred"]))
#     .to_frame("mae")
#     .reset_index()
# )

In [None]:
# mae_drug_metrics: pd.DataFrame = (
#     df.query("model == 'tune'")
#     .groupby(["fold", "iter", "n_tumors", "drug_id"])
#     .apply(lambda g: skm.mean_absolute_percentage_error(g["y_true"], g["y_pred"]))
#     .to_frame("mae")
#     .reset_index()
# )

In [None]:
# def power_law(x, a, b):
#     return a * np.power(x, b)


# mae_metrics_geq_xmin = mae_metrics.query("n_tumors > 32")
# x_data = mae_metrics_geq_xmin["n_tumors"].values
# y_data = mae_metrics_geq_xmin["mae"].values

# params, covariance = curve_fit(power_law, x_data, y_data)
# perr = np.sqrt(np.diag(covariance))
# a, b = params

# x_fit = np.linspace(4, 128, 100)
# y_fit = power_law(x_fit, a, b)

# alpha = 0.1  # 90% confidence level

# dof = max(0, len(x_data) - len(params))  # degrees of freedom
# t_val = stats.t.ppf(1 - alpha / 2, dof)  # t-value for the confidence interval

# # Define functions to calculate the upper and lower bounds
# y_upper = power_law(x_fit, a + perr[0] * t_val, b + perr[1] * t_val)
# y_lower = power_law(x_fit, a - perr[0] * t_val, b - perr[1] * t_val)

In [None]:
# NOTE: 
# 1. for discussion -> may need to develop other PDxOs for other cancer types -> action item - develop organoids
# 2. we still need ds ft for maximal performance

In [None]:
# scale_config = dict(type="log", base=2)
# scale_config = dict()

# circles = (
#     alt.Chart(mae_metrics)
#     .mark_circle()
#     .encode(
#         alt.X("n_tumors:Q").scale(**scale_config).title("No. PDxO Lines"),
#         alt.Y("mae:Q").scale(**scale_config).title("Mean Absolute Error"),
#         alt.Color("fold:N"),
#     )
# )

# power_law_source = pd.DataFrame(
#     {"n_tumors": x_fit, "mae": y_fit, "mae_upper": y_upper, "mae_lower": y_lower}
# )
# power_law_base = alt.Chart(power_law_source).encode(
#     alt.X("n_tumors:Q").scale(**scale_config)
# )

# power_law_line = power_law_base.mark_line().encode(alt.Y("mae:Q").scale(**scale_config))
# power_law_bound = power_law_base.mark_area(color="lightgray", opacity=0.5).encode(
#     alt.Y("mae_lower:Q").scale(**scale_config), alt.Y2("mae_upper:Q")
# )

# chart: alt.Chart = power_law_bound + power_law_line + circles
# chart.resolve_scale(color="independent").properties(width=700)

In [None]:
# func = lambda g: pd.Series(
#     {
#         "pcc": eval_utils.pcorr(g, min_obs=5),
#         # "n_tumors_test": len(g),
#         # "n_tumors": g["n_tumors"].iloc[0],
#     }
# )

# pcc_drug_metrics: pd.DataFrame = (
#     df.groupby(["iter", "n_tumors", "drug_id"]).apply(func).reset_index().dropna()
# )

# pcc_drug_metrics.head()

In [None]:
# pcc_drug_metrics["drug_id"].unique()

In [None]:
# source = pcc_drug_metrics.groupby(["drug_id", "n_tumors"])["pcc"].mean().reset_index()
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Olaparib"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Alisertib"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Afatinib"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "AZD5363"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "5-azacytidine"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Lapatinib"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Vinorelbine"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "SN-38"]
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Selumetinib"]
# chart = alt.Chart(source).mark_circle().encode(
#     alt.X("n_tumors:Q").scale(type="log", base=2),
#     alt.Y("pcc:Q"),
#     alt.Color("fold:N"),
# )

# chart.properties(width=700)

In [None]:
# source = pcc_drug_metrics.groupby(["drug_id", "n_tumors"])["pcc"].mean().reset_index()
# source = pcc_drug_metrics[pcc_drug_metrics["drug_id"] == "Alisertib"]
# chart = alt.Chart(source).mark_circle().encode(
#     alt.X("n_tumors:Q").scale(type="log", base=2),
#     alt.Y("pcc:Q"),
#     alt.Color("fold:N"),
# )

# chart.properties(width=700)

In [None]:
# circles = (
#     alt.Chart(mae_metrics)
#     .mark_circle()
#     .encode(alt.X("n_tumors:Q"), alt.Y("mae:Q"), alt.Color("fold:N"))
# )

# line = alt.Chart(pl_source).mark_line().encode(alt.X("n_tumors:Q"), alt.Y("mae:Q"))

# (circles + line).resolve_scale(color="independent").properties(width=700)

In [None]:
# tmpdir = Path("./temp")
# tmpdir.mkdir(exist_ok=True)

In [None]:
# url = str(tmpdir / "temp.json")
# pcc.to_json(url, orient="records")

In [None]:
# source = pcc.groupby(["n_tumors", "drug_id"])["pcc"].mean().to_frame("pcc").reset_index()
# alt.Chart(source).transform_filter(
#     # alt.datum.drug_id == "Ipatasertib"
#     alt.datum.drug_id == "Alisertib"
# ).mark_circle().encode(alt.X("n_tumors:Q"), alt.Y("pcc:Q"), alt.Color("drug_id:N"))

In [None]:
# NOTE: look at MAE for the residual response

In [None]:
# Look at the resulting ranking for a given tumor?
# -> what is the best z-score vs best predicted z-score

In [None]:
# dataset_dir = root / "inputs/CellModelPassportsGDSCv1v2HCIv9AllDrugsHallmarkCombat"

# drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
# cell_meta = pd.read_csv(dataset_dir / "MetaSampleAnnotations.csv", index_col=0)

# D = Dataset.from_csv(
#     dataset_dir / "LabelsLogIC50.csv",
#     cell_meta=cell_meta,
#     drug_meta=drug_meta,
#     name="CellModelPassportsGDSCHCI",
# )

# cell_ids = D.cell_meta[D.cell_meta["domain"] == "CELL"].index
# pdmc_ids = D.cell_meta[D.cell_meta["domain"] == "PDMC"].index

# cell_ds = D.select_cells(cell_ids, name="cell_ds")
# pdmc_ds = D.select_cells(pdmc_ids, name="pdmc_ds")

# print(cell_ds)
# print(pdmc_ds)

In [None]:
# NOTE: this is very dependent on the number fold and the

In [None]:
# NOTE: why would we still want to do vs background
# -> we need to have relative response in order to select drugs
# -> so by computing vs the background dataset, we can better assess our
#    ability to predict differential response