# ScreenAhead Functionaly-Related Therapies

## Contents

In [None]:
# python scripts/experiments/screenahead_related_drugs.py -m dataset.split.id=1 experiment.drug_id="5-Fluorouracil","Leflunomide","Epirubicin","Piperlongumine","Vinblastine","Oxaliplatin","Docetaxel","Gemcitabine","Cytarabine","Cisplatin","Alisertib","Afatinib","Erlotinib","Dabrafenib","Alpelisib","Trametinib","Olaparib","Nilotinib","Fulvestrant","Irinotecan"

In [None]:
from __future__ import annotations

import json
import random

import altair as alt
import pandas as pd
import numpy as np
import typing as t
import sklearn.metrics as skm

from pathlib import Path
from scipy import stats
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from tensorflow import keras

from cdrpy.datasets import Dataset
from cdrpy.data.preprocess import GroupStandardScaler
from cdrpy.mapper import BatchedResponseGenerator
from cdrpy.metrics import tf_metrics

from screendl import model as screendl
from screendl.utils import evaluation as eval_utils
from screendl.utils.drug_selectors import get_response_matrix

In [None]:
root = Path("../../../datastore")

In [None]:
drug_types_path = root / "processed/DrugAnnotations/drug_types.json"
fixed_drug_types = {"chemotherapy": "Chemo", "targeted": "Targeted", "other": "Other"}
with open(drug_types_path, "r") as fh:
    drug_to_type = {k: fixed_drug_types[v] for k,v in json.load(fh).items()}

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2"

cell_meta = pd.read_csv(dataset_dir / "MetaCellAnnotations.csv", index_col=0)
drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
drug_meta["type"] = drug_meta.index.map(drug_to_type)

cell_encoders = screendl.load_cell_features(
    dataset_dir / "ScreenDL/FeatureGeneExpression.csv"
)

drug_encoders = screendl.load_drug_features(
    dataset_dir / "ScreenDL/FeatureMorganFingerprints.csv"
)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    cell_encoders=cell_encoders,
    drug_encoders=drug_encoders,
    name=dataset_dir.name,
)

print(D)

In [None]:
D.obs["label"] = D.obs.groupby("drug_id")["label"].transform(stats.zscore)
exp_enc = D.cell_encoders["exp"]
exp_enc.data[:] = StandardScaler().fit_transform(exp_enc.data)

In [None]:
exp_dim = D.cell_encoders["exp"].shape[-1]
mol_dim = D.drug_encoders["mol"].shape[-1]

model = screendl.create_model(
    exp_dim,
    mol_dim,
    exp_norm_layer=None,
    cnv_norm_layer=None,
    exp_hidden_dims=[512, 256, 128, 64],
    mol_hidden_dims=[256, 128, 64],
    shared_hidden_dims=[128, 64],
    activation="leaky_relu",
    use_noise=True,
    noise_stddev=0.3,
    use_l2=False,
)

In [None]:
model.summary()

In [None]:
gen = BatchedResponseGenerator(D, 256)
seq = gen.flow_from_dataset(D, shuffle=True, seed=4114)

model.compile(
    optimizer=keras.optimizers.Adam(1e-4, weight_decay=1e-4),
    loss="mean_squared_error",
    metrics=["mse", tf_metrics.pearson],
)

hx = model.fit(seq, epochs=15)

In [None]:
emb_input = model.get_layer("mol_input").input
emb_output = model.get_layer("mol_mlp_3").output

emb_model = keras.Model(emb_input, emb_output)

X_drug = D.drug_encoders["mol"].data
X_drug_embed = emb_model.predict(X_drug)

In [None]:
X_drug_embed_2d = TSNE(2, random_state=1771).fit_transform(X_drug_embed)
X_drug_embed_2d = pd.DataFrame(X_drug_embed_2d, columns=["x", "y"])
X_drug_embed_2d["drug_id"] = X_drug.index
X_drug_embed_2d.head()

In [None]:
kmeans = KMeans(9, random_state=1771, n_init="auto")
_ = kmeans.fit(X_drug_embed)

X_drug_embed_2d["cluster"] = kmeans.labels_ + 1
X_drug_embed_2d["pathway"] = X_drug_embed_2d["drug_id"].map(D.drug_meta["target_pathway"])
X_drug_embed_2d["targets"] = X_drug_embed_2d["drug_id"].map(D.drug_meta["targets"])

In [None]:
FOCUSED_PATHWAYS = [
    "EGFR signaling",
    "ERK MAPK signaling",
    "PI3K/MTOR signaling",
    "Apoptosis regulation",
    "DNA replication",
    "Genome integrity",
    "IGF1R signaling"
]

In [None]:
X_drug_embed_2d["focused_pathway"] = X_drug_embed_2d["pathway"].isin(FOCUSED_PATHWAYS)

In [None]:
def parse_targets(item: t.Any) -> t.List[str]:
    if not isinstance(item, str):
        return []
    tgts = []
    for tgt in item.split(","):
        tgt = tgt.strip()
        tgts.append(tgt)
    return tgts


X_drug_embed_2d["targets"] = X_drug_embed_2d["targets"].map(parse_targets)

In [None]:
def assign_focused_targets(targets: t.List[str]) -> str | None:
    if "EGFR" in targets or "ERBB2" in targets:
        return "EGFR/HER2"
    elif any("PARP" in t for t in targets):
        return "PARP1/2"
    elif any(x in targets for x in ["ERK1", "ERK2"]):
        return "ERK1/2"
    elif any(x in targets for x in ["MEK1", "MEK2"]):
        return "MEK1/2"
    elif any(x in targets for x in ["AKT1", "AKT2", "AKT3", "AKT"]):
        return "AKT1/2/3"
    elif any(x.startswith("TOP") for x in targets):
        return "TOP1/2"
    elif any(x.startswith("MTOR") for x in targets):
        return "MTOR"
    elif any(x in targets for x in ["BRD2", "BRD3", "BRD4"]):
        return "BRD2/3/4"
    else:
        return None


X_drug_embed_2d["selected_targets"] = X_drug_embed_2d["targets"].map(
    assign_focused_targets
)
X_drug_embed_2d["focused_targets"] = X_drug_embed_2d["selected_targets"].map(
    lambda x: isinstance(x, str)
)

In [None]:
AXIS_CONFIG = {
    "titleFont": "arial",
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}

BOXPLOT_CONFIG = {
    "size": 28,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}

def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures altair chart for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(
            titleFont="arial",
            titleFontStyle="regular",
            titlePadding=10,
            labelFont="arial",
        )
        .configure_legend(
            titleFontSize=10, labelFontSize=10, titleFont="arial", labelFont="arial"
        )
    )

In [None]:
target_chart = (
    alt.Chart(X_drug_embed_2d)
    .mark_circle(size=50)
    .encode(
        alt.X("x:Q")
        .axis(ticks=False, labels=False, grid=False, titlePadding=10)
        .title("TSNE1"),
        alt.Y("y:Q")
        .axis(ticks=False, labels=False, grid=False, titlePadding=10)
        .title("TSNE2"),
        alt.condition(
            alt.datum.focused_targets == True,
            alt.Color("selected_targets:N")
            .scale(
                domain=sorted(
                    [
                        "BRD2/3/4",
                        "EGFR/HER2",
                        "PARP1/2",
                        "ERK1/2",
                        "MEK1/2",
                        "AKT1/2/3",
                        "TOP1/2",
                        "MTOR",
                    ]
                )
            )
            .legend(orient="right")
            .title("Protein Targets"),
            alt.ColorValue("lightgray"),
        ),
        alt.condition(
            alt.datum.focused_targets == True,
            alt.SizeValue(60),
            alt.SizeValue(30),
        ),
        tooltip=["drug_id:N", "pathway:N", "targets"],
    )
    .properties(width=350, height=250)
)

target_chart.configure_view(strokeOpacity=0)

In [None]:
pathway_chart = (
    alt.Chart(X_drug_embed_2d)
    .mark_circle()
    .encode(
        alt.X("x:Q")
        .axis(ticks=False, labels=False, grid=False, titlePadding=10)
        .title("TSNE1"),
        alt.Y("y:Q")
        .axis(ticks=False, labels=False, grid=False, titlePadding=10)
        .title("TSNE2"),
        alt.condition(
            alt.datum.focused_pathway == True,
            alt.Color("pathway:N")
            .scale(domain=FOCUSED_PATHWAYS)
            .legend(orient="right")
            .title("Target Pathway"),
            alt.ColorValue("lightgray"),
        ),
        alt.condition(
            alt.datum.focused_pathway == True,
            alt.SizeValue(60),
            alt.SizeValue(30),
        ),
        tooltip=["drug_id:N", "pathway:N"],
    )
    .properties(width=350, height=250)
)

pathway_chart.configure_view(strokeOpacity=0)

In [None]:
def load_multirun_predictions(multirun_dir: str | Path, regex: str) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    reader = lambda f: pd.read_csv(f).assign(split_id=f.parent.name)
    file_list = multirun_dir.glob(regex)
    return pd.concat(map(reader, file_list))

In [None]:
output_dir = root / "outputs"
path_fmt = "experiments/screenahead_related_drugs/{0}/{1}/multiruns/{2}"
fixed_models = {"ScreenDL": "ScreenDL-PT", "ScreenDL-SA": "ScreenDL-SA"}

dataset = "CellModelPassports-GDSCv1v2"
dates = [
    "2024-11-27_11-57-01",
    "2024-11-27_21-45-38",
    "2024-11-28_07-57-15",
    "2024-11-28_07-58-09",
    "2024-11-28_07-58-39",
    "2024-11-28_14-19-01",
    "2024-11-28_17-47-32",
    "2024-11-28_17-56-25",
    "2024-11-28_20-44-49",
    "2024-11-29_08-06-14",
]
model = "ScreenDL"

fold_results = []
for date in dates:
    run_dir = output_dir / path_fmt.format(dataset, model, date)
    run_results = load_multirun_predictions(run_dir, "*/predictions_sa.csv")
    run_results["model"] = run_results["model"].map(fixed_models)
    fold_results.append(run_results)

results_df = pd.concat(fold_results)
results_df.head()

In [None]:
id_vars = ["model", "drug_id", "n_drugs", "n_best_drugs"]

pcc_metrics = (
    results_df.groupby(id_vars, dropna=False)
    .apply(eval_utils.pcorr)
    .to_frame(name="pcc")
    .reset_index()
    .sort_values(id_vars)
)


def assign_label(row):
    if row["n_drugs"] == 0:
        return "base"
    return str(int(row["n_best_drugs"]))


pcc_metrics["label"] = pcc_metrics.apply(assign_label, axis=1)
order = list(pcc_metrics["label"].unique())

pcc_metrics.head()

In [None]:
boxes = (
    alt.Chart(pcc_metrics)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("label:O")
        .sort(order)
        .axis(labelAngle=0, titlePadding=10)
        .title("No. Functionally-Related Drugs"),
        alt.Y("pcc:Q")
        .scale(domain=(0, 1))
        .axis(grid=False, titlePadding=10)
        .title("Pearson Correlation"),
        alt.condition(
            alt.datum.label == "base",
            alt.ColorValue("#4C78A8"),
            alt.ColorValue("#53A24B"),
        ),
    )
    .properties(width=450, height=220)
)

configure_chart(boxes)

In [None]:
temp_ = pcc_metrics.set_index(["model", "n_best_drugs", "drug_id"]).sort_index()["pcc"]

x1 = temp_.xs(("ScreenDL-PT", pd.NA))
x2 = temp_.xs(("ScreenDL-SA", 0))
print(stats.wilcoxon(x2, x1))

In [None]:
temp_ = pcc_metrics.set_index(["model", "n_best_drugs", "drug_id"]).sort_index()["pcc"]

for i in range(10):
    x1 = temp_.xs(("ScreenDL-SA", i))
    x2 = temp_.xs(("ScreenDL-SA", i + 1))
    print(i, i + 1, stats.wilcoxon(x2, x1))

## Performance improvement based on functional similarity

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(pd.read_csv, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
model_results = {}

In [None]:
output_dir = root / "outputs"
path_fmt = "experiments/screenahead_drug_selection/{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2"
date = "2024-04-17_10-14-39"
model = "ScreenDL"

run_dir = output_dir / path_fmt.format(dataset, model, date)
rescale = lambda df, col, by: df.groupby(by)[col].transform(stats.zscore)

model_results["ScreenDL-PT"] = (
    load_multirun_predictions(run_dir, "*/predictions.csv", splits=["test"])
    .assign(
        y_true=lambda df: rescale(df, "y_true", ["drug_id"]),
        y_pred=lambda df: rescale(df, "y_pred", ["drug_id"]),
    )
    .assign(model="ScreenDL-PT")
)

model_results["ScreenDL-SA"] = (
    load_multirun_predictions(run_dir, "*/predictions_sa.csv")
    .query("selector_type == 'principal' and n_drugs == 20")
    .assign(
        y_true=lambda df: rescale(df, "y_true", ["selector_type", "n_drugs", "drug_id"]),
        y_pred=lambda df: rescale(df, "y_pred", ["selector_type", "n_drugs", "drug_id"]),
    )
    .assign(model="ScreenDL-SA")
)

In [None]:
X: pd.DataFrame = model_results["ScreenDL-SA"]

In [None]:
M = get_response_matrix(D)
f_sims = M.T.corr()
np.fill_diagonal(f_sims.values, 0)
f_sims.head()

In [None]:
X_with_sims = []
for _, group in X.groupby("cell_id"):
    screened_drugs = group.query("was_screened == True")["drug_id"].to_list()
    group["max_f_sim"] = (
        f_sims[screened_drugs].loc[group["drug_id"]].max(axis=1).to_list()
    )
    X_with_sims.append(group)

X_with_sims: pd.DataFrame = pd.concat(X_with_sims)
X_with_sims.head()

In [None]:
bins = np.arange(0, 1.05, 0.05)
X_with_sims["max_f_sim_bin"] = pd.cut(
    X_with_sims["max_f_sim"],
    bins=bins,
)

bin_to_str = lambda x: f"({x.left}, {x.right}]"
X_with_sims["max_f_sim_bin_str"] = X_with_sims["max_f_sim_bin"].map(bin_to_str)
X_with_sims.head()

In [None]:
X_with_sims["ae"] = (X_with_sims["y_true"] - X_with_sims["y_pred"]).abs()
X_with_sims["se"] = (X_with_sims["y_true"] - X_with_sims["y_pred"]) ** 2

base_err = (
    model_results["ScreenDL-PT"]
    .assign(base_ae=lambda df: (df["y_true"] - df["y_pred"]).abs())
    .assign(base_se=lambda df: (df["y_true"] - df["y_pred"]) ** 2)
    .filter(items=["cell_id", "drug_id", "base_ae", "base_se"])
)
X_with_sims = X_with_sims.merge(base_err, on=["cell_id", "drug_id"])

In [None]:
X_with_sims_agg = (
    X_with_sims.query("was_screened == False")
    .groupby("max_f_sim_bin")
    .aggregate(
        {
            "base_ae": "mean",
            "ae": "mean",
            "base_se": "mean",
            "se": "mean",
            "max_f_sim_bin_str": "first",
        }
    )
    .assign(
        delta_mae=lambda df: df["ae"] - df["base_ae"],
        delta_mse=lambda df: df["se"] - df["base_se"],
    )
    .dropna()
)

In [None]:
X_with_sims_agg["max_interval"] = [x.right for x in X_with_sims_agg.index.to_list()]
X_with_sims_agg["min_interval"] = [x.left for x in X_with_sims_agg.index.to_list()]
X_with_sims_agg.head()

In [None]:
source = X_with_sims_agg.reset_index(drop=True)
sorted_bins = source["max_f_sim_bin_str"].unique().tolist()

source_points = (
    source.reset_index(drop=True)
    .melt(
        id_vars=["max_f_sim_bin_str"],
        value_vars=["base_ae", "ae"],
        var_name="model",
        value_name="mae",
    )
    .assign(
        model=lambda df: df["model"].map({"base_ae": "ScreenDL-PT", "ae": "ScreenDL-SA"})
    )
)

source_bars = source.assign(model="ScreenDL-SA")

In [None]:
stats.pearsonr(
    source_bars["delta_mae"],
    source_bars["max_interval"].rank(),
)

In [None]:
points = (
    alt.Chart(source_points)
    .mark_circle(size=80, opacity=1.0, stroke="black", strokeWidth=0.5)
    .encode(
        alt.X("max_f_sim_bin_str:O")
        .sort(sorted_bins)
        .axis(titlePadding=10, labelAngle=-60, labelPadding=10)
        .title("Binned Max Functional Similarity"),
        alt.Y("mae:Q")
        .scale(domain=(0, 1.1))
        .axis(grid=False, values=(0, 0.5, 1.0), titlePadding=5)
        .title("Mean Absolute Error (MAE)"),
        alt.Color("model:N")
        .scale(domain=("ScreenDL-PT", "ScreenDL-SA"), range=("#4C78A8", "#5CA453"))
        .legend(orient="none", legendX=350, legendY=0)
        .title(None),
    )
    .properties(width=450, height=200)
)

bars = (
    alt.Chart(source_bars)
    .mark_bar(color="gray", stroke="black", strokeWidth=1, opacity=1)
    .encode(
        alt.X("max_f_sim_bin_str:O")
        .sort(sorted_bins)
        .axis(grid=False, labelAngle=-60, labelPadding=10, titlePadding=10)
        .title("Binned Max Functional Similarity"),
        alt.Y("delta_mae:Q")
        .axis(grid=False, titlePadding=10)
        .scale(domain=(-0.5, 0.5))
        .title("Change in MAE"),
    )
    .properties(width=450, height=110)
)

mae_chart = alt.vconcat(points, bars, spacing=40)
configure_chart(mae_chart)

In [None]:
panel_1 = alt.vconcat(mae_chart, boxes, spacing=40).resolve_scale(color="independent")
panel_2 = alt.vconcat(pathway_chart, target_chart, spacing=40).resolve_scale(
    color="independent"
)

final_chart = alt.hconcat(panel_1, panel_2, spacing=40)
configure_chart(final_chart)