# Empirical validation of baselines used in SHAP

In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import sklearn.metrics as skm
import typing as t

from functools import partial
from pathlib import Path
from scipy import stats
from tensorflow import keras
from tqdm import tqdm, trange

from cdrpy.util.io import read_gmt

from screendl.utils.ensemble import ScreenDLEnsembleWrapper
from screendl.pipelines.core.screendl import (
    apply_preprocessing_pipeline,
    load_dataset,
    load_pretraining_configs,
    split_dataset,
)

from utils.const import DRUG_TO_PATHWAY_EXT as _DRUG_TO_PATHWAY_EXT

In [None]:
def load_ensemble(
    pretrain_dir: str | Path,
    model_type: t.Literal["ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA"],
    tumor_id: str,
) -> ScreenDLEnsembleWrapper:
    """Load the ScreenDL ensemble model."""
    pattern = (
        f"*/{model_type}.model"
        if model_type != "ScreenDL-SA"
        else f"*/{model_type}.{tumor_id}.model"
    )
    files = pretrain_dir.glob(pattern)
    members = []
    for file in files:
        members.append(keras.models.load_model(file))
    return ScreenDLEnsembleWrapper(members)

In [None]:
root = Path("../../../datastore")

In [None]:
drug_meta_ext = pd.read_csv(
    root / "inputs/CellModelPassports-GDSCv1v2/MetaDrugAnnotations.csv",
    index_col=0,
    usecols=["drug_id", "targets", "target_pathway"],
)

drug_to_pathway = drug_meta_ext["target_pathway"].to_dict()
fixed_pathways = {"EGFR signaling": "EGFR/HER2 signaling"}
drug_to_pathway = {k: fixed_pathways.get(v, v) for k, v in drug_to_pathway.items()}
drug_to_pathway.update(_DRUG_TO_PATHWAY_EXT)

In [None]:
# load MSigDB gene sets
gmt_dir = root / "raw/MSigDB"

GENE_SETS = {
    "h.all": read_gmt(gmt_dir / "h.all.v2023.1.Hs.symbols.gmt"),
    "c6.all": read_gmt(gmt_dir / "c6.all.v2023.2.Hs.symbols.gmt"),
    "c2.cgp": read_gmt(gmt_dir / "c2.cgp.v2024.1.Hs.symbols.gmt"),
    "c2.cgp.breast": read_gmt(gmt_dir / "c2.cgp.v2024.1.Hs.breast_cancer.symbols.gmt"),
    "c5.go.bp": read_gmt(gmt_dir / "c5.go.bp.v2023.2.Hs.symbols.gmt"),
}

meta_sets = ["h.all", "c6.all", "c2.cgp.breast"]
gene_sets = {}
for name in meta_sets:
    gene_sets.update(GENE_SETS[name])

print(f"Considering {len(gene_sets):,} gene sets from {', '.join(meta_sets)}")

In [None]:
dataset = "CellModelPassports-GDSCv1v2-HCI"
model_name = "ScreenDL"
date, tumor_id = ("2025-06-24_09-26-32", "BCM15163")

In [None]:
pt_dir = root / f"outputs/core/{dataset}/{model_name}/multiruns/{date}"

In [None]:
model = load_ensemble(pt_dir, "ScreenDL-SA", tumor_id)

In [None]:
# load the datasets
X_t_0 = None
datasets = []
for i in trange(len(model.members)):
    pt_cfg, _ = load_pretraining_configs(pt_dir / str(i))
    Dt, Dv, De = split_dataset(pt_cfg, load_dataset(pt_cfg))
    if i == 0:
        X_t_0 = De.cell_encoders["exp"].data.loc[list(set(De.cell_ids))].copy()
    Dt, Dv, De = apply_preprocessing_pipeline(pt_dir / str(i), Dt, Dv, De)
    datasets.append(De)

In [None]:
assert X_t_0 is not None
X_t_0 = X_t_0.transform(stats.zscore)

In [None]:
get_x = lambda D, t_ids, d_ids: [
    D.cell_encoders["exp"].data.loc[t_ids, :].values,
    D.drug_encoders["mol"].data.loc[d_ids, :].values,
]

In [None]:
d_ids = datasets[0].drug_encoders["mol"].data.index.to_list()
t_ids = sorted(list(set(datasets[0].cell_ids)))

all_preds = []
for d_id in tqdm(d_ids):
    X = list(map(partial(get_x, t_ids=t_ids, d_ids=[d_id] * len(t_ids)), datasets))
    y_pred = model(X, map_inputs=True).numpy()
    all_preds.append(pd.DataFrame({"cell_id": t_ids, "drug_id": d_id, "y_pred": y_pred}))

In [None]:
all_preds_df = pd.concat(all_preds, ignore_index=True)
all_preds_df["y_pred"] = all_preds_df.groupby("drug_id")["y_pred"].transform(stats.zscore)
all_preds_df["y_pred_gds"] = all_preds_df.groupby("cell_id")["y_pred"].transform("mean")

In [None]:
tumor_to_pred_gds = all_preds_df[["cell_id", "y_pred_gds"]].drop_duplicates()
tumor_to_pred_gds.head()

In [None]:
d_dim = datasets[0].drug_encoders["mol"].shape[0]

In [None]:
d_baseline = np.zeros((1, d_dim))  # null drug -> model will predict tumor's GDS
x_d = np.repeat(d_baseline, len(t_ids), axis=0)
x_t = datasets[0].cell_encoders["exp"].data.loc[t_ids, :].values
pred_gds_from_baseline = model([x_t, x_d], map_inputs=False)

In [None]:
temp = (
    pd.Series(pred_gds_from_baseline, index=t_ids)
    .to_frame("y_pred_gds_0")
    .join(tumor_to_pred_gds.set_index("cell_id")["y_pred_gds"])
    .join(datasets[0].obs.groupby("cell_id")["label"].mean().to_frame("y_true_gds"))
)

stats.pearsonr(temp["y_pred_gds_0"], temp["y_pred_gds"])

In [None]:
# using mean prediction across drugs
print(stats.pearsonr(temp["y_true_gds"], temp["y_pred_gds"]))
print(skm.mean_squared_error(temp["y_true_gds"], temp["y_pred_gds"]))

In [None]:
# using baselines
print(stats.pearsonr(temp["y_true_gds"], temp["y_pred_gds_0"]))
print(skm.mean_squared_error(temp["y_true_gds"], temp["y_pred_gds_0"]))