# Local Explanations for PDxO Drug Selection with SHAP

## Contents

- [Configuration](#configuration)
- [Data Loading](#data-loading)
    - [Load metadata](#load-metadata)
    - [Load pretrained model](#load-pretrained-model)
- [Generate SHAP explanations for the optimal precision therapy](#generate-shap-explanations-for-the-optimal-precision-therapy)
- [Visualize shared subnetwork embeddings](#visualize-shared-subnetwork-embeddings)

In [None]:
from __future__ import annotations

import altair as alt
import pandas as pd
import numpy as np
import typing as t

from functools import partial
from pathlib import Path
from scipy import stats
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import euclidean_distances
from statsmodels.stats.multitest import multipletests
from tensorflow import keras
from tqdm import tqdm, trange

from altair_forge.utils import get_domain

from cdrpy.util.io import read_gmt

from screendl.utils.ensemble import ScreenDLEnsembleWrapper
from screendl.pipelines.core.screendl import (
    apply_preprocessing_pipeline,
    load_dataset,
    load_pretraining_configs,
    split_dataset,
)

from utils.const import DRUG_TO_PATHWAY_EXT
from utils.plot import configure_chart
from utils.stats import combine_pvalues_cauchy
from utils.shap import (
    shap_adapter,
    run_shap_gsea,
    get_ensemble_embeddings_for_drug,
    _get_x_gexp,
    _agg_shap_values,
)

## Configuration

In [None]:
# whether or not to consider drugs without known response
CONSIDER_SCREENED_DRUGS_ONLY = True

# SHAP params
SHAP_ZSCORE_THRESHOLD = -1.5  # z-score threshold to consider a gene for GSEA
SHAP_NUM_BG_TUMORS = 5  # number of tumors in the SHAP background distribution
SHAP_NUM_BG_SAMPLES = 10000  # number of samples from background distribution

# number of reference tumors to consider
NUM_REFERENCE_TUMORS = 5

## Data Loading

In [None]:
root = Path("../../../datastore")

### Load metadata

In [None]:
# load drug metadata
drug_meta = pd.read_csv(
    root / "inputs/CellModelPassports-GDSCv1v2/MetaDrugAnnotations.csv",
    index_col=0,
    usecols=["drug_id", "targets", "target_pathway"],
)

drug_to_pathway = drug_meta["target_pathway"].to_dict()
fixed_pathways = {"EGFR signaling": "EGFR/HER2 signaling"}
drug_to_pathway = {k: fixed_pathways.get(v, v) for k, v in drug_to_pathway.items()}
drug_to_pathway.update(DRUG_TO_PATHWAY_EXT)

In [None]:
hr_status = pd.read_csv(root / "raw/Welm/receptor_status_combined.pdx.final.csv")
hr_status.head()

In [None]:
# load MSigDB gene sets
gmt_files = [
    root / "raw/MSigDB/h.all.v2023.1.Hs.symbols.gmt",
    root / "raw/MSigDB/c2.cgp.v2024.1.Hs.breast_cancer.symbols.gmt",
]

GENE_SETS = {}
for gmt_file in gmt_files:
    GENE_SETS.update(read_gmt(gmt_file))

print(f"Considering {len(GENE_SETS):,} gene sets for overrepresentation analysis.")

### Load pretrained model

In [None]:
def load_ensemble(
    pt_dir: Path,
    tumor_id: str,
    model_type: t.Literal["ScreenDL-PT", "ScreenDL-FT", "ScreenDL-SA"],
) -> ScreenDLEnsembleWrapper:
    """Load the ScreenDL ensemble model."""
    pattern = (
        f"*/{model_type}.model"
        if model_type != "ScreenDL-SA"
        else f"*/{model_type}.{tumor_id}.model"
    )
    files = pt_dir.glob(pattern)
    members = []
    for file in files:
        members.append(keras.models.load_model(file))
    return ScreenDLEnsembleWrapper(members)


def load_datasets(pt_dir: Path, model: ScreenDLEnsembleWrapper) -> list:
    """Load datasets and apply preprocessing pipelines."""
    datasets = []
    for i in trange(len(model.members)):
        pt_cfg, _ = load_pretraining_configs(pt_dir / str(i))
        Dt, Dv, De = split_dataset(pt_cfg, load_dataset(pt_cfg))
        Dt, Dv, De = apply_preprocessing_pipeline(pt_dir / str(i), Dt, Dv, De)
        datasets.append(De)
    return datasets

In [None]:
dataset = "CellModelPassports-GDSCv1v2-HCI"
model_name = "ScreenDL"
date, tumor_id = ("2025-06-24_09-26-32", "BCM15163")
# date, tumor_id = ("2025-06-24_09-26-35", "HCI051")

# configure pt directory path
pt_dir = root / f"outputs/core/{dataset}/{model_name}/multiruns/{date}"

# load the ensemble model
model = load_ensemble(pt_dir, tumor_id, "ScreenDL-SA")

# load and preprocess datasets
datasets = load_datasets(pt_dir, model)

In [None]:
# helper function to encoder (tummor, drug) pairs
get_x = lambda D, t_ids, d_ids: [
    D.cell_encoders["exp"].data.loc[t_ids, :].values,
    D.drug_encoders["mol"].data.loc[d_ids, :].values,
]

In [None]:
d_ids = datasets[0].drug_encoders["mol"].data.index.to_list()
t_ids = sorted(list(set(datasets[0].cell_ids)))

all_preds = []
for d_id in tqdm(d_ids):
    encode_fn = partial(get_x, t_ids=t_ids, d_ids=[d_id] * len(t_ids))
    X = list(map(encode_fn, datasets))
    y_pred = model(X, map_inputs=True, training=False, trim_frac=0.2).numpy()
    all_preds.append(pd.DataFrame({"cell_id": t_ids, "drug_id": d_id, "y_pred": y_pred}))

all_preds_df = pd.concat(all_preds, ignore_index=True)
all_preds_df["y_pred"] = all_preds_df.groupby("drug_id")["y_pred"].transform(stats.zscore)
all_preds_df.head()

In [None]:
y_true_df = datasets[0].obs.rename(columns={"label": "y_true"}).drop(columns="id")

tumor_preds_df = (
    all_preds_df.query("cell_id == @tumor_id")
    .assign(pathway=lambda df: df["drug_id"].map(drug_to_pathway))
    .sort_values("y_pred")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
)

tumor_preds_df_topk = (
    tumor_preds_df.query("y_true.notna()").iloc[:10]
    if CONSIDER_SCREENED_DRUGS_ONLY
    else tumor_preds_df.iloc[:10]
)

best_drug = tumor_preds_df_topk.iloc[0]["drug_id"]
print(f"Best drug: {best_drug} ({drug_to_pathway[best_drug]})")
tumor_preds_df_topk.head()

## Generate SHAP explanations for the optimal precision therapy

In [None]:
# get predicted GDS for each tumor
x_tumors = [_get_x_gexp(D).loc[t_ids].values for D in datasets]

# use null drug as baseline to get predicted GDS
x_drug_0 = np.zeros((1, datasets[0].drug_encoders["mol"].shape[-1]))
x_drug_0 = [np.tile(x_drug_0, (len(x_t), 1)) for x_t in x_tumors]

predict_func = partial(model, map_inputs=True, training=False, trim_frac=0.2)
y_pred_gds = predict_func(list(zip(x_tumors, x_drug_0)))
y_pred_gds = pd.Series(y_pred_gds.numpy(), index=t_ids)

In [None]:
# compute GDS and RDS
all_preds_df["y_pred_gds"] = all_preds_df["cell_id"].map(y_pred_gds)
all_preds_df["y_pred_rds"] = all_preds_df["y_pred"] - all_preds_df["y_pred_gds"]
all_preds_df["y_pred_rds_abs"] = all_preds_df["y_pred_rds"].abs()

# identify tumors with near-zero RDS for the best drug to use as reference for SHAP
rds_sorted_tumors = (
    all_preds_df.query("drug_id == @best_drug")
    .query("cell_id != @tumor_id")
    .sort_values("y_pred_rds_abs")["cell_id"]
    .to_list()
)

In [None]:
shap_func = lambda x: shap_adapter(
    x[0],
    x[1],
    drug_id=best_drug,
    tumor_id=tumor_id,
    n_bg_samples=SHAP_NUM_BG_TUMORS,
    n_shap_samples=SHAP_NUM_BG_SAMPLES,
    sorted_tumors=rds_sorted_tumors,
)

# we run SHAP for each ensemble member
shap_values = list(map(shap_func, zip(model.members, datasets)))

# extract tumor results (drug is held constant)
shap_values_t = [x[0].assign(idx=i) for i, x in enumerate(shap_values)]
shap_values_t = pd.concat(shap_values_t, ignore_index=True)

In [None]:
# aggregate gene-level SHAP values across ensemble members
shap_values_t_agg = _agg_shap_values(shap_values_t)
shap_values_t_agg.head()

In [None]:
# determine the number of genes to consider for GSEA
shap_zscores = stats.zscore(shap_values_t_agg["value"])
TOP_N_GENES = shap_values_t_agg.loc[shap_zscores < SHAP_ZSCORE_THRESHOLD].shape[0]
print(f"Considering {TOP_N_GENES:,} genes with z < {SHAP_ZSCORE_THRESHOLD:.2f} (z-score)")

# run gene set ORA on the top N sensitivity genes from SHAP
gsea_results = run_shap_gsea(
    shap_values_t,
    tumor_id=tumor_id,
    gene_sets=GENE_SETS,
    datasets=datasets,
    top_n=TOP_N_GENES,
    by_abs=False,
    reversed=False,
)

# multiple testing correction
gsea_results["fdr"] = multipletests(gsea_results["pval"], method="fdr_bh")[1]
gsea_results["log_pval"] = -np.log10(gsea_results["pval"])
gsea_results["log_fdr"] = -np.log10(gsea_results["fdr"])

gsea_results.sort_values("pval").head()

In [None]:
# filter for significant results
gsea_source = gsea_results[(gsea_results["pval"] < 0.05)].sort_values("pval").copy()
SIG_GENE_SET_IDS = gsea_source["set"].to_list()
SIG_GENE_SETS = {k: v for k, v in GENE_SETS.items() if k in SIG_GENE_SET_IDS}

# assign y-axis labels for plotting
gsea_source["y"] = gsea_source.apply(lambda r: f"{r['set']}  [N={r['overlap']}]", axis=1)
gene_set_to_y_label = dict(zip(gsea_source["set"], gsea_source["y"]))

In [None]:
x_embed = get_ensemble_embeddings_for_drug(
    model, datasets, drug_id=best_drug, layer_name="shared_mlp_2"
)

# standardize the embeddings
x_embed = x_embed.transform(stats.zscore)

# tack on the predicted values for the best drug
# temp = all_preds_df.query("drug_id == @best_drug").set_index("cell_id")[["y_pred"]]
# x_embed_with_pred = x_embed.copy()
# x_embed_with_pred = x_embed_with_pred.join(temp, how="left")

x_embed_this = x_embed.loc[[tumor_id]]
x_embed_rest = x_embed.drop(tumor_id)

# identify reference tumors (nearest neighbors in embedding space)
# sims = cosine_distances(x_embed_this, x_embed_rest).flatten()
sims = euclidean_distances(x_embed_this, x_embed_rest).flatten()
ref_tumor_inds = np.argsort(sims)[:NUM_REFERENCE_TUMORS]
ref_tumors = x_embed_rest.index[ref_tumor_inds].to_list()
print(ref_tumors)

In [None]:
ref_shap_values_t = []
for t_id in ref_tumors:
    shap_func = lambda x: shap_adapter(
        x[0],
        x[1],
        drug_id=best_drug,
        tumor_id=t_id,
        n_bg_samples=SHAP_NUM_BG_TUMORS,
        n_shap_samples=SHAP_NUM_BG_SAMPLES,
    )
    _shap_values = list(map(shap_func, zip(model.members, datasets)))
    _shap_values_t = [x[0].assign(idx=i) for i, x in enumerate(_shap_values)]
    _shap_values_t = pd.concat(_shap_values_t, ignore_index=True).assign(cell_id=t_id)
    ref_shap_values_t.append(_shap_values_t)

In [None]:
SIG_GENE_SET_IDS = gsea_source["set"].to_list()
SIG_GENE_SETS = {k: v for k, v in GENE_SETS.items() if k in SIG_GENE_SET_IDS}

In [None]:
gsea_params = {
    "datasets": datasets,
    "gene_sets": SIG_GENE_SETS,
    "top_n": TOP_N_GENES,
    "by_abs": False,
    "reversed": False,
}

ref_gsea_source = [
    run_shap_gsea(x, tumor_id=t, **gsea_params).assign(cell_id=t)
    for x, t in zip(ref_shap_values_t, ref_tumors)
]
ref_gsea_source = pd.concat(ref_gsea_source, ignore_index=True)

# aggregate gene set p-values across reference tumors
grouped = ref_gsea_source.groupby("set")
ref_gsea_source["pval_agg"] = grouped["pval"].transform(combine_pvalues_cauchy)
ref_gsea_source["log_pval"] = -np.log10(ref_gsea_source["pval_agg"])
ref_gsea_source["y"] = ref_gsea_source["set"].map(gene_set_to_y_label)

ref_gsea_source.drop_duplicates(["set"]).sort_values("pval").head()

In [None]:
# compute average expression of gene set genes
gexp_source = []
for gs in SIG_GENE_SET_IDS:
    for ds in datasets:
        gs_exp = (
            _get_x_gexp(ds)
            .filter(items=GENE_SETS[gs])
            .mean(axis=1)
            .to_frame(name="z")
            .assign(set=gs)
            .rename_axis(index="cell_id")
            .reset_index()
        )
        gexp_source.append(gs_exp)

gexp_source = (
    pd.concat(gexp_source, ignore_index=True)
    .groupby(["cell_id", "set"])
    .mean()
    .reset_index()
)

# grab the y-axis labels from GSEA results
gexp_source["y"] = gexp_source["set"].map(gene_set_to_y_label)
gexp_source["is_ref"] = gexp_source["cell_id"].isin(ref_tumors)
gexp_source = gexp_source.sort_values("is_ref")

gexp_source.head()

In [None]:
HEIGHT = 350
WIDTH = 150

legend_config = {
    "orient": "none",
    "legendX": WIDTH / 2 - (WIDTH / 1 / 2),
    "legendY": -60,
    "gradientLength": WIDTH / 1,
    "direction": "horizontal",
}

In [None]:
gsea_source_url = "./temp/gsea_source.json"
gsea_source.to_json(gsea_source_url, orient="records")

gexp_source_url = "./temp/gexp_source.json"
gexp_source.query("cell_id != @tumor_id").to_json(gexp_source_url, orient="records")

ref_gsea_source_url = "./temp/ref_gsea_source.json"
ref_gsea_source["is_focused"] = ref_gsea_source["cell_id"] == tumor_id
ref_gsea_source["is_ref"] = ref_gsea_source["cell_id"].isin(ref_tumors)
ref_gsea_source.to_json(ref_gsea_source_url, orient="records")

In [None]:
vals = list(gsea_source["total_mean"]) + list(ref_gsea_source["total_mean"])
attr_domain = get_domain(vals, step=0.05)
attr_domain = (attr_domain[0], max(0.05, attr_domain[1]))
attr_ticks = np.arange(attr_domain[0], attr_domain[1] + 0.01, step=0.05)

vals = list(gexp_source["z"]) + list(ref_gsea_source["z"])
gexp_domain = get_domain(vals, step=0.5)
gexp_domain_abs_max = max([abs(gexp_domain[0]), abs(gexp_domain[1])])
# gexp_domain = (-gexp_domain_abs_max - 0.5, gexp_domain_abs_max + 0.5)
gexp_domain = (-gexp_domain_abs_max, gexp_domain_abs_max)

y_order = gsea_source.sort_values("pval")["y"].to_list()

In [None]:
vline = (
    alt.Chart(pd.DataFrame({"x": [0.0]}))
    .mark_rule(color="lightgray")
    .encode(alt.X("x:Q"))
)

In [None]:
attr_plot = (
    alt.Chart(gsea_source_url, height=HEIGHT, width=WIDTH)
    .mark_bar()
    .encode(
        alt.X("total_mean:Q")
        .axis(grid=False, values=attr_ticks)
        .scale(
            domainMid=0,
            domainMin=attr_domain[0],
            domainMax=attr_domain[1],
            nice=False,
        )
        .title("Combined SHAP Value"),
        alt.Y("y:N").sort(y_order).axis(labelLimit=500).title(None),
        alt.Color("log_pval:Q")
        .scale(scheme="greens", domainMin=0)
        .legend(**legend_config)
        .title(["Significance", "-log10(P-Value)"]),
    )
)

In [None]:
gexp_plot = (
    alt.Chart(gexp_source.query("cell_id == @tumor_id"), height=HEIGHT, width=WIDTH)
    .mark_point(opacity=1, size=200, color="#636363", filled=True, shape="diamond")
    .encode(
        alt.X("z:Q")
        .axis(grid=False)
        .scale(domainMid=0, domainMin=gexp_domain[0], domainMax=gexp_domain[1])
        .title("Gene Set Expression"),
        alt.Y("y:N").sort(y_order).axis(labels=False).title(None),
        tooltip=[
            alt.Tooltip("cell_id:N").title("Tumor ID"),
            alt.Tooltip("set:N").title("Gene Set"),
            alt.Tooltip("z:Q").format(".2f").title("Expression"),
        ],
    )
)

ref_gexp_plot = (
    alt.Chart(gexp_source_url, height=HEIGHT, width=WIDTH)
    .transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
    .mark_point(filled=True)
    .encode(
        alt.X("z:Q")
        .axis(grid=False)
        .scale(domainMid=0, domainMin=gexp_domain[0], domainMax=gexp_domain[1])
        .title("Gene Set Expression"),
        alt.Y("y:N").sort(y_order).axis(labels=False).title(None),
        alt.YOffset("jitter:Q"),
        alt.condition(
            alt.datum.is_ref,
            alt.ColorValue("#969696"),
            alt.Color("z:Q")
            .scale(
                scheme="redblue",
                reverse=True,
                domainMid=0,
                domainMin=gexp_domain[0],
                domainMax=gexp_domain[1],
            )
            .legend(**legend_config, tickCount=3)
            .title(["Gene Set Expression", "Mean Z-Score log2(TPM+1)"]),
        ),
        alt.Shape("is_ref:N")
        .scale(domain=[True, False], range=["triangle", "circle"])
        .legend(None)
        .title(None),
        alt.Size("is_ref:N")
        .scale(domain=[True, False], range=[60, 30])
        .legend(None)
        .title(None),
        alt.Opacity("is_ref:N")
        .scale(domain=[True, False], range=[0.8, 0.6])
        .legend(None)
        .title(None),
        tooltip=[
            alt.Tooltip("cell_id:N").title("Tumor ID"),
            alt.Tooltip("set:N").title("Gene Set"),
            alt.Tooltip("z:Q").format(".2f").title("Expression"),
        ],
    )
)

In [None]:
ref_base = alt.Chart(ref_gsea_source_url, height=HEIGHT, width=WIDTH)

ref_attr_bars = ref_base.mark_bar().encode(
    alt.X("mean(total_mean):Q")
    .axis(grid=False, values=attr_ticks)
    .scale(
        domainMid=0,
        domainMin=attr_domain[0],
        domainMax=attr_domain[1],
        nice=False,
    )
    .title("Combined SHAP Value"),
    alt.Y("y:N").sort(y_order).axis(labelLimit=500, labels=False).title(None),
    alt.Color("mean(log_pval):Q")
    .scale(scheme="warmgreys", domainMin=0)
    .legend(**legend_config, tickCount=4)
    .title(["Significance", "-log10(P-Value)"]),
)

# ref_attr_pts = (
#     ref_base.mark_circle(stroke="black", color="white", size=40, opacity=1, strokeWidth=1)
#     .transform_calculate(jitter="sqrt(-2*log(random()))*cos(2*PI*random())")
#     .encode(
#         alt.X("total_mean:Q")
#         .axis(grid=False, values=attr_ticks)
#         .scale(
#             domainMid=0,
#             domainMin=attr_domain[0],
#             domainMax=attr_domain[1],
#             nice=False,
#         )
#         .title("Combined SHAP Value"),
#         alt.Y("y:N").sort(y_order).axis(labelLimit=500, labels=False).title(None),
#         alt.YOffset("jitter:Q"),
#     )
# )

ref_attr_errs = ref_base.mark_errorbar(
    extent="stdev",
    ticks=alt.MarkConfig(color="black", height=5),
    rule=alt.MarkConfig(color="black", strokeWidth=1),
).encode(
    alt.X("mean(total_mean):Q")
    .axis(grid=False)
    .scale(domainMid=0, domainMin=attr_domain[0], domainMax=attr_domain[1])
    .title("Combined SHAP Value"),
    alt.Y("y:N").sort(y_order).axis(labelLimit=500, labels=False).title(None),
)

# ref_attr_plot = ref_attr_bars + ref_attr_errs + vline + ref_attr_pts
ref_attr_plot = vline + ref_attr_bars + ref_attr_errs

In [None]:
shap_chart = alt.hconcat(
    vline + attr_plot,
    ref_attr_plot,
    vline + ref_gexp_plot + gexp_plot,
    spacing=10,
)
shap_chart = shap_chart.resolve_scale(color="independent")
configure_chart(shap_chart)

## Visualize shared subnetwork embeddings

In [None]:
def assign_shape(r: pd.Series) -> str:
    if r["is_focused"]:
        return "diamond"
    elif r["is_ref"]:
        return "triangle"
    else:
        return "circle"

In [None]:
# tSNE embeddings of the shared subnetwork for each tumor paired with the best drug
x_embed_2d = TSNE(2, random_state=1771).fit_transform(x_embed)
x_embed_2d = pd.DataFrame(x_embed_2d, columns=["x", "y"])

# add annotations
x_embed_2d["cell_id"] = x_embed.index
x_embed_2d["drug_id"] = best_drug
x_embed_2d["is_focused"] = x_embed_2d["cell_id"] == tumor_id
x_embed_2d["is_ref"] = x_embed_2d["cell_id"].isin(ref_tumors)
x_embed_2d["is_ref_or_focused"] = x_embed_2d["is_ref"] | x_embed_2d["is_focused"]
x_embed_2d["shape"] = x_embed_2d.apply(assign_shape, axis=1)

# merge in predictions and true labels
x_embed_2d = (
    x_embed_2d.merge(all_preds_df, on=["cell_id", "drug_id"], how="left")
    .merge(y_true_df, on=["cell_id", "drug_id"], how="left")
    .merge(hr_status, on="cell_id")
    .assign(
        is_TNBC=lambda df: (df["ER"] == "neg")
        & (df["PR"] == "neg")
        & (df["HER2"] == "neg")
    )
    .sort_values(["is_focused", "is_ref"], ascending=[True, True])
)

for item in SIG_GENE_SETS:
    x_embed_2d[item] = x_embed_2d["cell_id"].map(
        gexp_source.query("set == @item").set_index("cell_id")["z"]
    )

x_embed_2d.head()

In [None]:
x_embed_2d[SIG_GENE_SET_IDS].corrwith(x_embed_2d["y_pred"]).sort_values()

In [None]:
M = datasets[0].obs.pivot_table(index="cell_id", columns="drug_id", values="label")
M.corrwith(y_pred_gds).loc[best_drug]

In [None]:
# remove "." from column names to avoid issues with Altair
sorted_gene_sets = SIG_GENE_SET_IDS.copy()
sorted_gene_sets = [s.replace(".", "_") for s in sorted_gene_sets]
x_embed_2d.columns = x_embed_2d.columns.str.replace(".", "_")

In [None]:
display_sets = sorted_gene_sets[:2]

In [None]:
HEIGHT = 120
WIDTH = 180

legend_config = {
    "orient": "none",
    "legendX": WIDTH / 2 - (WIDTH / 1 / 2),
    "legendY": -60,
    "gradientLength": WIDTH / 1,
    "direction": "horizontal",
    "titleLimit": 1000,
}

In [None]:
x_domain = get_domain(x_embed_2d["x"], step=1)
x_domain = [x_domain[0] - 1, x_domain[1] + 1]
y_domain = get_domain(x_embed_2d["y"], step=1)
y_domain = [y_domain[0] - 1, y_domain[1] + 1]

In [None]:
base = (
    alt.Chart(x_embed_2d, width=WIDTH, height=HEIGHT)
    .mark_point(invalid=None, filled=True, stroke="black", opacity=0.8)
    .encode(
        alt.X("x:Q")
        .scale(domain=x_domain)
        .axis(ticks=False, labels=False, grid=False)
        .title("TSNE1"),
        alt.Y("y:Q")
        .scale(domain=y_domain)
        .axis(ticks=False, labels=False, grid=False)
        .title("TSNE2"),
        alt.Shape("shape:N")
        .scale(
            domain=["diamond", "triangle", "circle"],
            range=["diamond", "triangle", "circle"],
        )
        .legend(None),
        alt.condition(
            alt.datum.is_focused == True,
            alt.SizeValue(250),
            alt.SizeValue(60),
        ),
        alt.condition(
            alt.datum.is_ref_or_focused == True,
            alt.StrokeWidthValue(1.0),
            alt.StrokeWidthValue(0.0),
        ),
        tooltip=[
            alt.Tooltip("cell_id:N").title("Tumor ID"),
            alt.Tooltip("y_pred:Q").format(".2f").title("Pred Response"),
        ],
    )
)

pred_chart = base.encode(
    alt.Color("y_pred:Q")
    .scale(scheme="redyellowblue", domainMid=0)
    .legend(**legend_config, tickCount=4)
    .title([best_drug, "Predicted Zd"])
)

make_gexp_chart = lambda gs: base.encode(
    alt.Color(f"{gs}:Q")
    .scale(scheme="redblue", domainMid=0, reverse=True)
    .legend(**legend_config, tickCount=4)
    .title([gs, "Mean Z-Score log2(TPM+1)"]),
    tooltip=[
        alt.Tooltip("cell_id:N").title("Tumor ID"),
        alt.Tooltip(f"{gs}:Q").format(".2f").title("Expression"),
    ],
)
gexp_chart = alt.hconcat(*list(map(make_gexp_chart, display_sets)), spacing=50)

hr_chart = base.encode(
    alt.Color("is_TNBC:N").scale().title(["", "Triple Negative"]).legend(**legend_config),
    tooltip=[
        alt.Tooltip("cell_id:N").title("Tumor ID"),
        alt.Tooltip("is_TNBC:N").title("TNBC Status"),
    ],
)

embed_chart = alt.vconcat(
    pred_chart,
    gexp_chart.resolve_scale(color="shared").resolve_legend(color="independent"),
    spacing=30,
).resolve_scale(color="independent")

configure_chart(embed_chart)

In [None]:
configure_chart(
    alt.hconcat(shap_chart, embed_chart, spacing=60).resolve_scale(shape="independent")
)

In [None]:
gsea_source.head()

In [None]:
ref_preds = all_preds_df.query("cell_id in @ref_tumors").query("drug_id == @best_drug")
mean_ref_pred = ref_preds["y_pred"].mean()
print(f"Mean predicted response for reference tumors: {mean_ref_pred:.2f}")

In [None]:
y_true_df.query("cell_id in @ref_tumors").query("drug_id == @best_drug")