### Setup
Always run these cells before everything else

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%load_ext autoreload
%autoreload 2

In [None]:
# SETUP AND IMPORTS
import sys

wd_path = "/home/dk538/rds/hpc-work/graphdep"
sys.path.append(wd_path)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from utils.data_utils import Manual, get_data_loaders, process_data

# from models.baselines import SingleGeneLasso, SingleGeneLinear, SingleGeneSVR
from functools import partial
from numpy.linalg import norm
from utils.comptools import rep_renamer
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from scipy.stats import spearmanr

import matplotlib.font_manager as fm
import urllib.request

from utils.comp_utils import PerfComp

# Download the font
font_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-Regular.ttf?raw=True"
font_path = f"{wd_path}/results_analysis/figures/SourceSans3-Regular.ttf"  # Specify where to save the font
font_bold_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-Bold.ttf?raw=True"
font_bold_path = f"{wd_path}/results_analysis/figures/SourceSans3-Bold.ttf"  # Specify where to save the font
font_it_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-It.ttf?raw=True"
font_it_path = f"{wd_path}/results_analysis/figures/SourceSans3-It.ttf"  # Specify where to save the font
urllib.request.urlretrieve(font_url, font_path)
urllib.request.urlretrieve(font_bold_url, font_bold_path)
urllib.request.urlretrieve(font_it_url, font_it_path)

# in a terminal, run
# cp ~/rds/hpc-work/graphdep/results_analysis/figures/*ttf ~/.local/share/fonts
# fc-cache -f -v
# rm -fr ~/.cache/matplotlib

# Then restart Jupyter kernel

fm.findfont("Source Sans 3", rebuild_if_missing=True)
# fm.findfont("Source Sans 3:style=italic", rebuild_if_missing=True)

# Set font globally for Matplotlib
from matplotlib import rc

plt.style.use("default")
sns.set_theme(
    context="paper",
    style="ticks",
    palette="colorblind",
    rc={
        "axes.linewidth": 1,
        "xtick.major.width": 1,
        "ytick.major.width": 1,
        "axes.edgecolor": "grey",
        "xtick.labelcolor": "black",
        "xtick.color": "grey",
        "ytick.labelcolor": "black",
        "ytick.color": "grey",
    },
)

rc("font", **{"family": "sans-serif", "sans-serif": ["Source Sans 3"]})
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.it"] = "Source Sans 3:italic"


def type_renamer(x):
    x = "_".join((x.split(" ")))
    return "_".join(x.split("/"))


def get_similar_drugs(drug, drug_df, metric="cosine"):
    sim_dict_list = []
    for col in drug_df.columns:
        test_df = pd.DataFrame({drug: drug_df[drug], col: drug_df[col]})
        test_df = test_df.dropna(axis=0)
        if metric == "cosine":
            cos_sim = np.dot(test_df[drug], test_df[col]) / (
                norm(test_df[drug]) * norm(test_df[col])
            )
            sim_dict_list.append({"drugA": drug, "drugB": col, "sim": cos_sim})
        elif metric == "pearson":
            r, p = pearsonr(test_df[drug], test_df[col])
            sim_dict_list.append({"drugA": drug, "drugB": col, "sim": r})
    sim_df = pd.DataFrame.from_dict(sim_dict_list).sort_values(
        by="sim", ascending=False
    )
    return sim_df

# 1 iCoVAE representations

## 1.0 Data loading and summaries

### 1.0.1 Data loading

In [None]:
dataset_name = "depmap_gdsc"

# PROCESS DATASET
x, s, c, y, test_samples = process_data(
    dataset=dataset_name, wd_path=wd_path, experiment="h16"
)

dataset_params = {"var_filt_x": 1500, "var_filt_s": None}

dataset = Manual(
    x=x, s=s, y=y, constraints=["EGFR"], target="LAPATINIB", params=dataset_params
)

In [None]:
print(len(dataset.x_s_y_samples))
print(len(dataset.x_s_samples))
print(len(dataset.x_y_samples))
print(len(dataset.x_only_samples))
print(len(test_samples))

In [None]:
# Check for breast cancer drivers in DepMap knockouts
## For making 'Drivers' representation
# breast_drivers = pd.read_excel("../data/breast_cancer_drivers.xlsx", header=None)[
#     0
# ].tolist()

# breast_drivers_filt = set(breast_drivers).intersection(set(eff_df.columns))

# # Difference is in MLLT4, which has since been renamed to AFDN, so we add this back in
# print(set(breast_drivers).difference(breast_drivers_filt))

# breast_drivers_filt = sorted(list(breast_drivers_filt) + ["AFDN"])

# np.savetxt("../data/breast_cancer_drivers_filt.txt", breast_drivers_filt, fmt="%s")

In [None]:
# np.loadtxt("../data/breast_cancer_drivers_filt.txt", dtype="str").tolist()

### 1.0.2 Sample summaries

In [None]:
def type_renamer(x):
    x = "_".join((x.split(" ")))
    return "_".join(x.split("/"))

In [None]:
model = pd.read_csv(f"{wd_path}/data/depmap23q2/Model.csv")

model["OncotreeLineage"] = model["OncotreeLineage"].map(type_renamer)

model = model.set_index("ModelID")

lineage_labels = {
    "Bone": "Bone",
    "Peripheral_Nervous_System": "PNS",
    "Liver": "Liver",
    "Cervix": "Cervix",
    "Kidney": "Kidney",
    "Bladder_Urinary_Tract": "Bladder/UT",
    "Prostate": "Prostate",
    "Lung": "Lung",
    "Lymphoid": "Lymphoid",
    "CNS_Brain": "CNS/Brain",
    "Skin": "Skin",
    "Esophagus_Stomach": "Esophagus/Stomach",
    "Bowel": "Bowel",
    "Ovary_Fallopian_Tube": "Ovary/Fallopian Tube",
    "Head_and_Neck": "Head and Neck",
    "Pancreas": "Pancreas",
    "Breast": "Breast",
    "Soft_Tissue": "Soft Tissue",
    "Biliary_Tract": "Biliary Tract",
    "Eye": "Eye",
    "Thyroid": "Thyroid",
    "Ampulla_of_Vater": "Ampulla of Vater",
    "Testis": "Testis",
    "Vulva_Vagina": "Vulva/Vagina",
    "Myeloid": "Myeloid",
    "Uterus": "Uterus",
    "Fibroblast": "Fibroblast",
    "Pleura": "Pleura",
    "Adrenal_Gland": "Adrenal Gland",
    "Bladder_Urinary_Tract": "Bladder/UT",
}

pal = sns.color_palette("colorblind")

In [None]:
x_s_y_lineage = model.loc[dataset.x_s_y_samples][["OncotreeLineage"]]
x_s_y_lineage = (
    x_s_y_lineage.groupby("OncotreeLineage").size().reset_index(name="count")
)
x_s_y_lineage = x_s_y_lineage.sort_values("count", ascending=False)
print(x_s_y_lineage)
f, ax = plt.subplots(1, 1, figsize=(3, 4))
x_s_y_lineage_plot = x_s_y_lineage.copy()
x_s_y_lineage_plot["OncotreeLineage"] = x_s_y_lineage_plot["OncotreeLineage"].map(
    lambda x: lineage_labels[x]
)
sns.barplot(data=x_s_y_lineage_plot, y="OncotreeLineage", x="count", color=pal[0])
sns.despine()
ax.set_xlabel("Count")
ax.set_ylabel("")
ax.set_title("(x,s,y) samples")
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_y_lineages_new.svg",
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_y_lineages_new.png",
    bbox_inches="tight",
    dpi=600,
)

In [None]:
x_s_lineage = model.loc[dataset.x_s_samples][["OncotreeLineage"]]
x_s_lineage = x_s_lineage.groupby("OncotreeLineage").size().reset_index(name="count")
x_s_lineage = x_s_lineage.sort_values("count", ascending=False)
print(x_s_lineage)
f, ax = plt.subplots(1, 1, figsize=(3, 4))
x_s_lineage_plot = x_s_lineage.copy()
x_s_lineage_plot["OncotreeLineage"] = x_s_lineage_plot["OncotreeLineage"].map(
    lambda x: lineage_labels[x]
)
sns.barplot(data=x_s_lineage_plot, y="OncotreeLineage", x="count", color=pal[1])
sns.despine()
ax.set_xlabel("Count")
ax.set_ylabel("")
ax.set_title("(x,s) samples")
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_y_lineages_new.svg",
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_lineages_new.png",
    bbox_inches="tight",
    dpi=600,
)

In [None]:
x_y_lineage = model.loc[dataset.x_y_samples][["OncotreeLineage"]]
x_y_lineage = x_y_lineage.groupby("OncotreeLineage").size().reset_index(name="count")
x_y_lineage = x_y_lineage.sort_values("count", ascending=False)
print(x_y_lineage)
f, ax = plt.subplots(1, 1, figsize=(3, 4))
x_y_lineage_plot = x_y_lineage.copy()
x_y_lineage_plot["OncotreeLineage"] = x_y_lineage_plot["OncotreeLineage"].map(
    lambda x: lineage_labels[x]
)
sns.barplot(data=x_y_lineage_plot, y="OncotreeLineage", x="count", color=pal[2])
sns.despine()
ax.set_xlabel("Count")
ax.set_ylabel("")
ax.set_title("(x,y) samples")
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_y_lineages_new.svg",
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_y_lineages_new.png",
    bbox_inches="tight",
    dpi=600,
)

In [None]:
x_only_lineage = model.loc[dataset.x_only_samples][["OncotreeLineage"]]
x_only_lineage = (
    x_only_lineage.groupby("OncotreeLineage").size().reset_index(name="count")
)
x_only_lineage = x_only_lineage.sort_values("count", ascending=False)
print(x_only_lineage)
f, ax = plt.subplots(1, 1, figsize=(3, 4))
x_only_lineage_plot = x_only_lineage.copy()
x_only_lineage_plot["OncotreeLineage"] = x_only_lineage_plot["OncotreeLineage"].map(
    lambda x: lineage_labels[x]
)
sns.barplot(data=x_only_lineage_plot, y="OncotreeLineage", x="count", color=pal[3])
sns.despine()
ax.set_xlabel("Count")
ax.set_ylabel("")
ax.set_title("(x) samples")
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_s_y_lineages_new.svg",
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_x_only_lineages_new.png",
    bbox_inches="tight",
    dpi=600,
)

## 1.1 Evaluating representations

### 1.1.1 Representation PCA

In [None]:
target_dict = {
    "AZD5991": "MCL1/BCL2",
    "Alpelisib": "PI3K",
    "AZD8186": "PI3K",
    "Gefitinib": "EGFR/ERBB2",
    "Lapatinib": "EGFR/ERBB2",
    "Sorafenib": "VEGFR",
    "Docetaxel": "Taxane",
    "Paclitaxel": "Taxane",
    "Taselisib": "PI3K",
    "AZD6482": "PI3K",
    "Palbociclib": "CDK4/6",
    "AZD3759": "EGFR/ERBB2",
    "Afatinib": "EGFR/ERBB2",
    "Afuresertib": "AKT",
    "Serdemetan": "MDM2",
    "Oxaliplatin": "Platinum",
    "GSK1904529A": "IGF1R",
    "Buparlisib": "PI3K",
    "Linsitinib": "IGF1R",
    "Ipatasertib": "AKT",
    "CZC24832": "PI3K",
    "Sabutoclax": "MCL1/BCL2",
    "MK-8776": "CHEK1",
    "Ribociclib": "CDK4/6",
    "Cisplatin": "Platinum",
    "Osimertinib": "EGFR/ERBB2",
    "Erlotinib": "EGFR/ERBB2",
    "AZD6738": "ATR",
    "Olaparib": "PARP",
    "Niraparib": "PARP",
    "Veliparib": "PARP",
    "MK-1775": "WEE1",
    "Cyclophosphamide": "Alkylating agent",
    "5-Fluorouracil": "Antimetabolite",
    "Epirubicin": "Anthracycline",
    "Tamoxifen": "ER",
    "Methotrexate": "Antimetabolite",
    "Venetoclax": "MCL1/BCL2",
    "AZD5153": "BRD4",
    "JQ1": "BRD4",
    "PD173074": "FGFR",
    "Sapitinib": "EGFR/ERBB2",
    "AZD4547": "FGFR",
    "Vorinostat": "HDAC",
    "Refametinib": "MEK",
    "Selumetinib": "MEK",
    "Trametinib": "MEK",
    "Axitinib": "VEGFR",
    "GSK2830371A": "PPM1D",
    "CCT007093": "PPM1D",
    "Gemcitabine": "Antimetabolite",
    "Irinotecan": "TOP1",
    "VE-822": "ATR",
    "5-Fluorouracil": "Antimetabolite",
    "Crizotinib": "ALK/ROS1",
    "Cytarabine": "Antimetabolite",
    "Entinostat": "HDAC",
    "Foretinib": "VEGFR",
    "Fulvestrant": "ER",
    "Motesanib": "VEGFR",
    "Navitoclax": "MCL1/BCL2",
    "PD173074": "FGFR",
    "Pyridostatin": "G4",
    "Rapamycin": "mTOR",
    "Temsirolimus": "mTOR",
    "Tanespimycin": "HSP90",
    "Uprosertib": "AKT",
    "AZD5363": "AKT",
    "Dabrafenib": "BRAF",
    "Temozolomide": "Alkylating agent",
    "Vinblastine": "Vinca alkyloid",
    "Vinorelbine": "Vinca alkyloid",
}


# FULL DRUG LIST FOR H16
drugs = sorted(target_dict.keys())

In [None]:
## PLOT REPRESENTATIONS PCA

from matplotlib.lines import Line2D
import json
from copy import deepcopy
from tqdm import tqdm

rep_type = "TSNE"
drugs_plot = ["AZD6738", "LAPATINIB", "PALBOCICLIB"]
expi_type = "h16"
color_genes = ["ATR", "EGFR", "ERBB2"]
seeds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
# seeds = [20]

# Load representations for each drug
p_list = []
eff_pred_list = []
sil_list_test = []
sil_list_train = []
knn_acc_list = []
for drug in tqdm(drugs):
    # print(drug)
    # ccvae_folder = f"R:/graphdep/data/outputs/ElasticNet_{drug.upper()}_s{seed}_ccvae_{drug.upper()}_{dataset}_{target}_s{seed}_{expi_type}"
    # vae_folder = f"R:/graphdep/data/outputs/ElasticNet_{drug.upper()}_s{seed}_vae_{dataset}_{target}_s{seed}_{expi_type}"
    ccvae_folder = f"{wd_path}/data/outputs/{dataset_name}/{drug.upper()}_new/{expi_type}/pico/ElasticNet_icovae"
    vae_folder = f"{wd_path}/data/outputs/{dataset_name}/{drug.upper()}_new/{expi_type}/pico/ElasticNet_vae"
    try:
        with open(f"{ccvae_folder}/args_best_s10.txt") as f:
            args = json.load(f)
        # with torch.no_grad():
        #     reg_model = torch.load(f"{ccvae_folder}/regressor_s{seed}.pt", map_location=torch.device("cpu"))
        #     print(reg_model)
        #     reg_weights = reg_model.reg.mean_coeff.numpy()
        #     reg_bias = reg_model.reg.mean_bias.numpy()
        # print(reg_weights)
        # print(reg_bias)
        with open(f"{vae_folder}/args_best_s10.txt") as f:
            vae_args = json.load(f)
    except:
        print(f"{drug} not found...")
        continue

    # resp = resp.dropna(axis=0).reset_index(drop=True)
    # zdim = args["zdim"]
    # vae_zdim = vae_args["zdim"]
    genes = args["constraints"]
    # genes = ast.literal_eval(genes)
    genes = [gene.strip() for gene in genes]

    # print(" ".join(genes))
    n_genes = len(genes)
    dataset = Manual(
        x=x, s=s, y=y, constraints=genes, target=drug, params=dataset_params
    )

    for seed in seeds:
        # Consider constrained part of CCVAE and unconstrained separately
        ccvae_test_z = pd.read_csv(f"{ccvae_folder}/z_pred_test_s{seed}.csv")
        ccvae_train_z = pd.read_csv(f"{ccvae_folder}/z_pred_train_s{seed}.csv")

        # Consider whole VAE rep
        vae_test_z = pd.read_csv(f"{vae_folder}/z_pred_test_s{seed}.csv")
        vae_train_z = pd.read_csv(f"{vae_folder}/z_pred_train_s{seed}.csv")

        # Get x after variance filtering
        data_loaders = get_data_loaders(
            dataset,
            test_samples,
            32,
            fold=0,
            seed=seed,
            val_split=0.2,
            stage="p",
            save_folder=None,
            hopt=False,
            verbose=True,
        )

        # Get x_train and x_test from subsets data_loaders["train"].dataset and data_loaders["test"].dataset
        # Cannot use .x directly on Subset
        x_train = data_loaders["train"].dataset.dataset.x[
            data_loaders["train"].dataset.indices
        ]
        x_test = data_loaders["test"].dataset.dataset.x[
            data_loaders["test"].dataset.indices
        ]
        pca = PCA(n_components=32)
        pca_train_z = pca.fit_transform(x_train)
        pca_test_z = pca.transform(x_test)
        pca_train_z = pd.DataFrame(
            pca_train_z, columns=[f"pc_{i}" for i in range(pca_train_z.shape[1])]
        )
        pca_test_z = pd.DataFrame(
            pca_test_z, columns=[f"pc_{i}" for i in range(pca_test_z.shape[1])]
        )
        pca_train_z["ind"] = ccvae_train_z["ind"].values.tolist()
        pca_test_z["ind"] = ccvae_test_z["ind"].values.tolist()

        ccvae_test_z["ModelID"] = ccvae_test_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )
        vae_test_z["ModelID"] = ccvae_test_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )
        ccvae_train_z["ModelID"] = ccvae_train_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )
        vae_train_z["ModelID"] = ccvae_train_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )
        pca_test_z["ModelID"] = pca_test_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )
        pca_train_z["ModelID"] = pca_train_z["ind"].apply(
            lambda x: dataset.idx_to_sample[x]
        )

        ccvae_test_z = ccvae_test_z.set_index("ModelID")
        vae_test_z = vae_test_z.set_index("ModelID")
        ccvae_train_z = ccvae_train_z.set_index("ModelID")
        vae_train_z = vae_train_z.set_index("ModelID")
        pca_test_z = pca_test_z.set_index("ModelID")
        pca_train_z = pca_train_z.set_index("ModelID")

        def add_type_info(df, model):
            df = pd.merge(
                df,
                model[["OncotreeLineage", "LegacySubSubtype"]],
                left_index=True,
                right_index=True,
                how="left",
            )
            return df

        # Add type info
        ccvae_test_z = add_type_info(ccvae_test_z, model)
        vae_test_z = add_type_info(vae_test_z, model)
        ccvae_train_z = add_type_info(ccvae_train_z, model)
        vae_train_z = add_type_info(vae_train_z, model)
        pca_test_z = add_type_info(pca_test_z, model)
        pca_train_z = add_type_info(pca_train_z, model)

        s_copy = deepcopy(dataset.s)
        s_df = pd.DataFrame(s_copy, columns=dataset.s_features).reset_index()
        s_df = s_df.rename(lambda x: dataset.idx_to_sample[x], axis=0)

        def add_dep_info(df, dep_df):
            df = pd.merge(
                df,
                dep_df,
                left_index=True,
                right_index=True,
                how="left",
            )
            return df

        # Add dependency info if we have it
        ccvae_test_z = add_dep_info(ccvae_test_z, s_df)
        ccvae_train_z = add_dep_info(ccvae_train_z, s_df)
        vae_test_z = add_dep_info(vae_test_z, s_df)
        vae_train_z = add_dep_info(vae_train_z, s_df)
        pca_test_z = add_dep_info(pca_test_z, s_df)
        pca_train_z = add_dep_info(pca_train_z, s_df)

        # Add mutation info if we have it
        # genes_mut = [gene for gene in genes if gene in mut_df.columns]
        # if len(genes_mut) > 0:
        #     ccvae_test_z = pd.merge(ccvae_test_z, mut_df[genes_mut + ["ModelID"]], left_on="ModelID", right_on="ModelID", how="left", suffixes=("", "_mut")).drop("ModelID", axis=1)
        #     vae_test_z = pd.merge(vae_test_z, mut_df[genes_mut + ["ModelID"]], left_on="ModelID", right_on="ModelID", how="left", suffixes=("", "_mut")).drop("ModelID", axis=1)

        # ccvae_train_z = pd.merge(ccvae_train_z, mut_df[genes_mut + ["ModelID"]], left_on="ModelID", right_on="ModelID", how="left", suffixes=("", "_mut")).drop("ModelID", axis=1)
        # vae_train_z = pd.merge(vae_train_z, mut_df[genes_mut + ["ModelID"]], left_on="ModelID", right_on="ModelID", how="left", suffixes=("", "_mut")).drop("ModelID", axis=1)

        # Get rep
        ccvae_test_z_rep = ccvae_test_z.iloc[
            :, ccvae_test_z.columns.str.startswith("z")
        ]
        ccvae_test_z_rep = ccvae_test_z_rep.rename(
            mapper=partial(rep_renamer, genes=genes), axis=1
        )

        ccvae_train_z_rep = ccvae_train_z.iloc[
            :, ccvae_train_z.columns.str.startswith("z")
        ]
        ccvae_train_z_rep = ccvae_train_z_rep.rename(
            mapper=partial(rep_renamer, genes=genes), axis=1
        )

        cons_cols = [f"z_{gene}" for gene in genes]

        ccvae_test_z_rep_c = ccvae_test_z_rep.iloc[
            :, ccvae_test_z_rep.columns.isin(cons_cols)
        ]
        ccvae_test_z_rep = ccvae_test_z_rep.iloc[
            :, ~ccvae_test_z_rep.columns.isin(cons_cols)
        ]

        ccvae_train_z_rep_c = ccvae_train_z_rep.iloc[
            :, ccvae_train_z_rep.columns.isin(cons_cols)
        ]
        ccvae_train_z_rep = ccvae_train_z_rep.iloc[
            :, ~ccvae_train_z_rep.columns.isin(cons_cols)
        ]

        vae_test_z_rep = vae_test_z.iloc[:, vae_test_z.columns.str.startswith("z")]
        vae_train_z_rep = vae_train_z.iloc[:, vae_train_z.columns.str.startswith("z")]

        # CALCULATE SILHOUETTE SCORES FOR REPS AND CANCER TYPE
        from sklearn.metrics import silhouette_samples
        from sklearn.preprocessing import StandardScaler
        # NORMALIZE VALUES

        scaler_c = StandardScaler()
        ccvae_train_z_scaled = scaler_c.fit_transform(
            ccvae_train_z.iloc[:, ccvae_train_z.columns.str.startswith("z")]
        )
        scaler_vae = StandardScaler()
        vae_train_z_scaled = scaler_vae.fit_transform(
            vae_train_z.iloc[:, vae_train_z.columns.str.startswith("z")]
        )
        scaler_pca = StandardScaler()
        pca_train_z_scaled = scaler_pca.fit_transform(
            pca_train_z.iloc[:, pca_train_z.columns.str.startswith("pc")]
        )

        ccvae_test_z_scaled = scaler_c.transform(
            ccvae_test_z.iloc[:, ccvae_test_z.columns.str.startswith("z")]
        )
        vae_test_z_scaled = scaler_vae.transform(
            vae_test_z.iloc[:, vae_test_z.columns.str.startswith("z")]
        )
        pca_test_z_scaled = scaler_pca.transform(
            pca_test_z.iloc[:, pca_test_z.columns.str.startswith("pc")]
        )

        # TRAINING SET
        ccvae_train_silhouette = silhouette_samples(
            ccvae_train_z_scaled, ccvae_train_z["OncotreeLineage"]
        )
        vae_train_silhouette = silhouette_samples(
            vae_train_z_scaled, vae_train_z["OncotreeLineage"]
        )
        pca_train_silhouette = silhouette_samples(
            pca_train_z_scaled, pca_train_z["OncotreeLineage"]
        )
        # TEST SET
        ccvae_test_silhouette = silhouette_samples(
            ccvae_test_z_scaled, ccvae_test_z["OncotreeLineage"]
        )
        vae_test_silhouette = silhouette_samples(
            vae_test_z_scaled, vae_test_z["OncotreeLineage"]
        )
        pca_test_silhouette = silhouette_samples(
            pca_test_z_scaled, pca_test_z["OncotreeLineage"]
        )
        for i, label in enumerate(ccvae_test_z["OncotreeLineage"]):
            sil_list_test.append(
                {
                    "label": label,
                    "drug": drug,
                    "seed": seed,
                    "ccvae_test_silhouette": ccvae_test_silhouette[i],
                    "vae_test_silhouette": vae_test_silhouette[i],
                    "pca_test_silhouette": pca_test_silhouette[i],
                }
            )

        for i, label in enumerate(ccvae_train_z["OncotreeLineage"]):
            sil_list_train.append(
                {
                    "label": label,
                    "drug": drug,
                    "seed": seed,
                    "ccvae_train_silhouette": ccvae_train_silhouette[i],
                    "vae_train_silhouette": vae_train_silhouette[i],
                    "pca_train_silhouette": pca_train_silhouette[i],
                }
            )

        # SIMPLE KNN PERFORMANCE
        from sklearn.neighbors import KNeighborsClassifier
        from sklearn.model_selection import cross_val_score

        # TRAINING SET
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, ccvae_train_z_scaled, ccvae_train_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "train",
                "model": "iCoVAE",
                "knn_acc": np.mean(scores),
            }
        )
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, vae_train_z_scaled, vae_train_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "train",
                "model": "VAE",
                "knn_acc": np.mean(scores),
            }
        )
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, pca_train_z_scaled, pca_train_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "train",
                "model": "PCA",
                "knn_acc": np.mean(scores),
            }
        )
        # TEST SET
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, ccvae_test_z_scaled, ccvae_test_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "test",
                "model": "iCoVAE",
                "knn_acc": np.mean(scores),
            }
        )
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, vae_test_z_scaled, vae_test_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "test",
                "model": "VAE",
                "knn_acc": np.mean(scores),
            }
        )
        clf = KNeighborsClassifier(n_neighbors=5)
        scores = cross_val_score(
            clf, pca_test_z_scaled, pca_test_z["OncotreeLineage"], cv=5
        )
        knn_acc_list.append(
            {
                "drug": drug,
                "seed": seed,
                "dataset": "test",
                "model": "PCA",
                "knn_acc": np.mean(scores),
            }
        )

        if rep_type == "TSNE" and drug in drugs_plot:
            ccvae_tsne_c_train = TSNE(n_components=2, perplexity=10).fit_transform(
                ccvae_train_z_rep_c
            )
            ccvae_tsne_train = TSNE(n_components=2, perplexity=10).fit_transform(
                ccvae_train_z_rep
            )
            vae_tsne_train = TSNE(n_components=2, perplexity=10).fit_transform(
                vae_train_z_rep
            )

            ccvae_tsne_c = TSNE(n_components=2, perplexity=10).fit_transform(
                ccvae_test_z_rep_c
            )
            ccvae_tsne = TSNE(n_components=2, perplexity=10).fit_transform(
                ccvae_test_z_rep
            )
            vae_tsne = TSNE(n_components=2, perplexity=10).fit_transform(vae_test_z_rep)

            ccvae_train_rep_c = pd.DataFrame(
                ccvae_tsne_c_train,
                columns=["TSNE0", "TSNE1"],
                index=ccvae_train_z_rep_c.index,
            )
            ccvae_train_rep = pd.DataFrame(
                ccvae_tsne_train,
                columns=["TSNE0", "TSNE1"],
                index=ccvae_train_z_rep.index,
            )
            vae_train_rep = pd.DataFrame(
                vae_tsne_train,
                columns=["TSNE0", "TSNE1"],
                index=vae_train_z_rep.index,
            )

            ccvae_test_rep_c = pd.DataFrame(
                ccvae_tsne_c,
                columns=["TSNE0", "TSNE1"],
                index=ccvae_test_z_rep_c.index,
            )
            ccvae_test_rep = pd.DataFrame(
                ccvae_tsne,
                columns=["TSNE0", "TSNE1"],
                index=ccvae_test_z_rep.index,
            )
            vae_test_rep = pd.DataFrame(
                vae_tsne,
                columns=["TSNE0", "TSNE1"],
                index=vae_test_z_rep.index,
            )

            x_lab = "TSNE0"
            y_lab = "TSNE1"

        elif rep_type == "UMAP" and drug in drugs_plot:
            ccvae_tsne_c = TSNE(n_components=2, perplexity=10).fit(ccvae_train_z_rep_c)
            ccvae_tsne = TSNE(n_components=2, perplexity=10).fit(ccvae_train_z_rep)
            vae_tsne = TSNE(n_components=2, perplexity=10).fit(vae_train_z_rep)

            ccvae_train_rep_c = pd.DataFrame(
                ccvae_tsne_c.transform(ccvae_train_z_rep_c),
                columns=["TSNE_0", "TSNE_1"],
            )
            ccvae_train_rep = pd.DataFrame(
                ccvae_tsne.transform(ccvae_train_z_rep), columns=["TSNE_0", "TSNE_1"]
            )
            vae_train_rep = pd.DataFrame(
                vae_tsne.transform(vae_train_z_rep), columns=["TSNE_0", "TSNE_1"]
            )

            ccvae_test_rep_c = pd.DataFrame(
                ccvae_tsne_c.transform(ccvae_test_z_rep_c), columns=["TSNE_0", "TSNE_1"]
            )
            ccvae_test_rep = pd.DataFrame(
                ccvae_tsne.transform(ccvae_test_z_rep), columns=["TSNE_0", "TSNE_1"]
            )
            vae_test_rep = pd.DataFrame(
                vae_tsne.transform(vae_test_z_rep), columns=["TSNE_0", "TSNE_1"]
            )

            x_lab = "TSNE1"
            y_lab = "TSNE2"

        elif rep_type == "PCA" and drug in drugs_plot:
            ccvae_tsne_c = PCA(n_components=4).fit(ccvae_train_z_rep_c)
            ccvae_tsne = PCA(n_components=4).fit(ccvae_train_z_rep)
            vae_tsne = PCA(n_components=4).fit(vae_train_z_rep)

            ccvae_train_rep_c = pd.DataFrame(
                ccvae_tsne_c.transform(ccvae_train_z_rep_c),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=ccvae_train_z_rep_c.index,
            )
            ccvae_train_rep = pd.DataFrame(
                ccvae_tsne.transform(ccvae_train_z_rep),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=ccvae_train_z_rep.index,
            )
            vae_train_rep = pd.DataFrame(
                vae_tsne.transform(vae_train_z_rep),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=vae_train_z_rep.index,
            )

            ccvae_test_rep_c = pd.DataFrame(
                ccvae_tsne_c.transform(ccvae_test_z_rep_c),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=ccvae_test_z_rep_c.index,
            )
            ccvae_test_rep = pd.DataFrame(
                ccvae_tsne.transform(ccvae_test_z_rep),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=ccvae_test_z_rep.index,
            )
            vae_test_rep = pd.DataFrame(
                vae_tsne.transform(vae_test_z_rep),
                columns=["PC1", "PC2", "PC3", "PC4"],
                index=vae_test_z_rep.index,
            )

            x_lab = "PC1"
            y_lab = "PC2"
            y_lab2 = "PC3"
            y_lab3 = "PC4"

        if drug in drugs_plot:
            ccvae_test_z_c = pd.concat([ccvae_test_z, ccvae_test_rep_c], axis=1)
            ccvae_test_z = pd.concat([ccvae_test_z, ccvae_test_rep], axis=1)
            vae_test_z = pd.concat([vae_test_z, vae_test_rep], axis=1)
        # ccvae_test_z_c["set"] = "test"
        ccvae_test_z["set"] = "test"
        vae_test_z["set"] = "test"

        if drug in drugs_plot:
            ccvae_train_z_c = pd.concat([ccvae_train_z, ccvae_train_rep_c], axis=1)
            ccvae_train_z = pd.concat([ccvae_train_z, ccvae_train_rep], axis=1)
            vae_train_z = pd.concat([vae_train_z, vae_train_rep], axis=1)
        # ccvae_train_z_c["set"] = "train"
        ccvae_train_z["set"] = "train"
        vae_train_z["set"] = "train"

        # ccvae_z_c = pd.concat([ccvae_test_z_c, ccvae_train_z_c], axis=0)
        ccvae_z = pd.concat([ccvae_test_z, ccvae_train_z], axis=0)
        vae_z = pd.concat([vae_test_z, vae_train_z], axis=0)

        # Create sequential palettes
        pal0 = sns.light_palette(pal[0], as_cmap=True, reverse=True)
        pal1 = sns.light_palette(pal[1], as_cmap=True, reverse=True)
        pal2 = sns.light_palette(pal[2], as_cmap=True, reverse=True)

        pal_point = sns.color_palette("colorblind", n_colors=3).as_hex()

        bar_pal = {gene: "#888888" for gene in genes if gene not in color_genes}
        for i, gene in enumerate(color_genes):
            bar_pal[gene] = pal_point[i]

        # KO PREDICTION SCATTERPLOTS
        # ccvae_z_c_plot = ccvae_test_z_c.drop("LegacySubSubtype", axis=1).dropna(axis=0)
        ccvae_z_plot = ccvae_test_z.drop("LegacySubSubtype", axis=1).dropna(axis=0)
        vae_z_plot = vae_test_z.drop("LegacySubSubtype", axis=1).dropna(axis=0)
        # f_sc, ax_sc = plt.subplots(1,3, figsize=(9,3))
        for i, gene in enumerate(genes):
            # sns.scatterplot(ccvae_z_c_plot, x=f"z_{i}", y=genes[i], hue="OncotreeLineage", style="set", ax=axes_1[i], size=10, markers=["D", "o"], legend=False)
            try:
                r_p, _ = pearsonr(ccvae_z_plot[f"z_{i}"], ccvae_z_plot[f"{gene}_s"])
                r_s, _ = spearmanr(ccvae_z_plot[f"z_{i}"], ccvae_z_plot[f"{gene}_s"])
                eff_pred_list.append(
                    {
                        "gene": gene,
                        "pearson_r": r_p,
                        "spearman_r": r_s,
                        "seed": seed,
                        "drug": drug,
                    }
                )
            except:
                print(f"{drug} failed...")
                continue

        if drug in drugs_plot:
            ccvae_z_c_plot_rep = ccvae_test_z_c.drop(
                ["LegacySubSubtype"] + [f"{gene}_s" for gene in genes], axis=1
            ).dropna(axis=0)
            ccvae_z_plot_rep = ccvae_test_z.drop(
                ["LegacySubSubtype"] + [f"{gene}_s" for gene in genes], axis=1
            ).dropna(axis=0)
            vae_z_plot_rep = vae_test_z.drop(
                ["LegacySubSubtype"] + [f"{gene}_s" for gene in genes], axis=1
            ).dropna(axis=0)

            # print(len(ccvae_z_plot_rep))

            pal = sns.color_palette("colorblind")
            all_lineages = ccvae_z_plot_rep["OncotreeLineage"].unique()
            col_lineages = [
                "Bone",
                "Peripheral Nervous System",
                "Liver",
                "Cervix",
                "Kidney",
                "Bladder/Urinary Tract",
                "Fibroblast",
                "Uterus",
                "Pleura",
                "Thyroid",
            ]
            lineage_labels = {
                "Bone": "Bone",
                "Peripheral Nervous System": "PNS",
                "Liver": "Liver",
                "Cervix": "Cervix",
                "Kidney": "Kidney",
                "Bladder/Urinary Tract": "Bladder/UT",
                "Fibroblast": "Fibroblast",
                "Uterus": "Uterus",
                "Pleura": "Pleura",
                "Thyroid": "Thyroid",
            }
            pal_reps = {}
            for lin in all_lineages:
                if lin in col_lineages:
                    pal_reps[lin] = pal[col_lineages.index(lin)]
                else:
                    pal_reps[lin] = "lightgrey"

            # PCS VS LINEAGE AND DRUG RESPONSE
            fig_1, axes_1 = plt.subplots(1, 2, figsize=(8, 3))
            # sns.scatterplot(ccvae_z_c_plot_rep, x=x_lab, y=y_lab, hue="OncotreeLineage", palette=pal_reps, style="set", size=drug.upper(), sizes=(1,100), ax=axes_1[0], markers=["o", "D"], legend=False, **{"alpha":1})
            # axes_1[0].set_xticks([])
            # axes_1[0].set_yticks([])
            # axes_1[0].set_xlabel(x_lab)
            # axes_1[0].set_ylabel(y_lab)
            # axes_1[0].set_title("iCoVAE linked", fontweight="bold")
            # sns.despine(left=True, bottom=True, ax=axes_1[0])
            sns.scatterplot(
                data=ccvae_z_plot_rep,
                x=x_lab,
                y=y_lab,
                hue="OncotreeLineage",
                palette=pal_reps,
                s=30,
                ax=axes_1[0],
                markers=["o", "D"],
                legend=False,
                **{"alpha": 1},
            )
            axes_1[0].set_xticks([])
            axes_1[0].set_yticks([])
            axes_1[0].set_xlabel(x_lab)
            axes_1[0].set_ylabel(y_lab)
            axes_1[0].set_title("iCoVAE", fontweight="regular", fontsize=10)
            sns.despine(left=True, bottom=True, ax=axes_1[0])
            # print(vae_z_plot_rep)
            sns.scatterplot(
                data=vae_z_plot_rep,
                x=x_lab,
                y=y_lab,
                hue="OncotreeLineage",
                palette=pal_reps,
                s=30,
                ax=axes_1[1],
                markers=["o", "D"],
                legend="brief",
                **{"alpha": 1},
            )
            axes_1[1].set_xticks([])
            axes_1[1].set_yticks([])
            axes_1[1].set_xlabel(x_lab)
            axes_1[1].set_ylabel(y_lab)
            axes_1[1].set_title("VAE", fontweight="regular", fontsize=10)
            sns.despine(left=True, bottom=True, ax=axes_1[1])
            # fig.legend(lines, labels, loc="upper left", bbox_to_anchor=(1,1), frameon="false")
            axes_1[1].get_legend().remove()

            handles = []
            labels = []
            for lin, col in pal_reps.items():
                if lin in col_lineages:
                    handles.append(
                        Line2D(
                            [0],
                            [0],
                            label=lineage_labels[lin],
                            marker="o",
                            markersize=4,
                            markeredgecolor="none",
                            markerfacecolor=col,
                            linestyle="",
                        )
                    )

            handles.append(
                Line2D(
                    [0],
                    [0],
                    label="Other",
                    marker="o",
                    markersize=4,
                    markeredgecolor="none",
                    markerfacecolor="lightgrey",
                    linestyle="",
                )
            )

            plt.subplots_adjust(wspace=0.65)

            leg1 = axes_1[0].legend(
                handles=handles,
                bbox_to_anchor=(1, 0.5),
                loc="center left",
                frameon=False,
                ncol=1,
                fontsize=10,
            )

            axes_1[0].add_artist(leg1)

            # fig_1._legend.remove()

            # f_1, axes_1 = plt.subplots(n_genes, 1, figsize=(1, 12))

            fig_1.savefig(
                f"{wd_path}/results_analysis/figures/{dataset_name}_rep_{drug}_{expi_type}_{seed}_new.png",
                dpi=600,
                bbox_extra_artists=(leg1,),
            )
            fig_1.savefig(
                f"{wd_path}/results_analysis/figures/{dataset_name}_rep_{drug}_{expi_type}_{seed}_new.svg",
                bbox_extra_artists=(leg1,),
            )

In [None]:
# LINEAGES HISTPLOT - THIS WILL BE FOR THE FINAL SEED
f, ax = plt.subplots(1, 1, figsize=(4, 5))
ccvae_z_summ = ccvae_z.copy()
ccvae_z_summ["OncotreeLineage"] = ccvae_z_summ["OncotreeLineage"].map(
    lambda x: lineage_labels[x]
)
sns.histplot(
    data=ccvae_z_summ,
    y="OncotreeLineage",
    hue="set",
    ax=ax,
    hue_order=["train", "test"],
)
sns.despine()
sns.move_legend(
    ax, loc="upper center", frameon=False, ncol=2, title="", bbox_to_anchor=(0.5, 1.05)
)
ax.set_xlabel("Count")
ax.set_ylabel("")
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_h16_lineages_new.svg",
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_h16_lineages_new.png",
    bbox_inches="tight",
    dpi=600,
)

In [None]:
knn_acc_list_df = pd.DataFrame.from_dict(knn_acc_list)
knn_acc_plot = knn_acc_list_df.groupby(["drug", "model", "dataset"]).mean()

# Plot KNN accuracies
fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
sns.barplot(
    data=knn_acc_plot.reset_index(),
    x="dataset",
    y="knn_acc",
    hue="model",
    hue_order=["iCoVAE", "VAE", "PCA"],
    order=["train", "test"],
    ax=ax,
    width=0.6,
)
ax.set_xticklabels(["Common", "Rare"])
ax.set_ylabel("Accuracy")
ax.set_xlabel("Dataset")
sns.despine(ax=ax)

h, l = ax.get_legend_handles_labels()
ax.legend(
    handles=h,
    labels=["iCoVAE", "VAE", "PCA"],
    title="Feature extractor",
    fontsize="medium",
    ncol=3,
    frameon=False,
    bbox_to_anchor=(0.475, 1.3),
    loc="upper center",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_knn_acc.png",
    dpi=600,
    bbox_inches="tight",
)
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_knn_acc.svg",
    bbox_inches="tight",
)

### 1.1.2 Prediction of gene effect

In [None]:
gene_corr_df = pd.DataFrame(eff_pred_list)
print(gene_corr_df.groupby(["drug", "gene"]).mean())

In [None]:
import urllib

# All gene effect predictions
fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
gene_corr_df_plot = gene_corr_df.copy().groupby(["gene"]).mean()
gene_corr_df_plot["gene_type"] = "Other"
print(gene_corr_df_plot)
palette = sns.color_palette("colorblind")

# Add gene types oncogene or tumor suppressor from OncoKB
# Download from https://www.oncokb.org/9122534a-044d-4693-8d5c-ebc10a822e7c

oncokb_df = pd.read_csv(
    f"{wd_path}/data/downloads/oncokb_cancer_gene_list.tsv", sep="\t"
)
# Possible entries in Gene Type are ONCOGENE, TSG, ONCOGENE_AND_TSG, AND INSUFFICIENT_EVIDENCE
for i, row in gene_corr_df_plot.iterrows():
    gene = i
    if gene in oncokb_df["Hugo Symbol"].values:
        gene_type = oncokb_df[oncokb_df["Hugo Symbol"] == gene]["Gene Type"].values[0]
        if gene_type == "ONCOGENE":
            gene_corr_df_plot.loc[i, "gene_type"] = "Oncogene"
        elif gene_type == "TSG":
            gene_corr_df_plot.loc[i, "gene_type"] = "TSG"
        elif gene_type == "ONCOGENE_AND_TSG":
            # Choose ONCOGENE if both
            gene_corr_df_plot.loc[i, "gene_type"] = "Oncogene"
        elif gene_type == "INSUFFICIENT_EVIDENCE":
            gene_corr_df_plot.loc[i, "gene_type"] = "Other"
        else:
            gene_corr_df_plot.loc[i, "gene_type"] = "Other"


g = sns.boxplot(
    data=gene_corr_df_plot,
    x="gene_type",
    y="spearman_r",
    hue="gene_type",
    palette={"Oncogene": palette[0], "TSG": palette[1], "Other": palette[2]},
    order=["Oncogene", "TSG", "Other"],
    ax=ax,
    whis=1.5,
    width=0.6,
    fliersize=0,
)
# Overlay stripplot aligned with boxplot
sns.stripplot(
    data=gene_corr_df_plot,
    x="gene_type",
    y="spearman_r",
    palette={"Oncogene": palette[0], "TSG": palette[1], "Other": palette[2]},
    order=["Oncogene", "TSG", "Other"],
    ax=ax,
    size=5,
    alpha=0.7,
    jitter=False,
)
sns.despine(ax=ax)
ax.set_ylabel("Spearman correlation")
ax.set_xlabel("")

# Label outliers using sns boxplot stats
gene_type_cutoffs = {"Oncogene": 0.5, "TSG": 1.2, "Other": 1.72}
for gene_type in ["Oncogene", "TSG", "Other"]:
    data = gene_corr_df_plot[gene_corr_df_plot["gene_type"] == gene_type]["spearman_r"]
    q1 = data.quantile(0.25)
    q3 = data.quantile(0.75)
    iqr = q3 - q1
    lower_bound = q1 - gene_type_cutoffs[gene_type] * iqr
    upper_bound = q3 + gene_type_cutoffs[gene_type] * iqr
    outliers = data[(data > upper_bound)]
    for i, outlier in enumerate(outliers.items()):
        # Make text labels not overlap as much, shift randomly left/right and up/down and add a line from the text to the point
        gene_name = outlier[0]
        y = outlier[1]
        x = ["Oncogene", "TSG", "Other"].index(gene_type)
        # Random shifts
        x_shift = [0.0, 0.4, -0.4][i]
        y_shift = [0.05, 0.15, 0.15][i]
        ax.text(
            x + x_shift,
            y + y_shift,
            gene_name,
            horizontalalignment="center",
            fontsize=8,
        )
        ax.plot([x, x + x_shift], [y, y + y_shift], color="grey", linewidth=0.5)

# Add dashed line at y = 0
ax.axhline(0, color="grey", linestyle="--", linewidth=1)

plt.tight_layout()

# save as png
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_gene_effect_correlation_boxplot_new.png",
    dpi=600,
    bbox_inches="tight",
)
# save as svg
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_gene_effect_correlation_boxplot_new.svg",
    bbox_inches="tight",
)

In [None]:
# Barplot of number of genes in each category -- majority of genes are not recorded as oncogene or TSG in OncoKB
fig, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))
gene_corr_df_plot_count = (
    gene_corr_df_plot.groupby("gene_type").size().reset_index(name="counts")
)
sns.barplot(
    data=gene_corr_df_plot_count,
    x="gene_type",
    y="counts",
    palette={"Oncogene": palette[0], "TSG": palette[1], "Other": palette[2]},
    order=["Oncogene", "TSG", "Other"],
    ax=ax,
    width=0.6,
)
sns.despine(ax=ax)
ax.set_ylabel("Number of genes")
ax.set_xlabel("")
plt.tight_layout()
# save as png
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_constraint_counts_new.png",
    dpi=600,
    bbox_inches="tight",
)
# save as svg
plt.savefig(
    f"{wd_path}/results_analysis/figures/{dataset_name}_constraint_counts_new.svg",
    bbox_inches="tight",
)

# 2 Performance comparisons

## 2.1 Prediction performance on DepMap/GDSC experiments

### 2.1.1 Comparison across all drugs

In [None]:
target_dict = {
    "AZD5991": "MCL1/BCL2",
    "Alpelisib": "PI3K",
    "AZD8186": "PI3K",
    "Gefitinib": "EGFR/ERBB2",
    "Lapatinib": "EGFR/ERBB2",
    "Sorafenib": "VEGFR",
    "Docetaxel": "Taxane",
    "Paclitaxel": "Taxane",
    "Taselisib": "PI3K",
    # "AZD6482": "PI3K",
    "Palbociclib": "CDK4/6",
    "AZD3759": "EGFR/ERBB2",
    "Afatinib": "EGFR/ERBB2",
    "Afuresertib": "AKT",
    "Serdemetan": "MDM2",
    "Oxaliplatin": "Platinum",
    "GSK1904529A": "IGF1R",
    "Buparlisib": "PI3K",
    "Linsitinib": "IGF1R",
    "Ipatasertib": "AKT",
    "CZC24832": "PI3K",
    "Sabutoclax": "MCL1/BCL2",
    "MK-8776": "CHEK1",
    "Ribociclib": "CDK4/6",
    "Cisplatin": "Platinum",
    "Osimertinib": "EGFR/ERBB2",
    "Erlotinib": "EGFR/ERBB2",
    "AZD6738": "ATR",
    "Olaparib": "PARP",
    "Niraparib": "PARP",
    "Veliparib": "PARP",
    "MK-1775": "WEE1",
    "Cyclophosphamide": "Alkylating agent",
    "5-Fluorouracil": "Antimet.",
    "Epirubicin": "Anthracycline",
    "Tamoxifen": "ER",
    "Methotrexate": "Antimet.",
    "Venetoclax": "MCL1/BCL2",
    # "AZD5153": "BRD4",
    "PD173074": "FGFR",
    "Sapitinib": "EGFR/ERBB2",
    "AZD4547": "FGFR",
    "Vorinostat": "HDAC",
    "Refametinib": "MEK",
    "Selumetinib": "MEK",
    "Trametinib": "MEK",
    "Axitinib": "VEGFR",
    "Gemcitabine": "Antimet.",
    "Irinotecan": "TOP1",
    "VE-822": "ATR",
    "5-Fluorouracil": "Antimet.",
    "Crizotinib": "ALK/ROS1",
    "Cytarabine": "Antimet.",
    "Entinostat": "HDAC",
    "Foretinib": "VEGFR",
    "Fulvestrant": "ER",
    "Motesanib": "VEGFR",
    "Navitoclax": "MCL1/BCL2",
    "PD173074": "FGFR",
    "Pyridostatin": "G4",
    # "Rapamycin": "mTOR",
    "Temsirolimus": "mTOR",
    "Tanespimycin": "HSP90",
    "Uprosertib": "AKT",
    "AZD5363": "AKT",
    "Dabrafenib": "BRAF",
    "Temozolomide": "Alkylating agent",
    "Vinblastine": "Vinca alkyloid",
    "Vinorelbine": "Vinca alkyloid",
}
# FULL DRUG LIST FOR H16
drugs = sorted(target_dict.keys())

# drugs = ["AZD6738"]

# drugs = [drug.upper() for drug in drugs]
# target_dict = {key.upper(): val for key, val in target_dict.items()}


# DRUG LIST FOR BREAST ONLY EXPERIMENTS
# drugs = ["Lapatinib", "Afuresertib", "Epirubicin", "Docetaxel", "Ipatasertib", "Cisplatin", "Oxaliplatin", "AZD6738", "5-Fluorouracil", "Cyclophosphamide",
# "Ribociclib", "Palbociclib", "Niraparib", "Olaparib", "Veliparib", "Alpelisib", "Tamoxifen"]

# DRUG LIST FOR PAPER REPORTING -- BREAST
# drugs = ["Alpelisib", "Oxaliplatin", "Ipatasertib", "Docetaxel", "AZD6738", "Erlotinib", "Osimertinib", "Linsitinib"]

# SELECT DRUGS PLOT
# drugs = ["AZD6738", "Palbociclib", "AZD4547", "Linsitinib", "Trametinib", "Oxaliplatin", "Sorafenib"]

dataset_name = "depmap_gdsc"
experiment = "h16"
seeds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
fes = ["icovae", "vae", "pca", "none"]
models = ["SVR", "ElasticNet", "RandomForestRegressor", "nn"]
perf_dict_list = []

In [None]:
sorted(target_dict.keys())

In [None]:
comp = PerfComp(sorted(drugs), experiment, dataset=dataset_name, wd_path=wd_path)

In [None]:
# ONLY RUN IF NEED TO RECALCULATE, OTHERWISE LOAD BELOW
perf_df = comp.calculate_perf(fes, models, use_cache=False)

In [None]:
perf_df.to_csv(f"../figures/perf_df_latest_{experiment}_new.csv")
perf_df = pd.read_csv(f"../figures/perf_df_latest_{experiment}_new.csv").drop(
    "Unnamed: 0", axis=1
)

In [None]:
len(perf_df["target"].unique())

In [None]:
perf_df.groupby(["target", "model", "fe"]).mean().groupby("target").count().sort_values(
    "rmse", ascending=False
)

In [None]:
perf_df[perf_df["target"] == "Trametinib"].groupby(
    ["target", "model", "fe"]
).mean().sort_values("spearman_r_val", ascending=False)

In [None]:
# Results reporting for paper
perf_df[
    perf_df["target"].isin(
        [
            "AZD6738",
            "Linsitinib",
            "Axitinib",
            "Lapatinib",
            "Methotrexate",
            "Oxaliplatin",
            "Trametinib",
        ]
    )
].groupby(["target", "model", "fe"]).mean().head(50)

In [None]:
# PERF PLOT ONE BOXPLOT
from matplotlib.lines import Line2D

pal = sns.color_palette("colorblind", n_colors=5)

palette = {"PiCo": pal[0], "VAE": pal[1], "NN": pal[2]}

for metric in ["spearman_r_val", "pearson_r_val", "rmse_val"]:
    for plot_metric in ["_".join(metric.split("_")[:-1]), metric]:
        perf_df_plot = perf_df.reset_index()  # .drop(["n_genes"], axis=1)

        order = sorted(target_dict.items(), key=lambda x: x[1])
        order = [key for key, val in order]
        drug_targets = sorted(list(set(target_dict.values())))

        pal = sns.color_palette("colorblind", n_colors=5)

        palette = {"iCoVAE": pal[0], "VAE": pal[1], "PCA": pal[2], "Raw": "grey"}

        fe_dict = {"none": "Raw", "vae": "VAE", "icovae": "iCoVAE", "pca": "PCA"}
        model_dict = {
            "SVR": "SVR",
            "ElasticNet": "ElasticNet",
            "nn": "MLP",
            "RandomForestRegressor": "RF",
        }

        # perf_df_plot = perf_df.copy().drop("zdim", axis=1)

        perf_df_plot["fe"] = perf_df_plot["fe"].map(fe_dict)
        perf_df_plot["model"] = perf_df_plot["model"].map(model_dict)
        perf_df_plot["drug_target"] = perf_df_plot["target"].map(target_dict)

        # If validation, drop nans
        if metric.split("_")[-1] == "val":
            perf_df_plot = perf_df_plot.dropna(axis=0)

        # perf_df_best = (
        #     perf_df_plot
        #     .groupby(["fe", "target", "drug_target", "model"])
        #     .mean()
        #     .reset_index()
        # )

        # if metric == "rmse":
        #     perf_df_best = perf_df_best.loc[
        #         perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmin()
        #     ].set_index(["fe", "target", "drug_target", "model"])
        #     perf_df_plot = (perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
        #         .loc[perf_df_best.index]
        #         .reset_index()
        #         .sort_values("target", ascending=True)
        #     )
        # else:
        #     perf_df_best = perf_df_best.loc[
        #         perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmax()
        #     ].set_index(["fe", "target", "drug_target", "model"])
        #     perf_df_plot = (perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
        #         .loc[perf_df_best.index]
        #         .reset_index()
        #         .sort_values("target", ascending=True)
        #     )

        fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))

        hue_order = ["iCoVAE", "VAE", "PCA", "Raw"]
        order = ["ElasticNet", "SVR", "RF", "MLP"]
        # sns.boxplot(data=perf_df_plot, x="model", y=plot_metric, ax=ax, hue="fe", palette=palette, width=0.8, order=order, hue_order=hue_order)
        sns.barplot(
            data=perf_df_plot.groupby(["fe", "model", "target"]).mean(),
            x="model",
            y=plot_metric,
            ax=ax,
            hue="fe",
            palette=palette,
            width=0.8,
            order=order,
            hue_order=hue_order,
            errorbar=("sd", 1),
            err_kws={"linewidth": 1},
            estimator="mean",
        )
        sns.stripplot(
            data=perf_df_plot.groupby(["fe", "model", "target"]).mean(),
            x="model",
            y=plot_metric,
            ax=ax,
            hue="fe",
            dodge=True,
            palette=palette,
            alpha=0.3,
            jitter=True,
            order=order,
            linewidth=0.2,
            hue_order=hue_order,
            size=4,
        )
        # sns.lineplot(data=perf_df_best.reset_index(), x="model", y=plot_metric, ax=ax, hue="target", palette=["#515151"] * len(perf_df_best.reset_index()["target"].unique()), legend=False, linewidth=0.5, alpha=0.3)
        if plot_metric in ["spearman_r", "spearman_r_val"]:
            ax.set_ylabel("Spearman correlation")
        elif plot_metric in ["pearson_r", "pearson_r_val"]:
            ax.set_ylabel("Pearson correlation")
        elif plot_metric in ["rmse", "rmse_val"]:
            ax.set_ylabel("RMSE")
        ax.set_xlabel("Regression model")
        sns.despine(ax=ax)

        from scipy.stats import wilcoxon

        for model_type in order:
            perf_df_model = perf_df_plot[perf_df_plot["model"] == model_type]
            try:
                group1 = perf_df_model[perf_df_model["fe"] == "iCoVAE"][plot_metric]
                group2 = perf_df_model[perf_df_model["fe"] == "VAE"][plot_metric]
                stat, p = wilcoxon(group1, group2)
                print(
                    f"Wilcoxon test for iCoVAE vs. VAE {model_type}, {plot_metric}: stat={stat}, p={p}"
                )
            except:
                pass

            try:
                group1 = perf_df_model[perf_df_model["fe"] == "iCoVAE"][plot_metric]
                group2 = perf_df_model[perf_df_model["fe"] == "Raw"][plot_metric]
                stat, p = wilcoxon(group1, group2)
                print(
                    f"Wilcoxon test for iCoVAE vs. Raw {model_type}, {plot_metric}: stat={stat}, p={p}"
                )

            except:
                pass

        # try:
        #     # Then add with ax.text() or ax.annotate()
        #     if metric.split("_")[-1] == "val":
        #         if metric.split("_")[0] == "rmse":
        #             ax.text(x=0.26, y=0.75, s=f"$p = {p:.2e}$", transform=ax.transAxes)
        #         else:
        #             ax.text(x=0.26, y=1.00, s=f"$p = {p:.2e}$", transform=ax.transAxes)
        #     else:
        #         if metric.split("_")[0] == "rmse":
        #             ax.text(x=0.51, y=0.75, s=f"$p = {p:.2e}$", transform=ax.transAxes)
        #         else:
        #             ax.text(x=0.51, y=1.00, s=f"$p = {p:.2e}$", transform=ax.transAxes)
        # except:
        #     pass

        fe_handles = [
            Line2D([], [], marker="o", linestyle="none", color=palette[fe], label=fe)
            for fe in hue_order
        ]
        leg1 = ax.legend(
            handles=fe_handles,
            bbox_to_anchor=(0.4, 1.25),
            loc="upper center",
            frameon=False,
            ncol=4,
            title="Feature extractor",
        )

        ax.set_ylim(0)

        extra_artists = (leg1,) if leg1 is not None else ()
        fig.canvas.draw()
        fig.savefig(
            f"{wd_path}/results_analysis/figures/perf_comp_singlebox_{plot_metric}_{experiment}_new.svg",
            bbox_inches="tight",
            bbox_extra_artists=extra_artists,
        )
        fig.savefig(
            f"{wd_path}/results_analysis/figures/perf_comp_singlebox_{plot_metric}_{experiment}_new.png",
            bbox_inches="tight",
            dpi=1200,
            bbox_extra_artists=extra_artists,
        )

In [None]:
# PERF PLOT ONE BOXPLOT - GEN GAP
pal = sns.color_palette("colorblind", n_colors=5)

# palette = {"PICo-T": pal_mod[0], "PICo-EN": pal_mod[1], "PICo-SVR": pal_mod[2], "VAE-T": pal_mod[3], "VAE-EN": pal_mod[4], "VAE-SVR": pal_mod[5], "NN":pal_mod[6]}
palette = {"PiCo": pal[0], "VAE": pal[1], "NN": pal[2]}

for metric in ["spearman_r_val", "pearson_r_val", "rmse_val"]:
    for plot_metric in ["_".join(metric.split("_")[:-1])]:
        perf_df_plot = perf_df.reset_index()  # .drop(["n_genes"], axis=1)
        perf_df_plot[f"{plot_metric}_gap"] = (
            100
            * (perf_df_plot[plot_metric] - perf_df_plot[f"{plot_metric}_val"])
            / perf_df_plot[f"{plot_metric}_val"]
        )

        order = sorted(target_dict.items(), key=lambda x: x[1])
        order = [key for key, val in order]
        drug_targets = sorted(list(set(target_dict.values())))

        pal = sns.color_palette("colorblind", n_colors=5)

        palette = {"iCoVAE": pal[0], "VAE": pal[1], "PCA": pal[2], "Raw": "grey"}

        fe_dict = {"none": "Raw", "vae": "VAE", "icovae": "iCoVAE", "pca": "PCA"}
        model_dict = {
            "SVR": "SVR",
            "ElasticNet": "ElasticNet",
            "nn": "MLP",
            "RandomForestRegressor": "RF",
        }

        # perf_df_plot = perf_df.copy().drop("zdim", axis=1)

        perf_df_plot["fe"] = perf_df_plot["fe"].map(fe_dict)
        perf_df_plot["model"] = perf_df_plot["model"].map(model_dict)
        perf_df_plot["drug_target"] = perf_df_plot["target"].map(target_dict)

        # If validation, drop nans
        if metric.split("_")[-1] == "val":
            perf_df_plot = perf_df_plot.dropna(axis=0)

        # perf_df_best = (
        #     perf_df_plot
        #     .groupby(["fe", "target", "drug_target", "model"])
        #     .mean()
        #     .reset_index()
        # )

        # if metric == "rmse":
        #     perf_df_best = perf_df_best.loc[
        #         perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmin()
        #     ].set_index(["fe", "target", "drug_target", "model"])
        #     perf_df_plot = (
        #         perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
        #         .loc[perf_df_best.index]
        #         .reset_index()
        #         .sort_values("target", ascending=True)
        #     )
        # else:
        #     perf_df_best = perf_df_best.loc[
        #         perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmax()
        #     ].set_index(["fe", "target", "drug_target", "model"])
        #     perf_df_plot = (
        #         perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
        #         .loc[perf_df_best.index]
        #         .reset_index()
        #         .sort_values("target", ascending=True)
        #     )

    fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))

    hue_order = ["iCoVAE", "VAE", "PCA", "Raw"]
    order = ["ElasticNet", "SVR", "RF", "MLP"]
    sns.barplot(
        data=perf_df_plot.groupby(["fe", "model", "target"]).mean(),
        x="model",
        y=f"{plot_metric}_gap",
        ax=ax,
        hue="fe",
        palette=palette,
        width=0.8,
        order=order,
        hue_order=hue_order,
        errorbar=("sd", 1),
        err_kws={"linewidth": 1},
    )
    sns.stripplot(
        data=perf_df_plot.groupby(["fe", "model", "target"]).mean(),
        x="model",
        y=f"{plot_metric}_gap",
        ax=ax,
        hue="fe",
        dodge=True,
        palette=palette,
        alpha=0.3,
        jitter=True,
        order=order,
        linewidth=0.2,
        hue_order=hue_order,
        size=4,
    )
    # sns.lineplot(data=perf_df_best.reset_index(), x="model", y=f"{plot_metric}_gap", ax=ax, hue="target", palette=["#515151"] * len(perf_df_best.reset_index()["target"].unique()), legend=False, linewidth=0.5, alpha=0.3)
    if plot_metric in ["spearman_r", "spearman_r_val"]:
        ax.set_ylabel("Generalisation gap\n($\%$ diff. Spearman correlation)")
    elif plot_metric in ["pearson_r", "pearson_r_val"]:
        ax.set_ylabel("Generalisation gap\n($\%$ diff. Pearson correlation)")
    elif plot_metric in ["rmse", "rmse_val"]:
        ax.set_ylabel("Generalisation gap\n($\%$ diff. RMSE)")
    ax.set_xlabel("Regression model")
    sns.despine(ax=ax)

    from scipy.stats import wilcoxon

    for model_type in order:
        perf_df_model = perf_df_plot[perf_df_plot["model"] == model_type]
        try:
            group1 = perf_df_model[perf_df_model["fe"] == "iCoVAE"][
                f"{plot_metric}_gap"
            ]
            group2 = perf_df_model[perf_df_model["fe"] == "VAE"][f"{plot_metric}_gap"]
            stat, p = wilcoxon(group1, group2)
            print(
                f"Wilcoxon test for iCoVAE vs. VAE {model_type}, {plot_metric}_gap: stat={stat}, p={p}"
            )
        except:
            pass

        try:
            group1 = perf_df_model[perf_df_model["fe"] == "iCoVAE"][
                f"{plot_metric}_gap"
            ]
            group2 = perf_df_model[perf_df_model["fe"] == "Raw"][f"{plot_metric}_gap"]
            stat, p = wilcoxon(group1, group2)
            print(
                f"Wilcoxon test for iCoVAE vs. Raw {model_type}, {plot_metric}_gap: stat={stat}, p={p}"
            )

        except:
            pass

    fe_handles = [
        Line2D([], [], marker="o", linestyle="none", color=palette[fe], label=fe)
        for fe in hue_order
    ]
    leg1 = ax.legend(
        handles=fe_handles,
        bbox_to_anchor=(0.4, 1.25),
        loc="upper center",
        frameon=False,
        ncol=4,
        title="Feature extractor",
    )
    ax.axhline(0, color="grey", linestyle="--", linewidth=1)

    ax.set_ylim(-100, 25)

    extra_artists = (leg1,) if leg1 is not None else ()
    fig.canvas.draw()
    fig.savefig(
        f"{wd_path}/results_analysis/figures/perf_comp_singlebox_{plot_metric}_{experiment}_gengap_new.svg",
        bbox_inches="tight",
        bbox_extra_artists=extra_artists,
    )
    fig.savefig(
        f"{wd_path}/results_analysis/figures/perf_comp_singlebox_{plot_metric}_{experiment}_gengap_new.png",
        bbox_inches="tight",
        dpi=1200,
        bbox_extra_artists=extra_artists,
    )

In [None]:
# Get ranking of models for each metric
for metric in ["rmse_val", "pearson_r_val", "spearman_r_val"]:
    for plot_metric in ["_".join(metric.split("_")[:-1]), metric]:
        perf_df_best = perf_df.groupby(["fe", "target", "model"]).mean().reset_index()

        if metric.split("_")[-1] == "val":
            perf_df_best = perf_df_best.dropna(axis=0)

        if metric == "rmse_val":
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target"])[metric].idxmin()
            ]
        else:
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target"])[metric].idxmax()
            ]

        if plot_metric in ["rmse", "rmse_val"]:
            perf_df_best[f"{plot_metric}_rank"] = perf_df_best.groupby("target").rank(
                ascending=True
            )[plot_metric]
        else:
            perf_df_best[f"{plot_metric}_rank"] = perf_df_best.groupby("target").rank(
                ascending=False
            )[plot_metric]

        print(perf_df_best[["fe", f"{plot_metric}_rank"]].groupby("fe").value_counts())
        print(perf_df_best[["fe", f"{plot_metric}_rank"]].groupby("fe").mean())

In [None]:
# Not including drug class
from matplotlib.lines import Line2D  # near other imports

single_drug = None

for metric in ["spearman_r_val", "pearson_r_val", "rmse_val"]:
    for plot_metric in ["_".join(metric.split("_")[:-1]), metric]:
        perf_df_plot = perf_df.reset_index()  # .drop(["n_genes"], axis=1)

        order = sorted(target_dict.items(), key=lambda x: x[1])
        order = [key for key, val in order]
        drug_targets = sorted(list(set(target_dict.values())))

        pal = sns.color_palette("colorblind", n_colors=5)

        palette = {"iCoVAE": pal[0], "VAE": pal[1], "PCA": pal[2], "Raw": "grey"}

        fe_dict = {"none": "Raw", "vae": "VAE", "icovae": "iCoVAE", "pca": "PCA"}
        model_dict = {
            "SVR": "SVR",
            "ElasticNet": "ElasticNet",
            "nn": "MLP",
            "RandomForestRegressor": "Random forest",
        }

        # perf_df_plot = perf_df.copy().drop("zdim", axis=1)

        perf_df_plot["fe"] = perf_df_plot["fe"].map(fe_dict)
        perf_df_plot["model"] = perf_df_plot["model"].map(model_dict)
        perf_df_plot["drug_target"] = perf_df_plot["target"].map(target_dict)

        # If validation, drop nans
        if metric.split("_")[-1] == "val":
            perf_df_plot = perf_df_plot.dropna(axis=0)

        perf_df_best = (
            perf_df_plot.groupby(["fe", "target", "drug_target", "model"])
            .mean()
            .reset_index()
        )

        if metric == "rmse":
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmin()
            ].set_index(["fe", "target", "drug_target", "model"])
            perf_df_plot = (
                perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
                .loc[perf_df_best.index]
                .reset_index()
                .sort_values("target", ascending=True)
            )
        else:
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmax()
            ].set_index(["fe", "target", "drug_target", "model"])
            perf_df_plot = (
                perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
                .loc[perf_df_best.index]
                .reset_index()
                .sort_values("target", ascending=True)
            )

        if single_drug is not None:
            perf_df_plot = perf_df_plot[perf_df_plot["target"] == single_drug]

        hue_order = ["iCoVAE", "VAE", "PCA", "Raw"]

        # order = ["AZD5991", "Alpelisib", "AZD8186", "Gefitinib", "Lapatinib", "Sorafenib", "Docetaxel", "Paclitaxel", "Taselisib", "Palbociclib", "AZD3759", "Afatinib"]
        col_order = ["SVR", "ElasticNet", "RandomForestRegressor", "MLP"]
        # order = ["Afatinib"]

        yranges = (
            perf_df_plot.groupby("drug_target")["target"]
            .agg("nunique")
            .to_numpy(dtype="float")
        )
        yranges *= 1.1  # Account for default margins

        drug_targets_plot = [
            drug_target
            for drug_target in drug_targets
            if drug_target in set(perf_df_plot["drug_target"])
        ]

        model_markers = {
            "SVR": "o",
            "ElasticNet": "s",
            "Random forest": "D",
            "MLP": "^",
        }

        perf_df_plot["fe_model"] = perf_df_plot["fe"] + "__" + perf_df_plot["model"]

        fe_model_order = [
            f"{fe}__{model}"
            for fe in hue_order
            for model in model_markers
            if f"{fe}__{model}" in perf_df_plot["fe_model"].values
        ]
        palette_fe_model = {k: palette[k.split("__")[0]] for k in fe_model_order}
        markers_fe_model = [model_markers[k.split("__")[1]] for k in fe_model_order]

        fig, ax = plt.subplots(1, 1, figsize=(11, 4))

        # f = sns.catplot(
        #     perf_df_plot,
        #     x=plot_metric,
        #     y="target",
        #     hue="fe_model",
        #     hue_order=fe_model_order,
        #     row="drug_target",
        #     row_order=drug_targets_plot,
        #     palette=palette_fe_model,
        #     markers=markers_fe_model,
        #     #hue_order=hue_order,
        #     markersize=4,
        #     kind="point",
        #     linestyle="none",
        #     height=0.625,
        #     aspect=6,
        #     sharey="row",
        #     capsize=0.25,
        #     errorbar=("sd", 1),
        #     facet_kws=dict(gridspec_kws={"height_ratios": yranges}),
        #     err_kws={"linewidth": 1.5, "alpha": 0.5},
        #     dodge=0.4,
        # )

        # Plot order should be sorted by mean value for fe=iCoVAE
        order = (
            perf_df_plot[perf_df_plot["fe"] == "iCoVAE"]
            .groupby("target")[plot_metric]
            .mean()
            .sort_values(ascending=(metric == "rmse"))
            .index.tolist()
        )

        sns.pointplot(
            data=perf_df_plot,
            y=plot_metric,
            x="target",
            hue="fe_model",
            order=order,
            hue_order=fe_model_order,
            palette=palette_fe_model,
            markers=markers_fe_model,
            dodge=0.4,
            markersize=4,
            join=False,
            errorbar=("sd", 1),
            err_kws={"linewidth": 1.5, "alpha": 0.5},
            capsize=0.25,
            ax=ax,
        )
        if metric == "rmse":
            ax.set_ylabel("RMSE")
            ax.set_xlabel("")
        elif metric == "pearson_r":
            ax.set_ylabel("Pearson correlation")
            ax.set_xlabel("")
        elif metric == "spearman_r":
            ax.set_ylabel("Spearman correlation")
            ax.set_xlabel("")
        elif metric == "spearman_r_val":
            ax.set_ylabel("Spearman correlation")
            ax.set_xlabel("")
        elif metric == "pearson_r_val":
            ax.set_ylabel("Pearson correlation")
            ax.set_xlabel("")
        elif metric == "rmse_val":
            ax.set_ylabel("RMSE")
            ax.set_xlabel("")

        ax.grid(visible=True, axis="y")
        if plot_metric.split("_")[0] != "rmse":
            if plot_metric.split("_")[-1] != "val":
                ax.set_ylim(0, 0.7)
                ax.set_yticks(
                    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
                    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
                )
            else:
                ax.set_ylim(0.2, 0.85)
                ax.set_yticks(
                    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
                    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
                )
        fe_handles = [
            Line2D([], [], marker="o", linestyle="none", color=palette[fe], label=fe)
            for fe in hue_order
        ]
        model_handles = [
            Line2D(
                [], [], marker=marker, linestyle="none", color="darkgrey", label=model
            )
            for model, marker in model_markers.items()
        ]

        leg1 = ax.legend(
            handles=fe_handles,
            bbox_to_anchor=(0.1, 1.2),
            loc="upper left",
            frameon=False,
            ncol=4,
            title="Feature extractor",
        )
        leg2 = ax.legend(
            handles=model_handles,
            bbox_to_anchor=(0.5, 1.2),
            loc="upper left",
            frameon=False,
            ncol=4,
            title="Regression model",
        )
        ax.add_artist(leg1)
        ax.add_artist(leg2)

        ax.legend_.remove()
        sns.despine(ax=ax)

        # Rotate x ticks 90 degrees
        plt.xticks(rotation=90)

        # handles, labels = plt.gca().get_legend_handles_labels()

        # leg1 = axes[0].legend(
        #     handles=handles,
        #     bbox_to_anchor=(0.5, 2.3),
        #     loc="upper center",
        #     frameon=False,
        #     ncol=4,
        # )

        # axes[0].add_artist(leg1)

        # f._legend.remove()

        # sns.move_legend(f, "upper right", bbox_to_anchor=(0.95,0.95), title='Method')
        plt.tight_layout()
        if single_drug is not None:
            plt.savefig(
                f"./figures/perf_comp_{plot_metric}_{experiment}_{single_drug}_all_new.png",
                bbox_inches="tight",
                bbox_extra_artists=(leg1, leg2),
                dpi=600,
            )
            plt.savefig(
                f"./figures/perf_comp_{plot_metric}_{experiment}_{single_drug}_all_new.svg",
                bbox_inches="tight",
                bbox_extra_artists=(leg1, leg2),
            )
        else:
            plt.savefig(
                f"./figures/perf_comp_{plot_metric}_{experiment}_all_new.png",
                bbox_inches="tight",
                bbox_extra_artists=(leg1, leg2),
                dpi=600,
            )
            plt.savefig(
                f"./figures/perf_comp_{plot_metric}_{experiment}_all_new.svg",
                bbox_inches="tight",
                bbox_extra_artists=(leg1, leg2),
            )

In [None]:
# Not including drug class
from matplotlib.lines import Line2D  # near other imports

single_drug = "5-Fluorouracil"

for metric in ["spearman_r_val", "pearson_r_val", "rmse_val"]:
    for plot_metric in ["_".join(metric.split("_")[:-1]), metric]:
        perf_df_plot = perf_df.reset_index()  # .drop(["n_genes"], axis=1)

        pal = sns.color_palette("colorblind", n_colors=5)

        palette = {"iCoVAE": pal[0], "VAE": pal[1], "PCA": pal[2], "Raw": "grey"}

        fe_dict = {"none": "Raw", "vae": "VAE", "icovae": "iCoVAE", "pca": "PCA"}
        model_dict = {
            "SVR": "SVR",
            "ElasticNet": "ElasticNet",
            "nn": "MLP",
            "RandomForestRegressor": "Random forest",
        }

        # perf_df_plot = perf_df.copy().drop("zdim", axis=1)

        perf_df_plot["fe"] = perf_df_plot["fe"].map(fe_dict)
        perf_df_plot["model"] = perf_df_plot["model"].map(model_dict)
        perf_df_plot["drug_target"] = perf_df_plot["target"].map(target_dict)

        # If validation, drop nans
        if metric.split("_")[-1] == "val":
            perf_df_plot = perf_df_plot.dropna(axis=0)

        perf_df_best = (
            perf_df_plot.groupby(["fe", "target", "drug_target", "model"])
            .mean()
            .reset_index()
        )

        if metric == "rmse":
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmin()
            ].set_index(["fe", "target", "drug_target", "model"])
            perf_df_plot = (
                perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
                .loc[perf_df_best.index]
                .reset_index()
                .sort_values("target", ascending=True)
            )
        else:
            perf_df_best = perf_df_best.loc[
                perf_df_best.groupby(["fe", "target", "drug_target"])[metric].idxmax()
            ].set_index(["fe", "target", "drug_target", "model"])
            perf_df_plot = (
                perf_df_plot.set_index(["fe", "target", "drug_target", "model"])
                .loc[perf_df_best.index]
                .reset_index()
                .sort_values("target", ascending=True)
            )

        if single_drug is not None:
            perf_df_plot = perf_df_plot[perf_df_plot["target"] == single_drug]

        hue_order = ["iCoVAE", "VAE", "PCA", "Raw"]

        # order = ["AZD5991", "Alpelisib", "AZD8186", "Gefitinib", "Lapatinib", "Sorafenib", "Docetaxel", "Paclitaxel", "Taselisib", "Palbociclib", "AZD3759", "Afatinib"]
        col_order = ["SVR", "ElasticNet", "RandomForestRegressor", "MLP"]
        # order = ["Afatinib"]

        yranges = (
            perf_df_plot.groupby("drug_target")["target"]
            .agg("nunique")
            .to_numpy(dtype="float")
        )
        yranges *= 1.1  # Account for default margins

        # drug_targets_plot = [
        #     drug_target for drug_target in drug_targets if drug_target in set(perf_df_plot["drug_target"])
        # ]

        model_markers = {
            "SVR": "o",
            "ElasticNet": "s",
            "Random forest": "D",
            "MLP": "^",
        }

        perf_df_plot["fe_model"] = perf_df_plot["fe"] + "__" + perf_df_plot["model"]

        fe_model_order = [
            f"{fe}__{model}"
            for fe in hue_order
            for model in model_markers
            if f"{fe}__{model}" in perf_df_plot["fe_model"].values
        ]
        palette_fe_model = {k: palette[k.split("__")[0]] for k in fe_model_order}
        markers_fe_model = [model_markers[k.split("__")[1]] for k in fe_model_order]

        fig, ax = plt.subplots(1, 1, figsize=(3, 1))

        # f = sns.catplot(
        #     perf_df_plot,
        #     x=plot_metric,
        #     y="target",
        #     hue="fe_model",
        #     hue_order=fe_model_order,
        #     row="drug_target",
        #     row_order=drug_targets_plot,
        #     palette=palette_fe_model,
        #     markers=markers_fe_model,
        #     #hue_order=hue_order,
        #     markersize=4,
        #     kind="point",
        #     linestyle="none",
        #     height=0.625,
        #     aspect=6,
        #     sharey="row",
        #     capsize=0.25,
        #     errorbar=("sd", 1),
        #     facet_kws=dict(gridspec_kws={"height_ratios": yranges}),
        #     err_kws={"linewidth": 1.5, "alpha": 0.5},
        #     dodge=0.4,
        # )

        # Plot order should be sorted by mean value for fe=iCoVAE
        order = (
            perf_df_plot[perf_df_plot["fe"] == "iCoVAE"]
            .groupby("target")[plot_metric]
            .mean()
            .sort_values(ascending=(metric == "rmse"))
            .index.tolist()
        )

        sns.pointplot(
            data=perf_df_plot,
            x=plot_metric,
            y="target",
            hue="fe_model",
            hue_order=fe_model_order,
            palette=palette_fe_model,
            markers=markers_fe_model,
            dodge=0.4,
            markersize=4,
            join=False,
            errorbar=("sd", 1),
            err_kws={"linewidth": 1.5, "alpha": 0.5},
            capsize=0.25,
            ax=ax,
        )
        if metric == "rmse":
            ax.set_xlabel("RMSE")
            ax.set_ylabel("")
        elif metric == "pearson_r":
            ax.set_xlabel("Pearson correlation")
            ax.set_ylabel("")
        elif metric == "spearman_r":
            ax.set_xlabel("Spearman correlation")
            ax.set_ylabel("")
        elif metric == "spearman_r_val":
            ax.set_xlabel("Spearman correlation")
            ax.set_ylabel("")
        elif metric == "pearson_r_val":
            ax.set_xlabel("Pearson correlation")
            ax.set_ylabel("")
        elif metric == "rmse_val":
            ax.set_xlabel("RMSE")
            ax.set_ylabel("")

        ax.grid(visible=True, axis="x")
        if plot_metric.split("_")[0] != "rmse":
            if plot_metric.split("_")[-1] != "val":
                ax.set_xlim(0, 0.7)
                ax.set_xticks(
                    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
                    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
                )
            else:
                ax.set_xlim(0.2, 0.85)
                ax.set_xticks(
                    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
                    [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
                )
        fe_handles = [
            Line2D([], [], marker="o", linestyle="none", color=palette[fe], label=fe)
            for fe in hue_order
        ]
        model_handles = [
            Line2D(
                [], [], marker=marker, linestyle="none", color="darkgrey", label=model
            )
            for model, marker in model_markers.items()
        ]

        leg1 = ax.legend(
            handles=fe_handles,
            bbox_to_anchor=(0.1, 1.2),
            loc="upper left",
            frameon=False,
            ncol=4,
            title="Feature extractor",
        )
        leg2 = ax.legend(
            handles=model_handles,
            bbox_to_anchor=(0.5, 1.2),
            loc="upper left",
            frameon=False,
            ncol=4,
            title="Regression model",
        )
        # ax.add_artist(leg1)
        # ax.add_artist(leg2)

        ax.legend_.remove()
        sns.despine(ax=ax)

        # Rotate x ticks 90 degrees
        # plt.xticks(rotation=90)

        # handles, labels = plt.gca().get_legend_handles_labels()

        # leg1 = axes[0].legend(
        #     handles=handles,
        #     bbox_to_anchor=(0.5, 2.3),
        #     loc="upper center",
        #     frameon=False,
        #     ncol=4,
        # )

        # axes[0].add_artist(leg1)

        # f._legend.remove()

        # sns.move_legend(f, "upper right", bbox_to_anchor=(0.95,0.95), title='Method')
        plt.tight_layout()
        plt.savefig(
            f"./figures/perf_comp_{plot_metric}_{experiment}_{single_drug}_all_new.png",
            bbox_inches="tight",
            # bbox_extra_artists=(leg1,leg2),
            dpi=600,
        )
        plt.savefig(
            f"./figures/perf_comp_{plot_metric}_{experiment}_{single_drug}_all_new.svg",
            bbox_inches="tight",
            # bbox_extra_artists=(leg1,leg2),
        )

### 2.1.2 Performance comparison split by type

In [None]:
drug = "AZD5991"
experiment = "h16"

comp = PerfComp([drug], experiment, dataset="depmap_gdsc", wd_path=wd_path)

fes = ["icovae", "vae"]
models = ["ElasticNet"]
perf_df_type = comp.calculate_perf(fes, models, by_type=True, use_cache=True)

In [None]:
perf_df_means = (
    perf_df_type.groupby(["target", "model", "fe", "lineage"]).mean().reset_index()
)
perf_df_means[perf_df_means["lineage"] == "Liver"]

In [None]:
perf_df_type

In [None]:
# PERF COMP BEST TYPE -- TEST
lineage_dict = {
    "Prostate": "Prostate",
    "Uterus": "Uterus",
    "Cervix": "Cervix",
    "Bone": "Bone",
    "Liver": "Liver",
    "Bladder/Urinary Tract": "Bladder/UT",
    "Thyroid": "Thyroid",
    "Peripheral Nervous System": "PNS",
    "Kidney": "Kidney",
    "Pleura": "Pleura",
    "Testis": "Testis",
    "Ovary/Fallopian Tube": "Ovary/FT",
    "Lung": "Lung",
    "Breast": "Breast",
    "CNS/Brain": "CNS/Brain",
    "Head and Neck": "Head/Neck",
    "Pancreas": "Pancreas",
    "Soft Tissue": "Soft Tissue",
    "Esophagus/Stomach": "Stomach",
    "Bowel": "Bowel",
    "Skin": "Skin",
    "Lymphoid": "Lymphoid",
    "Myeloid": "Myeloid",
}

order = sorted(target_dict.items(), key=lambda x: x[1])
order = [key for key, val in order]
drug_targets = sorted(list(set(target_dict.values())))

pal = sns.color_palette("colorblind", n_colors=5)

# palette = {"PICo-T": pal_mod[0], "PICo-EN": pal_mod[1], "PICo-SVR": pal_mod[2], "VAE-T": pal_mod[3], "VAE-EN": pal_mod[4], "VAE-SVR": pal_mod[5], "NN":pal_mod[6]}
palette = {"iCoVAE": pal[0], "VAE": pal[1], "NN": pal[2]}

fe_dict = {"nn": "NN", "vae": "VAE", "icovae": "iCoVAE"}
model_dict = {"SVR": "SVR", "ElasticNet": "ElasticNet", "transfer": "Transfer"}


for metric in [
    "rmse",
    "pearson_r",
    "spearman_r",
    "rmse_val",
    "pearson_r_val",
    "spearman_r_val",
]:
    perf_df_plot = perf_df_type.copy()

    perf_df_plot["fe"] = perf_df_plot["fe"].map(fe_dict)
    perf_df_plot["model"] = perf_df_plot["model"].map(model_dict)
    # perf_df_plot["drug_target"] = perf_df_plot["target"].map(target_dict)
    perf_df_plot["lineage"] = perf_df_plot["lineage"].map(lineage_dict)

    perf_df_plot = perf_df_plot[["fe", "target", "model", "lineage", metric]].dropna(
        axis=0
    )

    perf_df_best = (
        perf_df_plot.groupby(["fe", "target", "model", "lineage"]).mean().reset_index()
    )

    perf_df_best = perf_df_best.loc[
        perf_df_best.groupby(["fe", "target", "lineage"])[metric].idxmax()
    ].set_index(["fe", "target", "model", "lineage"])

    print(perf_df_best)

    if metric.split("_")[0] == "rmse":
        perf_df_plot = (
            perf_df_plot.set_index(["fe", "target", "model", "lineage"])
            .loc[perf_df_best.index]
            .reset_index()
            .sort_values(metric, ascending=True)
        )
    else:
        perf_df_plot = (
            perf_df_plot.set_index(["fe", "target", "model", "lineage"])
            .loc[perf_df_best.index]
            .reset_index()
            .sort_values(metric, ascending=False)
        )

    if metric.split("_")[-1] == "val":
        hue_order = ["VAE", "iCoVAE"]
    else:
        hue_order = ["VAE", "iCoVAE"]

    # order = ["AZD5991", "Alpelisib", "AZD8186", "Gefitinib", "Lapatinib", "Sorafenib", "Docetaxel", "Paclitaxel", "Taselisib", "Palbociclib", "AZD3759", "Afatinib"]
    col_order = ["SVR", "ElasticNet", "Transfer"]
    # order = ["Afatinib"]

    yranges = perf_df_plot.groupby("target")["lineage"].agg("nunique")
    yranges *= 1.1  # Account for default margins

    print(perf_df_plot)

    f = sns.catplot(
        perf_df_plot,
        x=metric,
        y="lineage",
        hue="fe",
        col="target",
        palette=palette,
        hue_order=hue_order,
        markersize=4,
        kind="point",
        linestyle="none",
        height=2.5,
        aspect=0.85,
        sharey=True,
        capsize=0.25,
        errorbar=("sd", 1),
        err_kws={"linewidth": 1.5, "alpha": 0.5},
    )
    # facet_kws=dict(gridspec_kws={"height_ratios": yranges})
    if metric.split("_")[0] == "spearman":
        f.set_axis_labels("Spearman correlation", "")
    elif metric.split("_")[0] == "pearson":
        f.set_axis_labels("Pearson correlation", "")
    elif metric.split("_")[0] == "rmse":
        f.set_axis_labels("RMSE", "")
    axes = f.axes.flatten()
    for i, ax in enumerate(axes):
        # ax.text(s=drugs[i], x=0.5, y=1.05, transform=ax.transAxes, fontweight="bold", horizontalalignment="center")
        ax.tick_params(
            top=False,
            bottom=True,
            left=True,
            right=False,
            labelleft=True,
            labelbottom=True,
        )

        ax.set_title("")

        ax.grid(visible=True, axis="x")
        if metric.split("_")[0] != "rmse":
            ax.set_xlim(-0.2, 1.05)
            ax.set_xticks(
                [-0.2, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
                [-0.2, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
            )
            ax.axvline(0, linestyle="--", color="grey")

    handles, labels = plt.gca().get_legend_handles_labels()

    leg1 = axes[0].legend(
        handles=handles,
        bbox_to_anchor=(0.4, 1.2),
        loc="upper center",
        frameon=False,
        ncol=3,
    )

    axes[0].add_artist(leg1)

    f._legend.remove()

    # sns.move_legend(f, "upper right", bbox_to_anchor=(0.55,0.7), title='Method')
    plt.tight_layout()
    if metric.split("_")[-1] == "val":
        plt.savefig(
            f"./figures/perf_comp_{metric}_type_{drug}_train.png",
            bbox_inches="tight",
            bbox_extra_artists=(leg1,),
            dpi=600,
        )
        plt.savefig(
            f"./figures/perf_comp_{metric}_type_{drug}_train.svg",
            bbox_inches="tight",
            bbox_extra_artists=(leg1,),
        )
    else:
        plt.savefig(
            f"./figures/perf_comp_{metric}_type_{drug}.png",
            bbox_inches="tight",
            bbox_extra_artists=(leg1,),
            dpi=600,
        )
        plt.savefig(
            f"./figures/perf_comp_{metric}_type_{drug}.svg",
            bbox_inches="tight",
            bbox_extra_artists=(leg1,),
        )

### 2.1.4 Correlation of latent features with drug response

In [None]:
# PRED CORR WITH DRUG RESPONSE
pred_corrs_df = pd.DataFrame(pred_corrs)
print(pred_corrs_df)
pred_corrs_df["abs_corr"] = pred_corrs_df["s_corr"].abs()

pal = sns.color_palette("colorblind")
palette = {"mut": pal[1], "wt": pal[0], "all": "grey"}

pred_corrs_mean = (
    pred_corrs_df[pred_corrs_df["mut_stat"] == "all"]
    .groupby(["mut_stat"])
    .mean()
    .sort_values("abs_corr", ascending=False)
    .reset_index()
    .head(15)
)

f, ax = plt.subplots(1, 1, figsize=(3, 0.5))
ax.axvline(0, color="grey", linestyle="--", alpha=0.5)
sns.pointplot(
    data=pred_corrs_df,
    x="s_corr",
    hue="mut_stat",
    hue_order=["wt", "mut", "all"],
    palette=palette,
    linestyle="none",
    markersize=4,
    errorbar=("sd", 1),
    capsize=0.25,
    err_kws={"linewidth": 1.5, "alpha": 0.5},
)
sns.despine(ax=ax, left=True)
ax.set_ylabel("")
ax.set_xlabel(f"Spearman correlation\n pred. vs true {drug.lower()} ln(IC50)")
ax.tick_params(axis="y", which="major", labelsize=10)
ax.set_yticks([], [])
h, l = ax.get_legend_handles_labels()
ax.legend(
    handles=h,
    labels=["wild-type", "mutant", "all"],
    title=gene_mut.split("_")[0],
    fontsize="medium",
    ncol=2,
    frameon=False,
    bbox_to_anchor=(0.5, 3),
    loc="upper center",
    title_fontproperties={"style": "italic", "size": "medium"},
)
# for label in ax.get_xticklabels():
#   label.set_rotation(60)
#   label.set_ha("right")
#   label.set_rotation_mode("anchor")
plt.savefig(
    f"./figures/mut_wt_pred_{drug}_{gene_mut}.png", dpi=600, bbox_inches="tight"
)
plt.savefig(f"./figures/mut_wt_pred_{drug}_{gene_mut}.svg", bbox_inches="tight")

In [None]:
# MUT WT Z CORR WITH DRUG RESPONSE
dim_corrs_df = pd.DataFrame(dim_corrs)
print(
    dim_corrs_df.groupby(["dim", "mut_stat"])
    .mean()
    .sort_values("s_corr", ascending=False)
    .head(50)
)
print(dim_corrs_df.groupby(["dim", "mut_stat"]).mean().sort_values("s_corr").head(50))
dim_corrs_df["abs_corr"] = dim_corrs_df["s_corr"].abs()

# pal = sns.color_palette("colorblind")
# palette = {"mut": pal[1], "wt": "lightgray"}

dim_corrs_mean = (
    dim_corrs_df[dim_corrs_df["mut_stat"] == "all"]
    .groupby(["dim", "mut_stat"])
    .mean()
    .sort_values("abs_corr", ascending=False)
    .reset_index()
    .head(15)
)

# Pointplot with all dimensions
f, ax = plt.subplots(1, 1, figsize=(3, 5))
ax.axvline(0, color="grey", linestyle="--", alpha=0.5)
sns.pointplot(
    data=dim_corrs_df,
    x="s_corr",
    y="dim",
    hue="mut_stat",
    hue_order=["wt", "mut"],
    palette=palette,
    linestyle="none",
    markersize=4,
    errorbar=("sd", 1),
    capsize=0.25,
    err_kws={"linewidth": 1.5, "alpha": 0.5},
)
sns.despine(ax=ax)
ax.set_ylabel("")
ax.set_xlabel(f"Spearman correlation\nwith {drug.lower()} ln(IC50)")
ax.tick_params(axis="y", which="major", labelsize=10)
h, l = ax.get_legend_handles_labels()
ax.legend(
    handles=h,
    labels=["wild-type", "mutant"],
    title=gene_mut.split("_")[0],
    fontsize="medium",
    ncol=3,
    frameon=False,
    bbox_to_anchor=(0.45, 1.15),
    loc="upper center",
    title_fontproperties={"style": "italic", "size": "medium"},
)
# for label in ax.get_xticklabels():
#   label.set_rotation(60)
#   label.set_ha("right")
#   label.set_rotation_mode("anchor")
plt.savefig(
    f"./figures/mut_wt_corr_{drug}_{gene_mut}.png", dpi=600, bbox_inches="tight"
)
plt.savefig(f"./figures/mut_wt_corr_{drug}_{gene_mut}.svg", bbox_inches="tight")

# Boxplot with selected dimensions then all others
sel_genes_p53 = ["MDM2", "MDM4", "PPM1D"]
sel_genes_other = ["H2AZ1", "TTF2", "DCLRE1B"]
sel_dims_p53 = [f"$z_{{{gene}}}$" for gene in sel_genes_p53]
sel_dims_other = [f"$z_{{{gene}}}$" for gene in sel_genes_other]

dim_corrs_df["dim_type"] = dim_corrs_df["dim"].apply(
    lambda x: f"$p53$-related_{x}"
    if x in sel_dims_p53
    else (f"Predictive_{x}" if x in sel_dims_other else "Other")
)
order = list(reversed(dim_corrs_df["dim_type"].value_counts().index.tolist()))
print(order)
# Switch order of 3rd and 4th elements
order = [order[0], order[1], order[3], order[2]] + order[4:]
print(order)
f, ax = plt.subplots(1, 1, figsize=(2.5, 2.5))

ax.axvline(0, color="grey", linestyle="--", alpha=0.5)
ax.axhline(2.5, color="black", linestyle="-", alpha=0.5)
ax.axhline(5.5, color="black", linestyle="-", alpha=0.5)

sns.boxplot(
    data=dim_corrs_df[dim_corrs_df["mut_stat"] != "all"],
    x="s_corr",
    y="dim_type",
    hue="mut_stat",
    order=order,
    hue_order=["wt", "mut"],
    palette=palette,
    fliersize=2,
    width=0.6,
    dodge=True,
)
sns.despine(ax=ax)
ax.set_ylabel("")
ax.set_xlabel(f"Spearman correlation\nwith {drug.lower()} response ln(IC50)")
h, l = ax.get_legend_handles_labels()
ax.legend(
    handles=h,
    labels=["wild-type", "mutant"],
    title=gene_mut.split("_")[0],
    fontsize="medium",
    ncol=3,
    frameon=False,
    bbox_to_anchor=(0.475, 1.3),
    loc="upper center",
    title_fontproperties={"style": "italic", "size": "medium"},
)


# Set x-ticks to be just the part of the name after the underscore
new_labels = ["_".join(label.split("_")[1:]) for label in order]
new_labels[-1] = "Other"
print(new_labels)
ax.set_yticklabels(new_labels)
ax.set_xlim(-0.2, 0.6)
ax.tick_params(axis="y", which="major", labelsize=10)
plt.savefig(
    f"./figures/mut_wt_corr_{drug}_{gene_mut}_new.png", dpi=600, bbox_inches="tight"
)
plt.savefig(f"./figures/mut_wt_corr_{drug}_{gene_mut}_new.svg", bbox_inches="tight")

### 2.1.5 Permutation feature importance

In [None]:
### FEATURE IMPORTANCE
from utils.comp_utils import calculate_feat_imps

target = "5-Fluorouracil"
fe = "icovae"
model = "ElasticNet"
experiment = "h16"
dataset_name = "depmap_gdsc"
seeds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
train = True
# comp = PerfComp(
#     targets, experiment, dataset=dataset_name, wd_path=wd_path
# )

# pred_dict, test_z, genes, pi_df = comp._load_preds(
#     fe=fe,
#     model=model,
#     target=targets[0],
#     dataset=dataset_name,
#     experiment=experiment,
#     seeds=seeds,
#     plot_scatter=False,
#     return_z=True,
#     ld=False,
#     perm_imp=True,
# )
# pred_dict_ld, test_z_ld, genes_ld, pi_df_ld = comp._load_preds(fe=fe, model=model, drug=drug, dataset="GDSC", target="IC50", exp_type=exp_type, seeds=seeds, plot_scatter=False, return_z=True, ld=True, perm_imp=True)

pred_dict_list, constraints, confounders, pi_df = calculate_feat_imps(
    enc=fe,
    reg=model,
    model_path=f"{wd_path}/data/outputs/{dataset_name}/{target.upper()}_new/{experiment}/pico/{model}_{fe}",
    target="",
    seeds=seeds,
    train=train,
)

In [None]:
# PERMUTATION FEATURE IMPORTANCE NEW
from matplotlib.colors import rgb_to_hsv

n_feats = 32

cbar_labels = {
    "r": r"$\Delta \% \rho_{{p}}$",
    "s": r"$\Delta \% \rho_{{s}}$",
    "rmse": r"$\Delta\%$ RMSE",
}
metric_labels = {
    "s": f"Feature importance\n({cbar_labels['s']})",
    "r": f"Feature importance\n({cbar_labels['r']})",
    "rmse": f"Feature importance\n({cbar_labels['rmse']})",
}

for metric in ["s", "r", "rmse"]:
    pal = sns.color_palette("colorblind")
    f, ax = plt.subplots(
        1,
        1,
        figsize=(3, n_feats / 6),
    )
    pi_df_plot = pi_df.copy()
    pi_df_plot["fi_r"] = (
        100 * (pi_df_plot["r_perm"] - pi_df_plot["r"]) / (pi_df_plot["r"])
    )
    pi_df_plot["fi_s"] = (
        100 * (pi_df_plot["s_perm"] - pi_df_plot["s"]) / (pi_df_plot["s"])
    )
    pi_df_plot["fi_rmse"] = (
        100 * (pi_df_plot["rmse_perm"] - pi_df_plot["rmse"]) / pi_df_plot["rmse"]
    )
    pi_df_plot["dim"] = pi_df_plot["dim"].apply(lambda x: x.split("_")[1])
    pi_df_plot["dim"] = pi_df_plot["dim"].apply(lambda x: rf"$z_{{{x}}}$")

    # Group by dim and seed and take mean over perturbation iterations
    pi_df_plot = pi_df_plot.groupby(["dim", "seed"]).mean().reset_index()

    plot_order = (
        pi_df_plot.groupby("dim").mean().sort_values(by=f"{metric}_perm").index.tolist()
    )
    pi_df_plot = pi_df_plot[pi_df_plot["dim"].isin(plot_order[:n_feats])]

    pi_df_plot_hm = (
        pd.melt(
            pi_df_plot[["dim", f"fi_{metric}", "seed"]],
            id_vars=["dim", "seed"],
            value_vars=[f"fi_{metric}"],
        )
        .drop("variable", axis=1)
        .sort_values("value")
        .reset_index()
    )

    pi_df_plot_hm = pi_df_plot_hm.pivot_table(
        index="dim", columns="seed", values="value", sort=True
    )

    # Plot dimensions in order
    # plot_order = [rf"$z_{{{gene}}}$" for gene in constraints]
    # plot_order = plot_order + [
    #     rf"$z_{{{i+len(constraints)}}}$" for i in range(n_feats - len(constraints))
    # ]
    # Plot dimensions in order of importance
    plot_order = plot_order

    pi_df_plot_hm = pi_df_plot_hm.loc[plot_order]

    # print(pi_df_plot)
    pi_df_plot = pi_df_plot.set_index("dim").loc[plot_order].reset_index()

    print(pi_df_plot.groupby(["dim"]).mean().sort_values(f"fi_{metric}"))

    if metric == "rmse":
        pal_cm = sns.diverging_palette(
            rgb_to_hsv(pal[1])[0] * 360,
            rgb_to_hsv(pal[0])[0] * 360,
            s=100,
            center="light",
            as_cmap=True,
        )
    else:
        pal_cm = sns.diverging_palette(
            rgb_to_hsv(pal[0])[0] * 360,
            rgb_to_hsv(pal[1])[0] * 360,
            s=100,
            center="light",
            as_cmap=True,
        )

    # HEATMAP
    cbar_labels = {
        "r": r"$\Delta \% \rho_{{p}}$",
        "s": r"$\Delta \% \rho_{{s}}$",
        "rmse": r"$\Delta\%$ RMSE",
    }

    # BARPLOT
    sns.barplot(
        data=pi_df_plot,
        y="dim",
        x=f"fi_{metric}",
        estimator="mean",
        errorbar=("sd", 1),
        ax=ax,
        capsize=0.25,
        dodge=False,
        err_kws={"linewidth": 1, "alpha": 0.5},
    )
    sns.stripplot(
        data=pi_df_plot,
        y="dim",
        x=f"fi_{metric}",
        color="black",
        size=2,
        alpha=0.5,
        ax=ax,
        jitter=True,
    )
    ax.axvline(0, linestyle="--", color="grey", alpha=0.5)
    ax.set_ylim(len(plot_order) - 0.5, -0.5)
    # Get current y ticks and save
    y_tick_pos = [i for i in range(len(plot_order))]

    ax.set_yticks([], [])
    # for tick, label in zip(y_tick_pos, plot_order):
    #     ax.text(
    #         ax.get_xlim()[1]*2.0,
    #         tick,
    #         label,
    #         horizontalalignment="left",
    #         verticalalignment="center",
    #         fontsize=10,
    #     )
    #
    ax.set_ylabel("")
    ax.set_xlabel(metric_labels[metric])

    for i, bar in enumerate(ax.patches):
        if bar.get_height() > 0:
            if (ax.lines[i].get_xdata()[-1] > 0) and (ax.lines[i].get_xdata()[0] > 0):
                bar.set_color(pal_cm(0.75))
                color = pal_cm(1.0)
            elif (ax.lines[i].get_xdata()[-1] < 0) and (ax.lines[i].get_xdata()[0] < 0):
                bar.set_color(pal_cm(0.25))
                color = pal_cm(0.0)
            else:
                bar.set_color("lightgrey")
                color = "grey"

        ax.text(
            ax.get_xlim()[1] * 2.0,
            y_tick_pos[i],
            plot_order[i],
            horizontalalignment="left",
            verticalalignment="center",
            color=color,
            fontsize=10,
        )

    if metric != "rmse":
        ax.invert_xaxis()

    # y_tick_pos = [i+0.5 for i in range(len(plot_order))]
    # ax[0,1].yticks
    sns.despine(ax=ax, bottom=True, left=True)

    plt.subplots_adjust(wspace=0.00, hspace=0.05)

    if train:
        plt.savefig(
            f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}_new_train.svg",
            bbox_inches="tight",
        )
        plt.savefig(
            f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}_new_train.png",
            bbox_inches="tight",
            dpi=600,
        )
    else:
        plt.savefig(
            f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}_new.svg",
            bbox_inches="tight",
        )
        plt.savefig(
            f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}_new.png",
            bbox_inches="tight",
            dpi=600,
        )

In [None]:
pi_df_summ = pi_df_plot.groupby(["dim"]).mean().sort_values("fi_s")
if train:
    pi_df_summ.to_csv(
        f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}_train.csv"
    )
else:
    pi_df_summ.to_csv(f"./figures/fi_{target}_{fe}_{model}_{experiment}_{metric}.csv")

pi_df_summ.head(50)