# GDS Analyses

## Contents

- [Data Loading](#data-loading)
- [GDS Predicts Response to Individual Therapies in PDMCs](#gds-predicts-response-to-individual-therapies-in-pdmcs)
- [GDS Predicts Response to Individual Therapies in Cell Lines](#gds-predicts-response-to-individual-therapies-in-cell-lines)
- [ScreenAhead With GDS vs All Drugs](#screenahead-with-gds-vs-all-drugs)

In [None]:
from __future__ import annotations

import random

import altair as alt
import altair_forge as af
import pandas as pd
import numpy as np
import typing as t
import sklearn.metrics as skm

from pathlib import Path
from scipy import stats

from cdrpy.datasets import Dataset
from cdrpy.data.preprocess import GroupStandardScaler

from screendl import model as screendl
from screendl.utils import evaluation as eval_utils

## Data Loading

In [None]:
random.seed(1771)

In [None]:
root = Path("../../../datastore")

In [None]:
dataset_dir = root / "inputs/CellModelPassportsGDSCv1v2HCIv10AllDrugsHallmarkCombat"

drug_meta = pd.read_csv(dataset_dir / "MetaDrugAnnotations.csv", index_col=0)
drug_encoders = screendl.load_drug_features(
    dataset_dir / "ScreenDL/FeatureMorganFingerprints.csv"
)

cell_meta = pd.read_csv(dataset_dir / "MetaSampleAnnotations.csv", index_col=0)
cell_encoders = screendl.load_cell_features(
    dataset_dir / "ScreenDL/FeatureGeneExpression.csv"
)

D = Dataset.from_csv(
    dataset_dir / "LabelsLogIC50.csv",
    cell_encoders=cell_encoders,
    drug_encoders=drug_encoders,
    cell_meta=cell_meta,
    drug_meta=drug_meta,
    name="CellModelPassportsGDSCHCI",
)

cell_ids = D.cell_meta[D.cell_meta["domain"] == "CELL"].index
pdmc_ids = D.cell_meta[D.cell_meta["domain"] == "PDMC"].index

cell_ds = D.select_cells(cell_ids, name="cell_ds")
pdmc_ds = D.select_cells(pdmc_ids, name="pdmc_ds")

print(cell_ds)
print(pdmc_ds)

## GDS Predicts Response to Individual Therapies in PDMCs

In [None]:
MODELS = ["DualGCN", "HiDRA", "DeepCDR", "ScreenDL-PT", "ScreenDL-SA (Zd)", "ScreenDL-SA (GDS)"]

In [None]:
AXIS_CONFIG = {
    "titleFont": "arial",
    "titleFontStyle": "regular",
    "labelFont": "arial",
    "tickColor": "black",
    "domainColor": "black",
}

MODEL_COLOR_SCALE = alt.Scale(
    domain=MODELS,
    range=("lightgray", "darkgray", "gray", "#4C78A8", "#5CA453", "#9E765F"),
)

BOXPLOT_CONFIG = {
    "size": 25,
    "median": alt.MarkConfig(fill="black"),
    "box": alt.MarkConfig(stroke="black"),
    "ticks": alt.MarkConfig(size=10),
    "outliers": alt.MarkConfig(stroke="black", size=15, strokeWidth=1.5),
}

def configure_chart(chart: alt.Chart) -> alt.Chart:
    """Configures altair chart for viewing."""
    return (
        chart.configure_view(strokeOpacity=0)
        .configure_axis(**AXIS_CONFIG)
        .configure_header(labelFont="arial")
        .configure_legend(titleFontSize=10, labelFontSize=10)
    )

In [None]:
seed = 123
n_pdmcs = 30
n_drugs = 60

pdmc_obs = pdmc_ds.obs.drop(columns="id").copy()
pdmc_obs["Zd"] = pdmc_obs.groupby("drug_id")["label"].transform(stats.zscore)
pdmc_obs["GDS"] = pdmc_obs.groupby("cell_id")["Zd"].transform("mean")

# sample PDMCs
pdmc_obs["bin"] = pd.qcut(pdmc_obs["GDS"], n_pdmcs, labels=range(n_pdmcs)).astype(int)
sampled_cells = pdmc_obs.groupby("bin")["cell_id"].sample(1, random_state=123)
pdmc_obs_sampled = pdmc_obs.query("cell_id in @sampled_cells")

# sample drugs
unique_drugs = pdmc_obs_sampled["drug_id"].drop_duplicates()
sampled_drugs = unique_drugs.sample(n_drugs, random_state=123)
pdmc_obs_sampled = pdmc_obs_sampled.query("drug_id in @sampled_drugs")

pdmc_obs_sampled.head()

In [None]:
points = (
    alt.Chart(pdmc_obs_sampled)
    .mark_circle(size=50)
    .encode(
        alt.X("GDS:Q")
        .axis(format=".2", grid=False, tickCount=5)
        .scale()
        .title(["Global Drug Sensitivity", "(Mean Z-Score ln(IC50))"]),
        alt.Y("Zd:Q")
        .axis(grid=False, tickCount=5)
        .scale(domain=(-4, 4))
        .title("Z-Score ln(IC50)"),
        alt.Color("cell_id:N").sort("-x").legend(None),
    )
)

reg_line = points.transform_regression("GDS", "Zd", extent=[-1, 1]).mark_line(
    stroke="black", strokeWidth=1.5, strokeDash=[3, 3], point=False
)

pdmc_gds_chart = (points + reg_line).properties(width=460, height=250)
configure_chart(pdmc_gds_chart)

In [None]:
stats.pearsonr(pdmc_obs["Zd"], pdmc_obs["GDS"])

## GDS Predicts Response to Individual Therapies in Cell Lines

In [None]:
seed = 123
n_cells = 30
n_drugs = 60

cell_obs = cell_ds.obs.drop(columns="id").copy()
cell_obs["Zd"] = cell_obs.groupby("drug_id")["label"].transform(stats.zscore)
cell_obs["GDS"] = cell_obs.groupby("cell_id")["Zd"].transform("mean")

# sample cell lines
cell_obs["bin"] = pd.qcut(cell_obs["GDS"], n_cells, labels=range(n_cells)).astype(int)
sampled_cells = cell_obs.groupby("bin")["cell_id"].sample(1, random_state=123)
cell_obs_sampled = cell_obs.query("cell_id in @sampled_cells")

# sample drugs
unique_drugs = cell_obs_sampled["drug_id"].drop_duplicates()
sampled_drugs = unique_drugs.sample(n_drugs, random_state=123)
cell_obs_sampled = cell_obs_sampled.query("drug_id in @sampled_drugs")

cell_obs_sampled.head()

In [None]:
points = (
    alt.Chart(cell_obs_sampled)
    .mark_circle(size=40)
    .encode(
        alt.X("GDS:Q")
        .axis(format=".2", grid=False, tickCount=5)
        .scale(domain=(-1.5, 1.5))
        .title(["Global Drug Sensitivity", "(Mean Z-Score ln(IC50))"]),
        alt.Y("Zd:Q").axis(grid=False).scale(domain=(-5, 5)).title("Z-Score ln(IC50)"),
        alt.Color("cell_id:N").sort("-x").legend(None),
    )
)

reg_line = points.transform_regression("GDS", "Zd", extent=[-1.5, 1.5]).mark_line(
    stroke="black", strokeWidth=1.5, strokeDash=[3, 3], point=False
)

cell_gds_chart = (points + reg_line).properties(width=460, height=250)
configure_chart(cell_gds_chart)

In [None]:
stats.pearsonr(cell_obs["Zd"], cell_obs["GDS"])

## ScreenAhead With GDS vs All Drugs

In [None]:
def load_multirun_predictions(
    multirun_dir: str | Path, regex: str, splits: list[str] | None = None
) -> pd.DataFrame:
    """Loads predictions from a multirun."""
    if isinstance(multirun_dir, str):
        multirun_dir = Path(multirun_dir)

    file_list = multirun_dir.glob(regex)
    pred_df = pd.concat(map(pd.read_csv, file_list))

    if splits is not None:
        pred_df = pred_df[pred_df["split_group"].isin(splits)]

    return pred_df

In [None]:
def rescale_predictions(df: pd.DataFrame) -> pd.DataFrame:
    """Rescales the predictions based on predictions in the train set."""
    df_trn = df[df["split_group"] == "train"]
    df_tst = df[df["split_group"] == "test"]

    gss = GroupStandardScaler()
    df_trn["y_true"] = gss.fit_transform(df_trn[["y_true"]], groups=df_trn["drug_id"])
    df_tst["y_true"] = gss.transform(df_tst[["y_true"]], groups=df_tst["drug_id"])

    gss = GroupStandardScaler()
    df_trn["y_pred"] = gss.fit_transform(df_trn[["y_pred"]], groups=df_trn["drug_id"])
    df_tst["y_pred"] = gss.transform(df_tst[["y_pred"]], groups=df_tst["drug_id"])

    return pd.concat([df_trn, df_tst])

In [None]:
model_results: t.Dict[str, pd.DataFrame] = {}
output_dir = root / "outputs"
path_fmt = "basic/{0}/{1}/multiruns/{2}"
column_mapper = {"fold": "split_id", "split": "split_group"}
dataset = "CellModelPassportsGDSCv1v2Hallmark"

In [None]:
# HiDRA results

model = "HiDRA-legacy"
date = "2024-04-17_19-29-28"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

model_results[model.split("-")[0]] = (
    load_multirun_predictions(run_dir, run_regex, splits=["train", "test"])
    .rename(columns=column_mapper)
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model.split("-")[0])
)

In [None]:
# DualGCN results

model = "DualGCN-legacy"
dates = [
    # NOTE: exceeded 72 hr maximum timelimit so folds are split over multiple runs
    "2024-04-12_09-31-07",
    "2024-04-12_09-32-20",
    "2024-04-14_08-02-47",
    "2024-04-14_08-03-56",
    "2024-04-15_16-17-18",
]

temp = []
for date in dates:
    run_dir = output_dir / path_fmt.format(dataset, model, date)
    run_regex = "*/predictions.csv"
    temp.append(load_multirun_predictions(run_dir, run_regex, splits=["train", "test"]))

model_results[model.split("-")[0]] = (
    pd.concat(temp)
    .rename(columns=column_mapper)
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model.split("-")[0])
)

In [None]:
model = "DeepCDR-legacy"
date = "2024-04-02_09-27-37"

run_dir = output_dir / path_fmt.format(dataset, model, date)
run_regex = "*/predictions.csv"

model_results[model.split("-")[0]] = (
    load_multirun_predictions(run_dir, run_regex, splits=["train", "test"])
    .rename(columns=column_mapper)
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model.split("-")[0])
)

In [None]:
model = "ScreenDL"
date = "2024-04-18_17-35-37"

path_fmt = "experiments/sa_mr/{0}/{1}/multiruns/{2}"
run_dir = output_dir / path_fmt.format(dataset, model, date)

model_results[model + "-PT"] = (
    load_multirun_predictions(run_dir, "*/predictions.csv", splits=["train", "test"])
    .groupby("split_id", as_index=False)
    .apply(rescale_predictions)
    .assign(model=model + "-PT")
)

In [None]:
temp_ = model_results[model + "-PT"].query("split_group == 'train'")
temp_ = pd.concat(
    [
        temp_.assign(model="ScreenDL-SA"),
        temp_.assign(model="ScreenDL-SA (GDS)"),
    ]
)

name_map = {"ScreenDL-SA(MR)": "ScreenDL-SA (GDS)", "ScreenDL-SA": "ScreenDL-SA (Zd)"}
fix_names = lambda x: name_map.get(x, x)
model_results[model + "-SA"] = (
    load_multirun_predictions(run_dir, "*/predictions_sa.csv", splits=None)
    .pipe(lambda df: pd.concat([df, temp_]))
    .assign(model=lambda df: df["model"].map(fix_names))
    .groupby(["model", "split_id"], as_index=False)
    .apply(rescale_predictions)
)

In [None]:
model_results_df = pd.concat(model_results.values())
model_results_df_tst = model_results_df.query("split_group == 'test'")
model_results_df_tst.head()

In [None]:
pcc_metrics = (
    model_results_df_tst.groupby(["model", "drug_id"])
    .apply(lambda g: eval_utils.pcorr(g, "y_true", "y_pred"))
    .to_frame(name="pcc")
    .reset_index()
)

pcc_metrics.groupby("model")["pcc"].describe().loc[MODELS]

In [None]:
pcc_boxplot = (
    alt.Chart(pcc_metrics, width=30 * len(MODELS[-3:]), height=250)
    .transform_filter(alt.FieldOneOfPredicate("model", MODELS[-3:]))
    .mark_boxplot(**BOXPLOT_CONFIG)
    .encode(
        alt.X("model:N")
        .axis(labelAngle=-45, labelPadding=5)
        .sort([MODELS[-3], MODELS[-1], MODELS[-2]])
        .title(None),
        alt.Y("pcc:Q")
        .axis(titlePadding=10, tickCount=6, grid=False)
        .scale(domain=[-0.2, 1])
        .title("Pearson Correlation"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(None),
    )
)

configure_chart(pcc_boxplot)

In [None]:
X = cell_obs[["cell_id", "drug_id", "Zd", "GDS"]].copy()

# compute absolute error for GDS
X["GDS_ae"] = (X["GDS"] - X["Zd"]).abs()
X["GDS_ae_bin"] = pd.qcut(X["GDS_ae"], 30)
X.head()

- The expected result here is that, as the expected MAE of the GDS-only model increases (i.e., we move along the x-axis), we should see the observed MAE from each model cross the GDS-only line.
- This shows that ScreenAhead improves performance even for cell line-drug pairs for which GDS is not very predictive of drug response i.e., those cell line-drug pairs for which we observe exceptional sensitivity or resistance.
- In addition, the observation that the improvement for ScreenAhead is consistent across the plot from left to right suggests that ScreenAhead uses more information than just GDS. If ScreenAhead only used GDS, we would expect the improvement from ScreenAhead to to decrease as we moved from left to right.

In [None]:
binned_mae = (
    model_results_df_tst[["model", "cell_id", "drug_id", "y_pred"]]
    .merge(X, on=["cell_id", "drug_id"])
    .assign(model_ae=lambda df: (df["y_pred"] - df["Zd"]).abs())
    .groupby(["model", "GDS_ae_bin"])
    .aggregate({"GDS_ae": "mean", "model_ae": "mean"})
    .reset_index()
)

In [None]:
chart = (
    alt.Chart(binned_mae.drop(columns=["GDS_ae_bin"]))
    .mark_circle(size=50)
    .encode(
        alt.X("GDS_ae:Q")
        .axis(grid=False, titlePadding=10, values=[0, 0.5, 1, 1.5, 2, 2.5])
        .scale(domain=(0, 2.5))
        .title("Expected Mean Absolute Error (GDS Only Model)"),
        alt.Y("model_ae:Q")
        .axis(grid=False, titlePadding=10, values=[0, 0.5, 1, 1.5, 2, 2.5])
        .scale(domain=(0, 2.5))
        .title("Mean Absolute Error"),
        alt.Color("model:N", scale=MODEL_COLOR_SCALE)
        .legend(orient="none", legendX=50, legendY=-30, direction="horizontal", columns=3)
        .title(None),
    )
    .properties(width=420, height=250)
)

line = (
    alt.Chart(pd.DataFrame({"x": [0, 2.5], "y": [0, 2.5]}))
    .mark_line(stroke="black", strokeDash=[3, 3], strokeWidth=1)
    .encode(alt.X("x:Q"), alt.Y("y:Q"))
)

mae_chart = line + chart

configure_chart(mae_chart)

In [None]:
delta_ae_sa = (
    model_results_df_tst[["model", "cell_id", "drug_id", "y_pred", "y_true"]]
    .query("model.str.startswith('ScreenDL')")
    .assign(model_ae=lambda df: (df["y_pred"] - df["y_true"]).abs())
    .set_index(["cell_id", "drug_id", "model"])["model_ae"]
    .unstack()
    .assign(delta_ae=lambda df: df["ScreenDL-SA (Zd)"] - df["ScreenDL-PT"])
    .drop(columns=["ScreenDL-SA (GDS)"])
    .reset_index()
    .merge(X, on=["cell_id", "drug_id"])
    .assign(model="ScreenDL-SA (Zd)")
    .dropna()
)

delta_ae_sa_gds = (
    model_results_df_tst[["model", "cell_id", "drug_id", "y_pred", "y_true"]]
    .query("model.str.startswith('ScreenDL')")
    .assign(model_ae=lambda df: (df["y_pred"] - df["y_true"]).abs())
    .set_index(["cell_id", "drug_id", "model"])["model_ae"]
    .unstack()
    .assign(delta_ae=lambda df: df["ScreenDL-SA (GDS)"] - df["ScreenDL-PT"])
    .drop(columns=["ScreenDL-SA (Zd)"])
    .reset_index()
    .merge(X, on=["cell_id", "drug_id"])
    .assign(model="ScreenDL-SA (GDS)")
    .dropna()
)

In [None]:
def agg(df: pd.DataFrame) -> pd.Series:
    """Aggregators for computing mean and CI."""
    GDS_ae_mean = df["GDS_ae"].mean()
    delta_ae_mean = df["delta_ae"].mean()
    delta_ae_ci = stats.t.interval(
        alpha=0.95,
        df=len(df) - 1,
        loc=np.mean(df["delta_ae"]),
        scale=stats.sem(df["delta_ae"]),
    )
    return pd.Series(
        {
            "GDS_ae_mean": GDS_ae_mean,
            "delta_ae_mean": delta_ae_mean,
            "delta_ae_mean_lower": delta_ae_ci[0],
            "delta_ae_mean_upper": delta_ae_ci[1],
        }
    )


delta_ae_sa_agg = (
    delta_ae_sa.groupby("GDS_ae_bin")
    .apply(agg)
    .reset_index(drop=True)
    .assign(model="ScreenDL-SA")
)

delta_ae_sa_gds_agg = (
    delta_ae_sa_gds.groupby("GDS_ae_bin")
    .apply(agg)
    .reset_index(drop=True)
    .assign(model="ScreenDL-SA (GDS)")
)

delta_ae_agg = pd.concat([delta_ae_sa_agg, delta_ae_sa_gds_agg])

- The expected result here is that performance will improve the most when GDS is a good predictor and then performance will still improve but at a more constant rate when GDS is not as predictive and then converge to a constant. This constant probably corresponds to a value that correlates with the rate of positive transfer across drugs. It would be interesting to compare what this constant is when we use different drug selection methods.

In [None]:
base = (
    alt.Chart(delta_ae_sa_agg)
    .encode(
        alt.X("GDS_ae_mean:Q")
        .axis(grid=False, titlePadding=10, values=[0, 0.5, 1, 1.5, 2])
        .scale(nice=True)
        .title("Expected Mean Absolute Error (GDS Only Model)")
    )
    .properties(width=300, height=250)
)

points = base.mark_line(
    color="gray",
    opacity=1,
    # interpolate="basis",
    point=alt.MarkConfig(size=50, color="gray"),
).encode(
    alt.Y("delta_ae_mean:Q")
    .axis(grid=False, titlePadding=10, tickCount=5)
    .scale(zero=True)
    .title(["Difference in Absolute Error", "ScreenDL-SA (Zd) vs. ScreenDL-PT"]),
    # alt.Color("model:N", scale=MODEL_COLOR_SCALE).legend(values=MODELS[-2:])
)

interval = base.mark_area(color="lightgray", opacity=0.5).encode(
    alt.Y("delta_ae_mean_lower:Q")
    .axis(grid=False, titlePadding=10, tickCount=5)
    .scale(zero=False),
    alt.Y2("delta_ae_mean_upper:Q"),
)

delta_mae_chart = interval + points
configure_chart(delta_mae_chart)

In [None]:
upper_panel = alt.hconcat(cell_gds_chart, pdmc_gds_chart, spacing=40)
upper_panel = upper_panel.resolve_scale(color="independent")
lower_panel = alt.hconcat(pcc_boxplot, mae_chart, delta_mae_chart, spacing=40)
lower_panel = lower_panel.resolve_scale(color="independent")
configure_chart(alt.vconcat(upper_panel, lower_panel, spacing=40))