# Comparison of ScreenAhead Drug Selection Algorithms

## Contents

- [Visualize Drug Selection Using PFA](#visualize-drug-selection-using-pfa)
- [PFA With an Increasing Number of Drugs Screened](#pfa-with-an-increasing-number-of-drugs-screened)
- [Statistical Comparisons of Drug Selection Algorithms](#statistical-comparisons-of-drug-selection-algorithms)

In [None]:
from __future__ import annotations

import json
import itertools

import altair as alt
import pandas as pd
import numpy as np
import typing as t
import sklearn.metrics as skm

from pathlib import Path
from scipy import stats
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

from cdrpy.datasets import Dataset

from screendl.utils.drug_selectors import get_response_matrix
from screendl.utils.drug_selectors import PrincipalDrugSelector

In [None]:
BOXPLOT_CONFIG = {
    "size": 26,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}

AXIS_CONFIG = {
    "titleFont": "arial",
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
    "titlePadding": 10,
}


def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures altair chart for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(
            titleFont="arial",
            titleFontStyle="regular",
            titlePadding=10,
            labelFont="arial",
        )
        .configure_legend(
            titleFontSize=10, labelFontSize=10, titleFont="arial", labelFont="arial"
        )
    )

In [None]:
root = Path("../../../datastore")

In [None]:
drug_types_path = root / "processed/DrugAnnotations/drug_types.json"
fixed_drug_types = {"chemotherapy": "Chemo", "targeted": "Targeted", "other": "Other"}
with open(drug_types_path, "r") as fh:
    drug_to_type = {k: fixed_drug_types[v] for k,v in json.load(fh).items()}

In [None]:
dataset_dir = root / "inputs/CellModelPassports-GDSCv1v2"

cell_meta = pd.read_csv(dataset_dir / "MetaCellAnnotations.csv", index_col=0)
drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
drug_meta["type"] = drug_meta.index.map(drug_to_type)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name=dataset_dir.name,
)

print(D)

## Visualize Drug Selection Using PFA

In [None]:
M = get_response_matrix(D, na_threshold=0.9)
M = M.apply(stats.zscore, axis=1)

M_2d = TSNE(2, random_state=1771).fit_transform(M)
M_2d = pd.DataFrame(M_2d, columns=["tsne_1", "tsne_2"])
M_2d["drug_id"] = M.index

selector = PrincipalDrugSelector(D, seed=1441)
selected_drugs = selector.select(20)

M_2d["was_selected"] = M_2d["drug_id"].isin(selected_drugs).astype(int)
M_2d.head()

In [None]:
kmeans = KMeans(20, random_state=1771, n_init="auto")
_ = kmeans.fit(M)

M_2d["cluster"] = kmeans.labels_

In [None]:
x_min = np.floor(M_2d["tsne_1"].min()) - 2
x_max = np.ceil(M_2d["tsne_1"].max()) + 2
y_min = np.floor(M_2d["tsne_2"].min()) - 2
y_max = np.ceil(M_2d["tsne_2"].max()) + 2

principal_drug_chart = (
    alt.Chart(M_2d.sort_values("was_selected"))
    .mark_circle(stroke="black")
    .encode(
        alt.X("tsne_1:Q")
        .axis(grid=False, ticks=False, labels=False)
        .scale(domain=(x_min, x_max))
        .title("TSNE1"),
        alt.Y("tsne_2:Q")
        .axis(grid=False, ticks=False, labels=False)
        .scale(domain=(y_min, y_max))
        .title("TSNE2"),
        alt.Color("cluster:N").title(None).scale(scheme="tableau20").legend(None),
        alt.Opacity("was_selected:N").scale(domain=(0, 1), range=(0.9, 0.9)),
        alt.Size("was_selected:N").scale(domain=(0, 1), range=(30, 40)).legend(None),
        alt.StrokeWidth("was_selected:N")
        .scale(domain=(0, 1), range=(0, 1.5))
        .legend(None),
        tooltip=["drug_id:N"],
    )
    .properties(width=300, height=200)
)

principal_drug_chart.configure_view(strokeOpacity=0).display()

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(pd.read_csv, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
def pcorr(df: pd.DataFrame, c1: str = "y_true", c2: str = "y_pred") -> float:
    if df.shape[0] < 10:
        return np.nan
    return stats.pearsonr(df[c1], df[c2])[0]

def mse(df: pd.DataFrame, c1: str = "y_true", c2: str = "y_pred") -> float:
    if df.shape[0] < 5:
        return np.nan
    return skm.mean_squared_error(df[c1], df[c2])

## PFA With an Increasing Number of Drugs Screened

In [None]:
output_dir = root / "outputs"
path_fmt = "experiments/screenahead_drug_selection/{0}/{1}/multiruns/{2}"

dataset = "CellModelPassports-GDSCv1v2"
date = "2024-04-17_10-14-39"
model = "ScreenDL"

run_dir = output_dir / path_fmt.format(dataset, model, date)
rescale = lambda df, col, by: df.groupby(by)[col].transform(stats.zscore)

base_results = load_multirun_predictions(
    run_dir, "*/predictions.csv", splits=["test"]
).assign(
    y_true=lambda df: rescale(df, "y_true", ["drug_id"]),
    y_pred=lambda df: rescale(df, "y_pred", ["drug_id"]),
)

sa_results = load_multirun_predictions(run_dir, "*/predictions_sa.csv").assign(
    y_true=lambda df: rescale(df, "y_true", ["selector_type", "n_drugs", "drug_id"]),
    y_pred=lambda df: rescale(df, "y_pred", ["selector_type", "n_drugs", "drug_id"]),
)

In [None]:
keep_selectors = ["uniform", "agglomerative", "principal", "random"]
sa_results = sa_results[sa_results["selector_type"].isin(keep_selectors)]

In [None]:
grouped = sa_results.groupby(["cell_id", "drug_id", "n_drugs"])
counts = grouped["selector_type"].nunique()

# remove pairs which were screened in at least one trial
sa_results = (
    sa_results.set_index(["cell_id", "drug_id", "n_drugs"])
    .loc[counts[counts == counts.max()].index]
    .reset_index()
)

In [None]:
base_drug_corrs = base_results.groupby("drug_id").apply(pcorr)
sa_drug_corrs = sa_results.groupby(["selector_type", "n_drugs", "drug_id"]).apply(pcorr)
sa_drug_corrs.groupby(["selector_type", "n_drugs"]).median().unstack(1)

In [None]:
base_source = base_drug_corrs.to_frame(name="pcc").reset_index()
base_source["n_drugs"] = 0

selector_type = "principal"
selector_source = sa_drug_corrs.xs(selector_type).to_frame(name="pcc").reset_index()

source = pd.concat([base_source, selector_source])

num_drugs_chart = (
    alt.Chart(source)
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("n_drugs:O").axis(labelAngle=0, grid=False).title("No. Drugs Screened"),
        alt.Y("pcc:Q")
        .axis(grid=False)
        .scale(alt.Scale(domain=(0, 1)))
        .title("Pearson Correlation"),
        alt.condition(
            alt.datum.n_drugs == 0, alt.ColorValue("#4D79A9"), alt.ColorValue("#53A24B")
        ),
    )
    .properties(width=35 * source["n_drugs"].nunique(), height=200)
)

configure_chart(num_drugs_chart)

## Statistical Comparisons of Drug Selection Algorithms

In [None]:
corrs_vs_random = (
    sa_drug_corrs.loc[["uniform", "agglomerative", "principal"]]
    .to_frame("pcc_informed")
    .join(sa_drug_corrs.loc["random"].to_frame("pcc_random"))
    .reset_index()
)
corrs_vs_random["delta"] = corrs_vs_random["pcc_informed"] - corrs_vs_random["pcc_random"]
corrs_vs_random["winner"] = corrs_vs_random.apply(
    lambda r: r["selector_type"] if r["pcc_informed"] > r["pcc_random"] else "random",
    axis=1,
)

corrs_vs_random.head(10)

In [None]:
wins_vs_random = (
    corrs_vs_random.groupby(["selector_type", "n_drugs"])["winner"]
    .value_counts()
    .to_frame("count")
    .reset_index()
)

wins_vs_random["order"] = (wins_vs_random["winner"] == "random").astype(int)
mapper = {"agglomerative": 0, "uniform": 1, "principal": 2}
wins_vs_random["col_order"] = wins_vs_random["selector_type"].map(mapper)
grouped = wins_vs_random.groupby(["selector_type", "n_drugs"])
wins_vs_random["win_pct"] = grouped["count"].transform(lambda g: 100 * g / g.sum())
wins_vs_random.head()

In [None]:
color_domain = ("random", "uniform", "agglomerative", "principal")
color_range = ("lightgray", "#53A24B", "#53A24B", "#53A24B")

share_chart = (
    alt.Chart(wins_vs_random)
    .mark_bar(stroke="black", strokeWidth=1, size=19)
    .encode(
        alt.X("n_drugs:O")
        .axis(labelAngle=0, grid=False, titlePadding=10, domainOpacity=0)
        .title("No. Drugs"),
        alt.Y("sum(count):Q")
        .stack("normalize")
        .axis(grid=False, offset=5, titlePadding=10, values=(0, 0.25, 0.5, 0.75, 1))
        .title("Win Share (%)"),
        alt.Color("winner:N")
        .scale(domain=color_domain, range=color_range)
        .legend(None),
        alt.Column("col_order:N", spacing=10).header(None),
        alt.Order("order:O"),
    )
    .properties(height=200, width=22 * 5)
)

configure_chart(share_chart)

In [None]:
selector_map = {
    "random": "R",
    "uniform": "M",
    "agglomerative": "A",
    "principal": "P",
}

wins_vs_random["selector"] = wins_vs_random["selector_type"].map(selector_map)

In [None]:
get_metrics = lambda g: pd.Series(
    {
        "pvalue": stats.wilcoxon(g["pcc_1"], g["pcc_2"]).pvalue,
        "share": (g["pcc_1"] > g["pcc_2"]).sum() / g.shape[0],
        "wins": (g["pcc_1"] > g["pcc_2"]).sum(),
        "total": g.shape[0],
    }
)

temp = sa_drug_corrs.to_frame(name="pcc").dropna()
combs = itertools.combinations(selector_map, 2)
result = []
for selector_2, selector_1 in combs:
    x_selector_1 = temp.xs(selector_1)
    x_selector_2 = temp.xs(selector_2)

    metrics = (
        x_selector_1.join(x_selector_2, lsuffix="_1", rsuffix="_2")
        .groupby("n_drugs")
        .apply(get_metrics)
    )

    metrics["selector_1"] = selector_1
    metrics["selector_2"] = selector_2
    metrics["best_selector"] = metrics["share"].apply(
        lambda x: selector_1 if x > 0.5 else (selector_2 if x < 0.5 else "")
    )

    result.append(metrics)

source = pd.concat(result).reset_index()
source["selector_1"] = source["selector_1"].map(selector_map)
source["selector_2"] = source["selector_2"].map(selector_map)
source["best_selector"] = source["best_selector"].map(selector_map)
source["log_pvalue"] = -np.log10(source["pvalue"])
source["reject"] = (source["pvalue"] < 0.05).astype(int)

source.head()

In [None]:
points = (
    alt.Chart()
    .mark_circle(size=650, stroke="black")
    .encode(
        alt.X("selector_1:N")
        .title(None)
        .axis(labelAngle=0, orient="top")
        .scale(domain=list(selector_map.values())),
        alt.Y("selector_2:N").title(None).scale(domain=list(selector_map.values())),
        alt.Color("log_pvalue:Q")
        .scale(nice=True)
        .legend(title=None, gradientLength=150, gradientThickness=20, tickCount=2),
        alt.StrokeWidth("reject:N").scale(domain=(1, 0), range=(2, 0)).legend(None),
    )
    .properties(width=150, height=150)
)

text = (
    alt.Chart()
    .mark_text(size=10)
    .encode(
        alt.X("selector_1:N")
        .title(None)
        .axis(labelAngle=0)
        .scale(domain=list(selector_map.values())[1:]),
        alt.Y("selector_2:N").title(None).scale(domain=list(selector_map.values())[:-1]),
        alt.condition(
            alt.datum.best_selector == None,
            alt.TextValue(""),
            alt.Text("best_selector:N"),
        ),
        alt.condition(
            alt.datum.log_pvalue > 14, alt.ColorValue("white"), alt.ColorValue("black")
        ),
    )
)

pairwise_selector_chart = alt.layer(points, text).facet(
    column=alt.Column("n_drugs:O").title("No. Drugs").header(orient="bottom"),
    data=source,
    spacing=10,
)

pairwise_selector_chart.configure_view(strokeOpacity=0).display()

In [None]:
top = alt.hconcat(num_drugs_chart, principal_drug_chart, share_chart, spacing=40)
top = top.resolve_scale(color="independent")

chart = alt.vconcat(top, pairwise_selector_chart, spacing=45)
configure_chart(chart.resolve_scale(color="independent"))