### Setup
Run code here before anything else

In [None]:
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'facecolor': "w"}
%load_ext autoreload
%autoreload 2

In [None]:
# Setup
import sys
import os
wd_path = "/home/dk538/rds/hpc-work/pico/src"
sys.path.append(wd_path)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sklearn.cluster
from scipy.stats import pearsonr
from utils.data_utils import gene_column_renamer
import random
import pyreadr
import json
from tqdm.auto import tqdm
import urllib.request

#from models.clinical import ClinicalLogisticRegression, ClinicalRandomForest, ClinicalSVC

plt.style.use("default")
sns.set()
sns.set_theme(
    context="paper",
    style="ticks",
    palette="colorblind",
    rc={
        "axes.linewidth": 1,
        "xtick.major.width": 1,
        "ytick.major.width": 1,
        "axes.edgecolor": "grey",
        "xtick.labelcolor": "black",
        "xtick.color": "grey",
        "ytick.labelcolor": "black",
        "ytick.color": "grey",
    },
)
import matplotlib.font_manager as fm
import urllib.request
import matplotlib

# Download the font
font_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-Regular.ttf?raw=True"
font_path = f"{wd_path}/fonts/SourceSans3-Regular.ttf"  # Specify where to save the font
font_bold_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-Bold.ttf?raw=True"
font_bold_path = f"{wd_path}/fonts/SourceSans3-Bold.ttf"  # Specify where to save the font
font_it_url = "https://github.com/adobe-fonts/source-sans/blob/release/TTF/SourceSans3-It.ttf?raw=True"
font_it_path = f"{wd_path}/fonts/SourceSans3-It.ttf"  # Specify where to save the font
urllib.request.urlretrieve(font_url, font_path)
urllib.request.urlretrieve(font_bold_url, font_bold_path)
urllib.request.urlretrieve(font_it_url, font_it_path)

# in a terminal, run
# cp ~/rds/hpc-work/graphdep/results_analysis/figures/*ttf ~/.local/share/fonts
# fc-cache -f -v
# rm -fr ~/.cache/matplotlib

# Then restart Jupyter kernel

fm.findfont("Source Sans 3", rebuild_if_missing=True)
#fm.findfont("Source Sans 3:style=italic", rebuild_if_missing=True)

# Set font globally for Matplotlib
from matplotlib import rc

plt.style.use("default")
sns.set_theme(
    context="paper",
    style="ticks",
    palette="colorblind",
    rc={
        "axes.linewidth": 1,
        "xtick.major.width": 1,
        "ytick.major.width": 1,
        "axes.edgecolor": "grey",
        "xtick.labelcolor": "black",
        "xtick.color": "grey",
        "ytick.labelcolor": "black",
        "ytick.color": "grey",
    },
)

rc("font", **{"family": "sans-serif", "sans-serif": ["Source Sans 3"]})
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.it"] = "Source Sans 3:italic"

# 4 External validation using TransNEO + PBCP (Sammut et al. 2021)

## 4.0 Data loading and preprocessing

### 4.0.1 Data loading

In [None]:
data_path = f"{wd_path}/data/transneo"

os.makedirs(data_path, exist_ok=True)

## DATA DOWNLOAD
if not os.path.exists(f"{data_path}/transneo-diagnosis-RNAseq-rawcounts.tsv.gz"):
    urllib.request.urlretrieve("https://github.com/cclab-brca/neoadjuvant-therapy-response-predictor/raw/refs/heads/master/data/transneo-diagnosis-RNAseq-rawcounts.tsv.gz", f"{data_path}/transneo-diagnosis-RNAseq-rawcounts.tsv.gz")
if not os.path.exists(f"{data_path}/transneo-diagnosis-RNAseq-validTPM.Rdata"):
    urllib.request.urlretrieve("https://github.com/cclab-brca/neoadjuvant-therapy-response-predictor/raw/refs/heads/master/data/transneo-diagnosis-RNAseq-validTPM.Rdata", f"{data_path}/transneo-diagnosis-RNAseq-validTPM.Rdata")
if not os.path.exists(f"{data_path}/testing_her2pos_df.csv"):
    urllib.request.urlretrieve("https://raw.githubusercontent.com/micrisor/NAT-ML/refs/heads/main/inputs/testing_her2pos_df.csv", f"{data_path}/testing_her2pos_df.csv")
if not os.path.exists(f"{data_path}/testing_her2neg_df.csv"):
    urllib.request.urlretrieve("https://raw.githubusercontent.com/micrisor/NAT-ML/refs/heads/main/inputs/testing_her2neg_df.csv", f"{data_path}/testing_her2neg_df.csv")
if not os.path.exists(f"{data_path}/training_df.csv"):
    urllib.request.urlretrieve("https://raw.githubusercontent.com/micrisor/NAT-ML/refs/heads/main/inputs/training_df.csv", f"{data_path}/training_df.csv")
#ENSEMBL ID to HGNC map and gene lengths (provided in github repo)
if not os.path.exists(f"{data_path}/gene_lengths.txt"):
    raise ValueError("gene_lengths.txt not present. Please download from PiCo github repo.")
## DATA LOADING
# resp_df = pd.read_excel(
#     f"{data_path}/Supplementary Tables.xlsx", sheet_name="Supplementary Table 1"
# ).set_index("Donor.ID")
# resp_df_val = pd.read_excel(
#     f"{data_path}/Supplementary Tables.xlsx", sheet_name="Supplementary Table 4"
# ).set_index("Donor.ID")

# Combine training and validation data
feat_df = (
    pd.read_csv(f"{data_path}/training_df.csv")
    .drop("Unnamed: 0", axis=1)
    .set_index("Trial.ID")
)
feat_df_val_1 = (
    pd.read_csv(f"{data_path}/testing_her2neg_df.csv")
    .drop("Unnamed: 0", axis=1)
    .set_index("Trial.ID")
)
feat_df_val_2 = pd.read_csv(f"{data_path}/testing_her2pos_df.csv").set_index(
    "Trial.ID"
)
feat_df_val = pd.concat([feat_df_val_1, feat_df_val_2], axis=0)

### 4.0.2 RNASeq: Counts to TPM

In [None]:
def normalize_expression(X:pd.DataFrame, method:str="tpm", length:pd.Series=None):
    """
    # TPM
    Transcripts Per Kilobase Million
        (1) Divide the read counts by the length of each gene in kilobases. This gives you reads per kilobase (RPK).
        (2) Count up all the RPK values in a sample and divide this number by 1,000,000. This is your “per million” scaling factor.
        (3) Divide the RPK values by the “per million” scaling factor. This gives you TPM.

    # Notes:
    Proper pseudocount addition would be the following as shown by metagenomeSeq's MRcounts
    log(X_normalized + 1)
    """
    if method is None:
        return X

    method = method.lower()


    # Lengths
    if method in {"fpkm", "rpkm", "rpk",  "tpm"}:
        assert length is not None, "If FPKM, RPKM, TPM, or GeTMM is chosed as the method then `length` cannot be None.  It must be either a pd.Series of sequences or sequence lengths"
        length = pd.Series(length)[X.columns]
        assert length.isnull().sum() == 0, "Not all of the genes in `X.columns` are in `length.index`.  Either use a different normalization or get the missing sequence lengths"
        # If sequences are given then convert to length (assumes CDS and no introns)
        if pd.api.types.is_string_dtype(length):
            length = length.map(len)

    # FPKM, RPKM, and TPM normalization
    if method in {"fpkm", "rpkm", "rpk", "tpm"}:

        # Set up variables
        C = X.values
        L = length.values
        N = X.sum(axis=1).values.reshape(-1,1)


        if method in {"fpkm","rpkm"}:
            # Compute operations
            numerator = 1e9 * C
            denominator = (N*L)
            return pd.DataFrame(numerator/denominator, index=X.index, columns=X.columns)

        if method in {"rpk", "tpm"}:
            rpk = C/L
            if method == "rpk":
                return pd.DataFrame(rpk, index=X.index, columns=X.columns)
            if method == "tpm":
                per_million_scaling_factor = (rpk.sum(axis=1)/1e6).reshape(-1,1)
                return pd.DataFrame( rpk/per_million_scaling_factor, index=X.index, columns=X.columns)


def counts_to_tpm(counts_df, lengths_df):
    gene_int = sorted(list(set(counts_df.columns).intersection(lengths_df["gene"])))
    counts_df_filt = counts_df[gene_int]
    lengths = lengths_df.set_index("gene").loc[gene_int]
    # If multiple canonical transcripts take the longest one
    lengths = lengths.groupby(lengths.index).max()["length"]
    tpm_df = normalize_expression(counts_df, method="tpm", length=lengths)
    return tpm_df


# Load RNASeq raw counts -- training data
raw_counts = (
    pd.read_csv(f"{data_path}/transneo-diagnosis-RNAseq-rawcounts.tsv.gz", sep="\t")
    .set_index("Unnamed: 0")
    .transpose()
)


lengths = pd.read_csv(f"{data_path}/gene_lengths.txt", sep="\t")

# Uncomment if using a BioMART export
# The file included is as used in Sammut et al.

#lengths = lengths[lengths["Ensembl Canonical"] == 1.0]

#gene_int = set(raw_counts.columns).intersection(lengths["Gene stable ID"])
#lengths_df = pd.DataFrame(
#     {
#         "gene": lengths["Gene stable ID"],
#         "length": lengths["Transcript length (including UTRs and CDS)"],
#     }
# )
tpm = counts_to_tpm(raw_counts, lengths)
log_tpm = np.log2(tpm + 1)
exp_df = log_tpm

In [None]:
exp_df

In [None]:
#exp_df.to_csv("../data/transneo_exp_train.csv")

In [None]:
# np.savetxt("./ext_val_sammut/raw_counts_ensg.txt", raw_counts.columns.astype(str).tolist(), fmt="%s")

In [None]:
# lengths["HGNC symbol"].to_csv("./ext_val_sammut/train_genes.txt", index=False)

In [None]:
# Convert genes to HGNC symbol from BioMART export (not used)

# exp_df = exp_df.rename(
#     columns=dict(zip(lengths["Gene stable ID"], lengths["HGNC symbol"]))
# )

In [None]:
# Load RNASeq log2TPM for validation set
exp_df_val = pyreadr.read_r(
    "../data/transneo/transneo-diagnosis-RNAseq-validTPM.Rdata"
)[None]
# Convert to CCLE format (log2(TPM+1))
exp_df_val = np.log2(2**exp_df_val + 1)

In [None]:
# This has gene symbols, so we want to check these and then convert back to ENSG
exp_df_val

In [None]:
#Creates a file for HGNC multi-symbol checker
exp_df_val.reset_index()["rownames"].to_csv("../data/transneo/val_genes.txt", header=None, sep="\t", index=False)

In [None]:
# Go to HGNC multi-symbol checker and upload the above file (https://www.genenames.org/tools/multi-symbol-checker/), then load the result
hgnc_symbol_check = pd.read_csv("../data/transneo/hgnc_symbol_checker.csv", header=1)
# Get mapping of HGNC symbol from BioMART (https://www.ensembl.org/biomart/martview) and load (this uses current approved symbols)
# Both files downloaded 16/12/24
hgnc_to_ensg = pd.read_csv("../data/transneo/biomart_hgnc_to_ensg.csv")#.set_index("HGNC ID")

In [None]:
hgnc_symbol_check_map = {row["Input"]: row["Approved symbol"] for ind, row in hgnc_symbol_check.iterrows()}
hgnc_to_ensg_map = {row["HGNC symbol"]: row["Gene stable ID"] for ind, row in hgnc_to_ensg.iterrows()}

In [None]:
# Map rownames to approved symbols, then map to ENSG, then transpose for rows as samples
exp_df_val = exp_df_val.rename(hgnc_symbol_check_map, axis=0)
exp_df_val = exp_df_val.rename(hgnc_to_ensg_map, axis=0).transpose()

In [None]:
#exp_df.transpose().reset_index().groupby("Unnamed: 0").max().transpose().dropna(axis=1, how="all")
#exp_df_val.transpose().reset_index().groupby("rownames").max().transpose().dropna(axis=1, how="all")

### 4.0.3 RNASeq: Preprocessing

In [None]:
# Get intersection of genes in exp and exp_val
# Remove duplicate genes by taking mean
exp_df_val = (
    exp_df_val
    .transpose()
    .reset_index()
    .groupby("rownames")
    .mean()
    .transpose()
    .dropna(axis=1, how="all")
)
exp_df = (
    exp_df
    .transpose()
    .reset_index()
    .groupby("Unnamed: 0")
    .mean()
    .transpose()
    .dropna(axis=1, how="all")
)
# Get intersection of columns
shared_genes = sorted(list(set(exp_df_val.columns).intersection(set(exp_df.columns))))
exp_df = exp_df[shared_genes]
exp_df_val = exp_df_val[shared_genes]

In [None]:
exp_df

In [None]:
exp_df_val

In [None]:
# Mean impute NAs from exp_df for both (mean across rows)
exp_df = exp_df.fillna(exp_df.mean(axis=0))
exp_df_val = exp_df_val.fillna(exp_df.mean(axis=0))

In [None]:
# print(len(shared_genes))
# np.savetxt(
#     fname="./ext_val_sammut/sammut_genes.csv", X=shared_genes, fmt="%s", delimiter=","
# )
# sammut_genes = pd.read_csv("./ext_val_sammut/sammut_genes.csv", header=None)[0].tolist()
# print(len(set(sammut_genes).symmetric_difference(shared_genes)))

In [None]:
# Save both to file
exp_df.to_csv("../data/transneo/transneo_exp_filt.csv")
exp_df_val.to_csv("../data/transneo/transneo_exp_filt_val.csv")

In [None]:
exp_df = pd.read_csv("../data/transneo/transneo_exp_filt.csv").set_index("Unnamed: 0")
exp_df_val = pd.read_csv("../data/transneo/transneo_exp_filt_val.csv").set_index(
    "Unnamed: 0"
)

### 4.0.4 RNASeq: Comparison to precomputed values

In [None]:
# COMPARE PROVIDED VS RECOMUPUTED EXPRESSION VALUES
f, ax = plt.subplots(1, 1)
comb_df = pd.merge(feat_df, exp_df, left_index=True, right_index=True)
comb_df_val = pd.merge(feat_df_val, exp_df_val, left_index=True, right_index=True)
comb_df["ENSG00000091831"] = np.log2(2 ** (comb_df["ENSG00000091831"]) - 1 + 0.001)
comb_df["ENSG00000141736"] = np.log2(2 ** comb_df["ENSG00000141736"] - 1 + 0.001)
comb_df["ENSG00000082175"] = np.log2(2 ** comb_df["ENSG00000082175"] - 1 + 0.001)

lims = [
    np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
    np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]

ax.set_xlim(-5, 15)
ax.set_ylim(-5, 15)
ax.set_xlabel("TPM from counts")
ax.set_ylabel("TPM from data")
ax.plot(lims, lims, "k-", alpha=0.75, zorder=0)
sns.scatterplot(data=comb_df, x="ENSG00000091831", y="ESR1.log2.tpm", ax=ax)
sns.scatterplot(data=comb_df, x="ENSG00000141736", y="ERBB2.log2.tpm", ax=ax)
sns.scatterplot(data=comb_df, x="ENSG00000082175", y="PGR.log2.tpm", ax=ax)
comb_df_val["ENSG00000091831"] = np.log2(2 ** comb_df_val["ENSG00000091831"] - 1)
comb_df_val["ENSG00000141736"] = np.log2(2 ** comb_df_val["ENSG00000141736"] - 1)
comb_df_val["ENSG00000082175"] = np.log2(2 ** comb_df_val["ENSG00000082175"] - 1)
f, ax = plt.subplots(1, 1)
sns.scatterplot(data=comb_df_val, x="ENSG00000091831", y="ESR1.log2.tpm", ax=ax)
sns.scatterplot(data=comb_df_val, x="ENSG00000141736", y="ERBB2.log2.tpm", ax=ax)
sns.scatterplot(data=comb_df_val, x="ENSG00000082175", y="PGR.log2.tpm", ax=ax)
ax.set_xlim(-5, 15)
ax.set_ylim(-5, 15)


# now plot both limits against eachother
ax.plot(lims, lims, "k-", alpha=0.75, zorder=0)

In [None]:
# FUNCTIONS FOR LOADING DATA

from scipy.stats import ranksums


def plot_rcb(
    rep_df, resp_df, plot_dim, ax, plot_type="RCB.category", treat_str="F", filt=False
):
    # Combine dataframes
    sns.set_theme(context="paper", style="whitegrid", palette="Set2")
    plot_df = pd.merge(rep_df, resp_df, left_index=True, right_index=True, how="inner")
    plot_df[treat_str] = plot_df["NAT.regimen"].str.contains(treat_str)
    if filt:
        plot_df = plot_df[plot_df[treat_str]]
    if plot_type == "RCB.category":
        # f, ax = plt.subplots(1, 1, figsize=(2, 2.5))
        pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        rcb1_z = plot_df[plot_df[plot_type] == "RCB-I"][plot_dim]
        rcb2_z = plot_df[plot_df[plot_type] == "RCB-II"][plot_dim]
        rcb3_z = plot_df[plot_df[plot_type] == "RCB-III"][plot_dim]
        p_01 = ranksums(pcr_z, rcb1_z).pvalue
        p_02 = ranksums(pcr_z, rcb2_z).pvalue
        p_03 = ranksums(pcr_z, rcb3_z).pvalue
        if filt:
            treat_str = plot_type
        sns.boxplot(
            data=plot_df,
            x=plot_type,
            y=plot_dim,
            hue=treat_str,
            order=["pCR", "RCB-I", "RCB-II", "RCB-III"],
            width=0.6,
            whis=1.5,
            ax=ax,
        )
        ax.text(0.4, 1, f"p={p_01:.2g}", {"ha": "center"}, transform=ax.transAxes)
        ax.text(0.6, 0.95, f"p={p_02:.2g}", {"ha": "center"}, transform=ax.transAxes)
        ax.text(0.8, 0.90, f"p={p_03:.2g}", {"ha": "center"}, transform=ax.transAxes)
        ax.set_xticks([0, 1, 2, 3], ["pCR", "I", "II", "III"])
        if filt:
            ax.legend([], [], frameon=False)
        sns.despine(ax=ax)
    elif plot_type == "pCR.RD":
        # f, ax = plt.subplots(1, 1, figsize=(1, 2.5))
        pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        rd_z = plot_df[plot_df[plot_type] == "RD"][plot_dim]
        p = ranksums(pcr_z, rd_z).pvalue
        if filt:
            treat_str = plot_type
        sns.boxplot(
            data=plot_df,
            x=plot_type,
            y=plot_dim,
            hue=treat_str,
            order=["pCR", "RD"],
            width=0.6,
            whis=1.5,
            ax=ax,
        )
        ax.text(0.5, 0.95, f"p={p:.2g}", {"ha": "center"}, transform=ax.transAxes)
        if filt:
            ax.legend([], [], frameon=False)
        else:
            ax.legend(title="Treatment", loc=(0.9, 0.7))
        sns.despine(ax=ax)
    elif plot_type == "RCB.score":
        # f, ax = plt.subplots(1, 1, figsize=(2, 2))
        # plot_df = plot_df[plot_df["RCB.category"]!="pCR"]
        plot_df = plot_df[[plot_dim, plot_type, treat_str]].dropna(axis=0)
        z = plot_df[plot_dim]
        score = plot_df[plot_type]
        r, p = pearsonr(z, score)
        sns.scatterplot(data=plot_df, y=plot_type, x=plot_dim, hue=treat_str, ax=ax)
        ax.text(0.75, 0.75, f"p={p:.2g}", {"ha": "center"}, transform=ax.transAxes)
        if filt:
            ax.legend([], [], frameon=False)
        else:
            ax.legend(title="Treatment", loc=(0.9, 0.7))
        sns.despine(ax=ax)
    else:
        plt.figure(figsize=(1, 2.5))
        # pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        # rd_z = plot_df[plot_df[plot_type] == "RD"][plot_dim]
        # p = ranksums(pcr_z, rd_z).pvalue
        sns.boxplot(
            data=plot_df,
            x=plot_type,
            y=plot_dim,
            hue=treat_str,
            width=0.6,
            whis=1.5,
            ax=ax,
        )
        # plt.text(1, max(plot_df[plot_dim]+0.1), f"p={p:.2g}", {"ha":"center"})
        if filt:
            ax.legend([], [], frameon=False)
        sns.despine(ax=ax)
    ax.grid(False)


def plot_rcb_density(
    rep_df, resp_df, plot_dim, plot_type="RCB.category", treat_str="F", filt=False
):
    # Combine dataframes
    sns.set_theme(context="paper", style="whitegrid", palette="Set2")
    plot_df = pd.merge(rep_df, resp_df, left_index=True, right_index=True, how="inner")
    plot_df[treat_str] = plot_df["NAT.regimen"].str.contains(treat_str)
    if filt:
        plot_df = plot_df[plot_df[treat_str]]
    if plot_type == "RCB.category":
        f, ax = plt.subplots(1, 1, figsize=(3, 2.5))
        pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        rcb1_z = plot_df[plot_df[plot_type] == "RCB-I"][plot_dim]
        rcb2_z = plot_df[plot_df[plot_type] == "RCB-II"][plot_dim]
        rcb3_z = plot_df[plot_df[plot_type] == "RCB-III"][plot_dim]
        p_01 = ranksums(pcr_z, rcb1_z).pvalue
        p_02 = ranksums(pcr_z, rcb2_z).pvalue
        p_03 = ranksums(pcr_z, rcb3_z).pvalue
        if filt:
            treat_str = plot_type
        sns.kdeplot(
            data=plot_df,
            x=plot_dim,
            hue=plot_type,
            hue_order=["pCR", "RCB-I", "RCB-II", "RCB-III"],
            common_norm=False,
        )
        plt.text(0.4, 1, f"p={p_01:.2g}", {"ha": "center"}, transform=ax.transAxes)
        plt.text(0.6, 0.95, f"p={p_02:.2g}", {"ha": "center"}, transform=ax.transAxes)
        plt.text(0.8, 0.90, f"p={p_03:.2g}", {"ha": "center"}, transform=ax.transAxes)
        sns.despine(ax=ax)
    elif plot_type == "pCR.RD":
        f, ax = plt.subplots(1, 1, figsize=(3, 2.5))
        pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        rd_z = plot_df[plot_df[plot_type] == "RD"][plot_dim]
        p = ranksums(pcr_z, rd_z).pvalue
        if filt:
            treat_str = plot_type
        sns.kdeplot(
            data=plot_df,
            x=plot_dim,
            hue=plot_type,
            hue_order=["pCR", "RD"],
            common_norm=False,
        )
        plt.text(
            0.15, 0.9, f"Wilcoxon\np={p:.2g}", {"ha": "center"}, transform=ax.transAxes
        )
        plt.title(plot_dim)
        sns.despine()
    elif plot_type == "RCB.score":
        f, ax = plt.subplots(1, 1, figsize=(2, 2))
        # plot_df = plot_df[plot_df["RCB.category"]!="pCR"]
        plot_df = plot_df[[plot_dim, plot_type, treat_str]].dropna(axis=0)
        z = plot_df[plot_dim]
        score = plot_df[plot_type]
        r, p = pearsonr(z, score)
        sns.scatterplot(data=plot_df, y=plot_type, x=plot_dim, hue=treat_str)
        plt.text(0.75, 0.75, f"p={p:.2g}", {"ha": "center"}, transform=ax.transAxes)
        sns.despine()
    else:
        plt.figure(figsize=(1, 2.5))
        # pcr_z = plot_df[plot_df[plot_type] == "pCR"][plot_dim]
        # rd_z = plot_df[plot_df[plot_type] == "RD"][plot_dim]
        # p = ranksums(pcr_z, rd_z).pvalue
        sns.boxplot(
            data=plot_df, x=plot_type, y=plot_dim, hue=treat_str, width=0.6, whis=1.5
        )
        # plt.text(1, max(plot_df[plot_dim]+0.1), f"p={p:.2g}", {"ha":"center"})
        sns.despine()
    ax.grid(False)

In [None]:
# plot_rcb(rep_df_plot, resp_df, "z_TSPAN32", ax, plot_type="pCR.RD", filt=True)
nrows = 4
rep_dict_plot = rep_dict["PiCo_E"].loc[
    :, rep_dict["PiCo_E"].columns.str.startswith("z")
]
fig, axes = plt.subplots(
    nrows,
    int(np.ceil(len(rep_dict_plot.columns) / nrows)),
    figsize=(12, 8),
    sharex=True,
)
for i, col in enumerate(rep_dict_plot.columns):
    y = i // nrows
    x = i % nrows
    plot_rcb(rep_dict_plot, resp_df, col, axes[x, y], plot_type="pCR.RD", filt=True)
plt.tight_layout()

In [None]:
for col in rep_dict["PICo"].columns:
    if col.startswith("z"):
        plot_rcb_density(rep_dict["PICo"], resp_df, col, plot_type="pCR.RD", filt=True)

In [None]:
def plot_rcb_grid(rep_df, resp_df, dims, filt=True, treat_str="F"):
    sns.set_theme(context="paper", style="whitegrid", palette="Set2")
    plot_df = pd.merge(rep_df, resp_df, left_index=True, right_index=True, how="inner")
    rcb_cats = ["pCR", "RCB-I", "RCB-II", "RCB-III"]
    plot_df[treat_str] = plot_df["Chemo.Regimen"].str.contains(treat_str)
    if filt:
        plot_df = plot_df[plot_df[treat_str]]
    min_0 = min(plot_df[dims[0]])
    max_0 = max(plot_df[dims[0]])
    min_1 = min(plot_df[dims[1]])
    max_1 = max(plot_df[dims[1]])
    range_0 = max_0 - min_0
    range_1 = max_1 - min_1
    fig, ax = plt.subplots(1, 4, sharex=True, sharey=True, **{"figsize": (8, 2)})
    for i, cat in enumerate(rcb_cats):
        plot_z = plot_df[plot_df["RCB.category"] == cat][dims]
        sns.kdeplot(data=plot_z, x=dims[0], y=dims[1], ax=ax[i])
        ax[i].set_title(cat)
    sns.despine()

In [None]:
plot_rcb_grid(rep_dict["PiCo"], resp_df, ["z_HNF1B", "z_KANK4"], filt=True)

### 4.0.5 Constraint selection

In [None]:
from utils.data_utils import Manual, get_data_loaders, get_constraints, gene_column_renamer, process_data
# Getting constraints related to each drug in FEC-T
# SELECTING CONSTRAINTS
# Consensus of both datasets -- get constraints for GDSC and CTRP

drugs_gdsc = ["5-FLUOROURACIL", "EPIRUBICIN", "CYCLOPHOSPHAMIDE", "PACLITAXEL"]
drugs_ctrp = ["FLUOROURACIL", "DOXORUBICIN", "CYCLOPHOSPHAMIDE", "PACLITAXEL"]
dataset_names = ["depmap_gdsc", "depmap_ctrp"]

constraints_dict = {}

for drug in drugs_gdsc:
    constraints = get_constraints(
    drug=drug,
    dataset_name="depmap_gdsc",
    zdim=512,
    experiment=None,
    col_thresh=1.0,
    wd_path=wd_path,
    )
    constraints_dict[f"{drug}_gdsc"] = constraints

for drug in drugs_ctrp:
    constraints = get_constraints(
    drug=drug,
    dataset_name="depmap_ctrp",
    zdim=512,
    experiment=None,
    col_thresh=1.0,
    wd_path=wd_path,
    )
    constraints_dict[f"{drug}_ctrp"] = constraints
    

In [None]:
# Load all the constraint dfs
constraint_dfs = []
for drug in drugs_gdsc:
    curr_df = pd.read_csv(f"{wd_path}/data/constraints/univar_lm_{drug}_lrt_IC50_depmap_gdsc_v2.csv")
    curr_df["drug"] = drug
    curr_df["dataset_name"] = "depmap_gdsc"
    constraint_dfs.append(curr_df)

for drug in drugs_ctrp:
    curr_df = pd.read_csv(f"{wd_path}/data/constraints/univar_lm_{drug}_lrt_IC50_depmap_ctrp_v2.csv")
    curr_df["drug"] = drug
    curr_df["dataset_name"] = "depmap_ctrp"
    constraint_dfs.append(curr_df)

constraint_df = pd.concat(constraint_dfs, axis=0)


In [None]:
# Filter by signifcance after correction
constraint_df_filt = constraint_df[constraint_df["p_corrected"] < 0.05]
# Count the occurrence of each gene across the datasets and drugs (max 8)
constraint_df_filt["count"] = constraint_df_filt.groupby("gene")["gene"].transform("count")
# Print these, grouped by count and sorted by corrected p val within each count
constraint_df_filt.sort_values(["dataset_name", "drug", "p_corrected"]).groupby(["dataset_name", "drug"]).head(200000).sort_values(by=["count", "p_corrected"], ascending=[False, True]).head(50)

In [None]:
taxane_metagene = ["BUB1B", "CDADC1", "MASTL", "CDK1", "CSNK1A1L",
                    "STK4", "TTK", "EEF2K", "SCYL1", "AURKB", "UGCG",
                      "GBA1", "GBA3", "CERT1"]
taxane_metagene_aliases = {"CDC2": "CDK1", "COL4A3BP": "CERT1"}

In [None]:
from utils.data_utils import process_data, Manual
# Check data availability 
dataset_name = "depmap_gdsc_scanb_tcga"

# PROCESS DATASET
x, s, c, y, test_samples = process_data(dataset=dataset_name, wd_path=wd_path, experiment="tcga_surv")

dataset_params = {"var_filt_x": 1500, "var_filt_s": None}

dataset = Manual(x=x, s=s, c=c, y=y, constraints=["MCL1", "PSMC1", "FANCF", "RAD1", "PPM1D", "ESR1", "ERBB2", "CDKN2A", "TP53", "IGF1R", "FGFR2", "CCNE1", "PIK3CA", "GATA3", "MAP3K1", "EGFR"], target="BCFi_MONTHS", params=dataset_params)

## 4.1 Prediction

### 4.1.1 Prediction performance on TransNEO & ARTemis+PBCP

In [None]:
target = "resp.pCR"
experiment = "artemis_pbcp"
rep_types = {"vae": "VAE", "icovae_MCL1_16": "PiCo"}
#model_types = ["ElasticNet", "SVR", "RandomForestRegressor"]
model_types = ["LogisticRegression"]
feat_sets = {"RNA": "_PGR.log2.tpm_11_norep", "Rep": "", "Clinical+Rep": "_Size.at.diagnosis_7", "Clinical+Rep+RNA": "_Size.at.diagnosis_18", "Clinical+RNA": "_Size.at.diagnosis_18_norep", "Clinical": "_Size.at.diagnosis_7_norep"}
feat_sets_conf = {"Rep": [],
     "Clinical+Rep": ["Size.at.diagnosis", "LN.at.diagnosis", "Age.at.diagnosis", "Histology", "ER.status", "HER2.status", "Grade.pre.chemotherapy"],
     "Clinical": ["Size.at.diagnosis", "LN.at.diagnosis", "Age.at.diagnosis", "Histology", "ER.status", "HER2.status", "Grade.pre.chemotherapy"],
     "Clinical+Rep+RNA": ["Size.at.diagnosis", "LN.at.diagnosis", "Age.at.diagnosis", "Histology", "ER.status", "HER2.status", "Grade.pre.chemotherapy", "PGR.log2.tpm", "ESR1.log2.tpm", "ERBB2.log2.tpm", "GGI.ssgsea.notnorm", "ESC.ssgsea.notnorm", "Swanton.PaclitaxelScore", "STAT1.ssgsea.notnorm", "TIDE.Dysfunction", "TIDE.Exclusion", "Danaher.Mast.cells", "CytScore.log2"],
     "Clinical+RNA": ["Size.at.diagnosis", "LN.at.diagnosis", "Age.at.diagnosis", "Histology", "ER.status", "HER2.status", "Grade.pre.chemotherapy", "PGR.log2.tpm", "ESR1.log2.tpm", "ERBB2.log2.tpm", "GGI.ssgsea.notnorm", "ESC.ssgsea.notnorm", "Swanton.PaclitaxelScore", "STAT1.ssgsea.notnorm", "TIDE.Dysfunction", "TIDE.Exclusion", "Danaher.Mast.cells", "CytScore.log2"],
     "RNA": ["PGR.log2.tpm", "ESR1.log2.tpm", "ERBB2.log2.tpm", "GGI.ssgsea.notnorm", "ESC.ssgsea.notnorm", "Swanton.PaclitaxelScore", "STAT1.ssgsea.notnorm", "TIDE.Dysfunction", "TIDE.Exclusion", "Danaher.Mast.cells", "CytScore.log2"]}

# Mapping for feature names in plots
names_map = {
    "Danaher.Mast.cells": "Mast cell score",
    "PGR.log2.tpm": "$\t{{PGR}}$ expression",
    "ESR1.log2.tpm": "$\t{{ESR1}}$ expression",
    "ERBB2.log2.tpm": "$\t{{ERBB2}}$ expression",
    "HER2.status": "HER2 status",
    "Age.at.diagnosis": "Age at diagnosis",
    "LN.at.diagnosis": "LN involvement",
    "Grade.pre.chemotherapy": "Histological grade",
    "TIDE.Exclusion": "T cell exclusion",
    "TIDE.Dysfunction": "T cell dysfunction",
    "Size.at.diagnosis": "Tumour size",
    "GGI.ssgsea.notnorm": "GGI score",
    "ESC.ssgsea.notnorm": "ES cell score",
    "Swanton.PaclitaxelScore": "Taxane score",
    "CytScore.log2": "Cytolytic score",
    "STAT1.ssgsea.notnorm": "STAT1 score",
    "Histology": "Histological subtype",
    "ER.status": "ER status",
    "HRD.sum": "HRD score",
    "CodingMuts.PIK3CA": "$PIK3CA$ mutation status",
    "CodingMuts.TP53": "$TP53$ mutation status",
    "CIN.Prop": "Chromosomal instability",
    "All.TMB": "All TMB",
    "Coding.TMB": "Coding TMB",
    "Expressed.NAg": "Neoantigens",
    "HLA.LOH": "HLA LOH",
}

### 4.1.2 Performance overall

In [None]:
# LOAD TEST SET RESULTS
from utils.comp_utils import calculate_feat_imps, plot_feat_imps
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score, auc
from scipy.stats import spearmanr, pearsonr

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
seeds = [10,20,30,40,50,60,70,80,90,100]
hopt_seed = 4563

test_metrics_df = None
val_metrics_list = []
hopt_df = None


for feat_set, ext in feat_sets.items():
    for rep_type, rep_type_label in rep_types.items():
        for model_type in model_types:
            print(f"Processing {model_type}, {rep_type}, {ext}...")
            if (rep_type_label == "PiCo") and (feat_set in ["Clinical+Rep+RNA"]) and (model_type in ["ElasticNet", "LogisticRegression"]):
                try:
                    pred_dict_list, constraints, confounders, feat_imps_df = calculate_feat_imps(enc=rep_type, reg=model_type, model_path=f"{res_root}/{model_type}_{rep_type + ext}", target=target, seeds=seeds)
                    if feat_set in ["Clinical", "Clinical+RNA", "RNA"]:
                        zdim = 0
                        norep = True
                    else:
                        zdim = 64
                        norep = False

                    if (target == "RCB.score"):
                        plot_feat_imps(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_r_{rep_type+ext}", metric="r", norep=norep, sort_feats=True, top_k=16)
                        plot_feat_imps(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_s_{rep_type+ext}", metric="s", norep=norep, sort_feats=True, top_k=16)
                        plot_feat_imps(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_rmse_{rep_type+ext}", metric="rmse", norep=norep, sort_feats=True, top_k=16)
                    else:
                        plot_feat_imps(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_auroc_{rep_type+ext}", metric="auroc", norep=norep, sort_feats=True, top_k=16)
                except:
                    print(f"Cannot produce feat imp plot for {model_type}, {rep_type}, {ext}...")
            try:
                print("Loading hopt results...")
                #best_trial = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/opt_study_results_s{hopt_seed}.csv").sort_values("value", ascending=False)["number"][0]
                curr_hopt_df = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/cv_results_best_s{hopt_seed}.csv")

                preds_val =  None

                print("Loading validation predictions...")
                for fold in range(5):
                    curr_preds_val = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/z_pred_val_{fold}_best_s{hopt_seed}.csv")
                    if preds_val is None:
                        preds_val = curr_preds_val
                    else:
                        preds_val = pd.concat([preds_val, curr_preds_val], axis=0)

                print("Calculating validation metrics...")
                pred_metrics_val = preds_val[["pred_0", "y"]].dropna()
                pred_metrics_auroc_val = pred_metrics_val.copy()
                pred_metrics_auroc_val["y"] = (pred_metrics_auroc_val["y"] == 0.0).astype(float)
                auc_val = roc_auc_score(pred_metrics_auroc_val["y"], pred_metrics_auroc_val["pred_0"])
                auc_neg_val = roc_auc_score(pred_metrics_auroc_val["y"], -1*pred_metrics_auroc_val["pred_0"])
                auc_val = np.max([auc_val, auc_neg_val])
                f1_val = f1_score(pred_metrics_auroc_val["y"], (pred_metrics_auroc_val["pred_0"]>0.5).astype(float))
                precision, recall, thresholds = precision_recall_curve(
                    pred_metrics_auroc_val["y"], pred_metrics_auroc_val["pred_0"], pos_label=1
                    )
                val_aupr = auc(recall, precision)
                # Store CV results as well
                print("Storing validation metrics...")
                if target == "RCB.score":
                    val_metrics_list.append({"n": len(pred_metrics_val), "rep_type": rep_type_label, "model_type": model_type, "feat_sets": feat_set, "seed": hopt_seed,
                                            "val_spearmanr": spearmanr(pred_metrics_val["pred_0"], pred_metrics_val["y"])[0], "val_pearsonr": pearsonr(pred_metrics_val["pred_0"], pred_metrics_val["y"])[0],
                                                "val_rmse": (np.sqrt(pred_metrics_val["pred_0"] - pred_metrics_val["y"])**2).mean(), "dataset": "cv"})
                elif target == "resp.pCR":
                    val_metrics_list.append({"n": len(pred_metrics_val), "rep_type": rep_type_label, "model_type": model_type, "feat_sets": feat_set, "seed": hopt_seed,
                                                "val_auroc": auc_val, "val_aupr": val_aupr, "val_f1": f1_val, "dataset": "cv"})
                
                curr_test_metrics_df = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/test_metrics.csv")
                curr_test_metrics_df["rep_type"] = rep_type_label
                curr_test_metrics_df["model_type"] = model_type
                curr_hopt_df["rep_type"] = rep_type_label
                curr_hopt_df["model_type"] = model_type
                # Placeholder until we add clinical features etc
                curr_test_metrics_df["feat_sets"] = feat_set
                curr_hopt_df["feat_sets"] = feat_set
                curr_hopt_df["seed"] = 4563
                curr_test_metrics_df["seed"] = seeds
                print("Combining test metrics...")
                if test_metrics_df is None:
                    test_metrics_df = curr_test_metrics_df
                else:
                    test_metrics_df = pd.concat([test_metrics_df, curr_test_metrics_df], axis=0)
                print("Combining hopt results...")
                if hopt_df is None:
                    hopt_df = curr_hopt_df
                else:
                    hopt_df = pd.concat([hopt_df, curr_hopt_df], axis=0)
            except:
                print(f"{model_type}, {rep_type}, {ext}, {hopt_seed} not found")
                continue

print("Done loading...")

test_metrics_df["dataset"] = "test"
hopt_df["dataset"] = "cv"
val_metrics_df = pd.DataFrame(val_metrics_list)

# Keep all folds for hopt_df

# Print results on external validation for reporting
test_metrics_df.groupby(["rep_type", "model_type", "feat_sets", "dataset"]).mean().reset_index()

In [None]:
feat_imps_plot = feat_imps_df.copy()
if target == "RCB.score":
    feat_imps_plot["fi_s"] = 100 * (feat_imps_plot["s"] - feat_imps_plot["s_perm"]) / feat_imps_plot["s"]
elif target == "resp.pCR":
    feat_imps_plot["fi_auroc"] = feat_imps_plot["auroc"] - feat_imps_plot["auroc_perm"]
feat_imps_plot.groupby(["dim"]).mean().sort_values("fi_auroc", ascending=False)

In [None]:
# Load CV results
val_metrics_df.sort_values(by="val_spearmanr", ascending=False)  # .groupby(["rep_type", "model_type", "feat_sets"]).mean()

In [None]:
# COMBINED TRAIN AND EXT VAL PERFORMANCE -- DATASET FACET
# AUROC
from matplotlib.lines import Line2D

target = "resp.pCR"

metrics_dict = {"resp.pCR": ["auroc", "aupr", "f1"], "RCB.score": ["spearmanr", "pearsonr", "rmse"]}
models_dict = {"resp.pCR": ["LogisticRegression"], "RCB.score": ["ElasticNet", "SVR", "RandomForestRegressor"]}

for metric in metrics_dict[target]:
    for model in models_dict[target]:

        metric_hopt = f"val_{metric}"
        metric_test = f"test_{metric}"

        palette = sns.color_palette("colorblind")
        # palette = {"PiCo_D": pal[1], "VAE": pal[0],}
        pal_0_desat = sns.set_hls_values(palette[0], l=0.6, s=0.5)
        pal_1_desat = sns.set_hls_values(palette[1], l=0.6, s=0.5)
        pal_cv = {
            "VAE": palette[1],
            "PiCo": palette[0],
            "NA": "grey",
        }
        pal_ev = {
            "VAE": palette[1],
            "PiCo": palette[0],
            "NA": "grey",
        }

        plot_order = ["Clinical", "RNA", "Clinical+RNA", "Rep", "Clinical+Rep", "Clinical+Rep+RNA"]

        facet_order = ["TransNEO\nCross-validation", "ARTemis+PBCP\nExternal validation"]

        hopt_df_plot = val_metrics_df[val_metrics_df["model_type"] == model].copy()

        hopt_df_plot["rep_type_hue"] = hopt_df_plot["rep_type"].copy()
        hopt_df_plot.loc[
            hopt_df_plot["feat_sets"].isin(["Clinical+RNA", "Clinical", "RNA"]),
            "rep_type_hue",
        ] = "NA"

        fig, ax = plt.subplots(1,2, figsize=(4.5,2.5))

        ## PLOTS CV RESULTS
        if target == "RCB.score":
            g1 = sns.pointplot(
                data=hopt_df_plot,
                y="feat_sets",
                x=metric_hopt,
                hue="rep_type_hue",
                order=plot_order,
                palette=pal_cv,
                linestyle="--",
                linewidth=1,
                markersize=4,
                errorbar=("sd", 1),
                capsize=0.25,
                marker="o",
                err_kws={"linewidth": 1.5, "alpha": 0.25},
                #legend=False,
                ax=ax[0],
            )
            # , errorbar=("ci", 95), capsize=0.1, err_kws={"linewidth": 1, "alpha": 0.5})
        else:
            g1 = sns.pointplot(
                data=hopt_df_plot,
                y="feat_sets",
                x=metric_hopt,
                hue="rep_type_hue",
                order=plot_order,
                palette=pal_cv,
                linestyle="--",
                linewidth=1,
                markersize=4,
                errorbar=("sd", 1),
                capsize=0.25,
                marker="o",
                err_kws={"linewidth": 1.5, "alpha": 0.25},
                #legend=False,
                ax=ax[0],
            )  # , errorbar=("ci", 95), capsize=0.1, err_kws={"linewidth": 1, "alpha": 0.5})


        ## PLOTS EXT VAL RESULTS
        metrics_df_plot = test_metrics_df[test_metrics_df["model_type"] == model].copy(deep=True)
        metrics_df_plot["rep_type_hue"] = metrics_df_plot["rep_type"].copy()
        metrics_df_plot.loc[
            metrics_df_plot["feat_sets"].isin(["Clinical+RNA", "Clinical", "RNA"]), "rep_type_hue"
        ] = "NA"

        if target == "RCB.score":
            sns.pointplot(
                data=metrics_df_plot,
                y="feat_sets",
                x=metric_test,
                hue="rep_type_hue",
                order=plot_order,
                palette=pal_ev,
                linestyle="-",
                linewidth=1,
                markersize=4,
                marker="o",
                errorbar=("sd", 1),
                capsize=0.25,
                err_kws={"linewidth": 1.2, "alpha": 0.5},
                ax=ax[1],
                legend=False,
            )
        else:
            sns.pointplot(
                data=metrics_df_plot,
                y="feat_sets",
                x=metric_test,
                hue="rep_type_hue",
                palette=pal_ev,
                order=plot_order,
                linestyle="-",
                linewidth=1,
                markersize=4,
                marker="o",
                errorbar=("sd", 1),
                capsize=0.25,
                err_kws={"linewidth": 1.2, "alpha": 0.5},
                ax=ax[1],
                legend=False,
            )

        handles, labels = plt.gca().get_legend_handles_labels()

        line_train = Line2D(
            [0], [0], label="VAE", color=palette[1], linestyle="-", linewidth=1
        )
        line_val = Line2D(
            [0], [0], label="PiCo", color=palette[0], linestyle="-", linewidth=1
        )

        leg1 = ax[0].legend(
            handles=handles[3:],
            bbox_to_anchor=(1.0, 1.6),
            loc="upper center",
            frameon=False,
            ncol=3,
            fontsize=10,
        )

        leg2 = ax[0].legend(
            handles=[line_train, line_val],
            bbox_to_anchor=(-0.35, 1.45),
            loc="upper center",
            frameon=False,
            ncol=1,
            fontsize=10,
        )

        ax[0].add_artist(leg1)
        ax[0].add_artist(leg2)

        # Rename y ticks
        names_dict = {"Clinical": "Clinical", "RNA": fr"RNA", "Clinical+RNA": "Clinical+RNA", "Rep": fr"$\mathbf{{z}}$", "Clinical+Rep": rf"Clinical+$\mathbf{{z}}$", "Clinical+Rep+RNA": rf"Clinical+$\mathbf{{z}}$+RNA"}
        #ax.set_yticks(names_dict.keys())
        ax[0].set_yticklabels(names_dict.values())

        #sns.move_legend(ax, "upper center", bbox_to_anchor=(.5, -.2), ncol=1, title=None, frameon=False)
        sns.despine()
        for i, ax in enumerate(ax):
            ax.set_ylabel("")
            if target == "RCB.score":
                if metric_test == "test_spearmanr":
                    ax.set_xlabel("Spearman correlation", fontsize=10)
                    ax.set_xlim(0.5, 0.82)
                    ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
                elif metric_test == "test_pearsonr":
                    ax.set_xlabel("Pearson correlation", fontsize=10)
                    ax.set_xlim(0.5, 0.82)
                    ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
                elif metric_test == "test_rmse":
                    ax.set_xlabel("RMSE", fontsize=10)
                    ax.set_xlim(0.75, 1.15)
                    ax.set_xticks([0.8, 0.9, 1.0, 1.1], [0.8,0.9, 1.0, 1.1])
            else:
                if metric_test == "test_auroc":
                    ax.set_xlabel("AUROC", fontsize=10)
                    ax.set_xlim(0.68, 0.95)
                    ax.set_xticks([0.7, 0.8, 0.9], [0.7, 0.8, 0.9])
                elif metric_test == "test_aupr":
                    ax.set_xlabel("AUPR", fontsize=10)
                    ax.set_xlim(0.45, 0.85)
                    ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
                elif metric_test == "test_f1":
                    ax.set_xlabel("F1 score", fontsize=10)
                    ax.set_xlim(0.15, 0.75)
                    ax.set_xticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7], [0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
                elif metric_test == "test_cross_entropy":
                    ax.set_xlabel("Cross-entropy", fontsize=10)
                    ax.set_xlim(0.3, 0.6)
                    ax.set_xticks([0.3, 0.4, 0.5, 0.6], [0.3, 0.4, 0.5, 0.6])
            ax.set_title("")
            ax.text(
                s=facet_order[i],
                x=0.5,
                y=1.1,
                transform=ax.transAxes,
                fontweight="regular",
                horizontalalignment="center",
                fontsize=10,
            )
            ax.grid(visible=True, axis="x")
            ax.tick_params(labelsize=10)
            if i > 0:
                sns.despine(left=True, ax=ax)
                ax.tick_params(
                    top=False,
                    bottom=True,
                    left=False,
                    right=False,
                    labelleft=False,
                    labelbottom=True,
                    labelsize=10,
                )

        fig.tight_layout()

        plt.savefig(
            f"./figures/transneo/perf_facet_{target}_{model}_{metric}.png",
            #bbox_extra_artists=(leg1, leg2),
            bbox_inches="tight",
            dpi=600,
        )
        plt.savefig(
            f"./figures/transneo/perf_facet_{target}_{model}_{metric}.svg",
            #bbox_extra_artists=(leg1, leg2),
            bbox_inches="tight",
        )

In [None]:
# CV RESULTS REPORTING FOR PAPER
val_metrics_df.groupby(["feat_sets", "rep_type", "model_type"]).mean().sort_values(
    by=metric_hopt, ascending=False
)

In [None]:
# EV RESULTS REPORTING FOR PAPER
test_metrics_df.groupby(["feat_sets", "rep_type", "model_type"]).mean().sort_values(by=metric_test, ascending=metric_test=="test_rmse")

### 4.1.3 pCR prediction AUROC plots

In [None]:
# LOAD TEST SET RESULTS
from utils.comp_utils import calculate_feat_imps, plot_feat_imps_v2
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
seeds = [10,20,30,40,50,60,70,80,90,100]

test_metrics_list = []
roc_curves_list = []
hopt_df = None
plot_roc = True

for feat_set, ext in feat_sets.items():
    for rep_type, rep_type_label in rep_types.items():
        for model_type in model_types:
            # if rep_type_label == "PiCo":
            #     try:
            #         pred_dict_list, constraints, confounders, feat_imps_df = calculate_feat_imps(enc=rep_type, reg=model_type, model_path=f"{res_root}/{model_type}_{rep_type + ext}", target=target, seeds=seeds)
            #         if feat_set in ["Clinical", "Clinical+RNA", "RNA"]:
            #             zdim = 0
            #             norep = True
            #         else:
            #             zdim = 64
            #             norep = False

            #         if target == "RCB.score":
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_r_{rep_type+ext}", metric="r", norep=norep, sort_feats=True, top_k=16)
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_s_{rep_type+ext}", metric="s", norep=norep, sort_feats=True, top_k=16)
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_rmse_{rep_type+ext}", metric="rmse", norep=norep, sort_feats=True, top_k=16)
            #         else:
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_auroc_{rep_type+ext}", metric="auroc", norep=norep, sort_feats=True, top_k=16)
            #     except:
            #         print(f"Cannot produce feat imp plot for {model_type}, {rep_type}, {ext}...")
            preds_val =  None
            for fold in range(5):
                curr_preds_val = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/z_pred_val_{fold}_best_s4563.csv")
                if preds_val is None:
                    preds_val = curr_preds_val
                else:
                    preds_val = pd.concat([preds_val, curr_preds_val], axis=0)

            pred_metrics_val = preds_val[["pred_0", "y"]].dropna()
            pred_metrics_auroc_val = pred_metrics_val.copy()
            pred_metrics_auroc_val["y"] = (pred_metrics_auroc_val["y"] == 0.0).astype(float)
            auc_val = roc_auc_score(pred_metrics_auroc_val["y"], pred_metrics_auroc_val["pred_0"])
            auc_neg_val = roc_auc_score(pred_metrics_auroc_val["y"], -1*pred_metrics_auroc_val["pred_0"])
            auc_val = np.max([auc_val, auc_neg_val])
            if plot_roc:
                fpr = dict()
                tpr = dict()
                roc_auc = dict()
                fpr, tpr, _ = roc_curve(pred_metrics_auroc_val["y"], pred_metrics_auroc_val["pred_0"])
                fpr_, tpr_, _ = roc_curve(pred_metrics_auroc_val["y"], -1*pred_metrics_auroc_val["pred_0"])
                for i in range(len(fpr_)):
                    roc_curves_list.append({"fpr": fpr_[i], "tpr": tpr_[i], "model_type": model_type, "rep_type": rep_type_label, "feat_set": feat_set, "dataset": "cv", "seed": 4563})

            # pres, rec, thres = precision_recall_curve(pred_metrics_auroc_val["y"], pred_metrics_auroc_val["pred_0"])
            # aupr = auc(pres, rec)
            # Store CV results as well
            test_metrics_list.append({"n": len(pred_metrics_val), "rep_type": rep_type_label, "model_type": model_type, "feat_sets": feat_set, "seed": hopt_seed,
                                        "spearman_r": spearmanr(pred_metrics_val["pred_0"], pred_metrics_val["y"])[0], "pearson_r": pearsonr(pred_metrics_val["pred_0"], pred_metrics_val["y"])[0],
                                            "rmse": (np.sqrt(pred_metrics_val["pred_0"] - pred_metrics_val["y"])**2).mean(), "auroc": auc_val, "aupr": np.nan, "dataset": "cv"})
            for seed in seeds:
            #try:
                #best_trial = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/opt_study_results_s{seed}.csv").sort_values("value", ascending=False)["number"][0]
                #curr_hopt_df = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/cv_results_{best_trial}_s{seed}.csv")

                curr_preds = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/z_pred_test_s{seed}.csv")

                pred_metrics = curr_preds[["pred_0", "y"]].dropna()
                pred_metrics_auroc = pred_metrics.copy()
                pred_metrics_auroc["y"] = (pred_metrics_auroc["y"] == 0.0).astype(float)
                auc = roc_auc_score(pred_metrics_auroc["y"], pred_metrics_auroc["pred_0"])
                auc_neg = roc_auc_score(pred_metrics_auroc["y"], -1*pred_metrics_auroc["pred_0"])
                auc = np.max([auc, auc_neg])
                # pres, rec, thres = precision_recall_curve(pred_metrics_auroc["y"], pred_metrics_auroc["pred_0"])
                # aupr = auc(pres, rec)
                # Store CV results as well
                test_metrics_list.append({"n": len(pred_metrics), "rep_type": rep_type_label, "model_type": model_type, "feat_sets": feat_set, "seed": seed,
                                            "spearman_r": spearmanr(pred_metrics["pred_0"], pred_metrics["y"])[0], "pearson_r": pearsonr(pred_metrics["pred_0"], pred_metrics["y"])[0],
                                                "rmse": (np.sqrt(pred_metrics["pred_0"] - pred_metrics["y"])**2).mean(), "auroc": auc, "aupr": np.nan, "dataset": "test"})
                if plot_roc:
                    fpr = dict()
                    tpr = dict()
                    roc_auc = dict()
                    fpr_, tpr_, _ = roc_curve(pred_metrics_auroc["y"], -1*pred_metrics_auroc["pred_0"])
                    for i in range(len(fpr_)):
                        roc_curves_list.append({"fpr": fpr_[i], "tpr": tpr_[i], "model_type": model_type, "rep_type": rep_type_label, "feat_set": feat_set, "dataset": "test", "seed": seed})
                # except:
                #     print(f"{model_type}, {rep_type}, {ext}, {seed} not found")
                #     continue

test_metrics_df = pd.DataFrame.from_dict(test_metrics_list)

# Print results on external validation for reporting
test_metrics_df.groupby(["rep_type", "model_type", "feat_sets", "dataset", "n"]).mean().reset_index()

In [None]:
if plot_roc:
    roc_curves_df = pd.DataFrame.from_dict(roc_curves_list)
    # Plot ROC curves for clinical+rep+rna ElasticNet PiCo vs VAE (CV)
    model_type = "LogisticRegression"
    rep_type_label = "PiCo"
    feat_sets_plot = ["Clinical+Rep+RNA", "Clinical+Rep", "Clinical+RNA", "Clinical"]
    feat_sets_plot = ["Clinical", "Clinical+Rep"]
    roc_curve_data = roc_curves_df[(roc_curves_df["model_type"] == model_type) & (roc_curves_df["feat_set"].isin(feat_sets_plot)) & (roc_curves_df["dataset"] == "cv")]
    fig, ax = plt.subplots(figsize=(3,3))
    roc_curve_data["Feature extractor"] = roc_curve_data["rep_type"]
    roc_curve_data["Feature set"] = roc_curve_data["feat_set"]
    sns.lineplot(data=roc_curve_data, x="fpr", y="tpr", hue="Feature extractor", style="Feature set", ax=ax, hue_order=["PiCo", "VAE"], errorbar=None)
    #sns.lineplot(x=fpr, y=tpr, label=f"{model_type}_{rep_type_label}_{feat_set}", ci=None)
    for line in ax.lines:
        line.set_drawstyle("steps-post")
    plt.plot([0,1], [0,1], linestyle="--", color="grey")
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.legend(bbox_to_anchor=(0.4, 0.5), loc='upper left', frameon=False, ncol=1, title=None)
    sns.despine(ax=ax)
    plt.savefig(f"./figures/transneo/roc_cv_{model_type}_{rep_type_label}_{feat_set}.png", dpi=600, bbox_inches="tight")
    plt.savefig(f"./figures/transneo/roc_cv_{model_type}_{rep_type_label}_{feat_set}.svg", bbox_inches="tight")
    # Same for test set
    roc_curve_data = roc_curves_df[(roc_curves_df["model_type"] == model_type) & (roc_curves_df["feat_set"].isin(feat_sets_plot)) & (roc_curves_df["dataset"] == "test")]
    print(roc_curve_data)
    fig, ax = plt.subplots(figsize=(3,3))
    roc_curve_data["Feature extractor"] = roc_curve_data["rep_type"]
    roc_curve_data["Feature set"] = roc_curve_data["feat_set"]
    for seed in roc_curve_data["seed"].unique():
        sns.lineplot(data=roc_curve_data[roc_curve_data["seed"] == seed], x="fpr", y="tpr", hue="Feature extractor", style="Feature set", ax=ax, alpha=0.3, hue_order=["PiCo", "VAE"], errorbar=None, legend=False if seed > roc_curve_data["seed"].unique()[0] else True)
    #sns.lineplot(x=fpr, y=tpr, label=f"{model_type}_{rep_type_label}_{feat_set}", ci=None)
    for line in ax.lines:
            line.set_drawstyle("steps-post")
    plt.plot([0,1], [0,1], linestyle="--", color="grey")
    plt.xlabel("FPR")
    plt.ylabel("TPR")
    plt.legend(bbox_to_anchor=(0.4, 0.5), loc='upper left', frameon=False, ncol=1, title=None)
    sns.despine(ax=ax)
    plt.savefig(f"./figures/transneo/roc_test_{model_type}_{rep_type_label}_{feat_set}.png", dpi=600, bbox_inches="tight")
    plt.savefig(f"./figures/transneo/roc_test_{model_type}_{rep_type_label}_{feat_set}.svg", bbox_inches="tight")

In [None]:
# COMBINED TRAIN AND EXT VAL PERFORMANCE -- DATASET FACET
# AUROC
from matplotlib.lines import Line2D

metric = "auroc"
model = "LogisticRegression"

metric_hopt = f"{metric}"
metric = f"{metric}"

target = "resp.pCR"

palette = sns.color_palette("colorblind")
# palette = {"PiCo_D": pal[1], "VAE": pal[0],}
pal_0_desat = sns.set_hls_values(palette[0], l=0.6, s=0.5)
pal_1_desat = sns.set_hls_values(palette[1], l=0.6, s=0.5)
pal_cv = {
    "VAE": palette[1],
    "PiCo": palette[0],
    "NA": "grey",
}
pal_ev = {
    "VAE": palette[1],
    "PiCo": palette[0],
    "NA": "grey",
}

plot_order = ["Clinical", "RNA", "Clinical+RNA", "Rep", "Clinical+Rep", "Clinical+Rep+RNA"]

facet_order = ["TransNEO\nCross-validation", "ARTemis+PBCP\nExternal validation"]

hopt_df_plot = test_metrics_df[(test_metrics_df["model_type"] == model) & (test_metrics_df["dataset"] == "cv")].copy()
print(hopt_df_plot)

hopt_df_plot["rep_type_hue"] = hopt_df_plot["rep_type"].copy()
hopt_df_plot.loc[
    hopt_df_plot["feat_sets"].isin(["Clinical+RNA", "Clinical", "RNA"]),
    "rep_type_hue",
] = "NA"

fig, ax = plt.subplots(1,2, figsize=(4.5,2.5))

## PLOTS CV RESULTS
if target == "RCB.score":
    g1 = sns.pointplot(
        data=hopt_df_plot,
        y="feat_sets",
        x=metric_hopt,
        hue="rep_type_hue",
        order=plot_order,
        palette=pal_cv,
        linestyle="--",
        linewidth=1,
        markersize=4,
        errorbar=("sd", 1),
        capsize=0.25,
        marker="o",
        err_kws={"linewidth": 1.5, "alpha": 0.25},
        #legend=False,
        ax=ax[0],
    )
    # , errorbar=("ci", 95), capsize=0.1, err_kws={"linewidth": 1, "alpha": 0.5})
else:
    g1 = sns.pointplot(
        data=hopt_df_plot,
        y="feat_sets",
        x=metric_hopt,
        hue="rep_type_hue",
        order=plot_order,
        palette=pal_cv,
        linestyle="--",
        linewidth=1,
        markersize=4,
        errorbar=("sd", 1),
        capsize=0.25,
        marker="o",
        err_kws={"linewidth": 1.5, "alpha": 0.25},
        #legend=False,
        ax=ax[0],
    )  # , errorbar=("ci", 95), capsize=0.1, err_kws={"linewidth": 1, "alpha": 0.5})


## PLOTS EXT VAL RESULTS
metrics_df_plot = test_metrics_df[(test_metrics_df["model_type"] == model) & (test_metrics_df["dataset"] == "test")].copy(deep=True)
metrics_df_plot["rep_type_hue"] = metrics_df_plot["rep_type"].copy()
metrics_df_plot.loc[
    metrics_df_plot["feat_sets"].isin(["Clinical+RNA", "Clinical", "RNA"]), "rep_type_hue"
] = "NA"

if target == "RCB.score":
    sns.pointplot(
        data=metrics_df_plot,
        y="feat_sets",
        x=metric,
        hue="rep_type_hue",
        order=plot_order,
        palette=pal_ev,
        linestyle="-",
        linewidth=1,
        markersize=4,
        marker="o",
        errorbar=("sd", 1),
        capsize=0.25,
        err_kws={"linewidth": 1.2, "alpha": 0.5},
        ax=ax[1],
        legend=False,
    )
else:
    sns.pointplot(
        data=metrics_df_plot,
        y="feat_sets",
        x=metric,
        hue="rep_type_hue",
        palette=pal_ev,
        order=plot_order,
        linestyle="-",
        linewidth=1,
        markersize=4,
        marker="o",
        errorbar=("sd", 1),
        capsize=0.25,
        err_kws={"linewidth": 1.2, "alpha": 0.5},
        ax=ax[1],
        legend=False,
    )

handles, labels = plt.gca().get_legend_handles_labels()

line_train = Line2D(
    [0], [0], label="VAE", color=palette[1], linestyle="-", linewidth=1
)
line_val = Line2D(
    [0], [0], label="PiCo", color=palette[0], linestyle="-", linewidth=1
)

leg1 = ax[0].legend(
    handles=handles[3:],
    bbox_to_anchor=(1.0, 1.6),
    loc="upper center",
    frameon=False,
    ncol=3,
    fontsize=10,
)

leg2 = ax[0].legend(
    handles=[line_train, line_val],
    bbox_to_anchor=(-0.35, 1.45),
    loc="upper center",
    frameon=False,
    ncol=1,
    fontsize=10,
)

ax[0].add_artist(leg1)
ax[0].add_artist(leg2)

# Rename y ticks
names_dict = {"Clinical": "Clinical", "RNA": fr"RNA", "Clinical+RNA": "Clinical+RNA", "Rep": fr"$\mathbf{{z}}$", "Clinical+Rep": rf"Clinical+$\mathbf{{z}}$", "Clinical+Rep+RNA": rf"Clinical+$\mathbf{{z}}$+RNA"}
#ax.set_yticks(names_dict.keys())
ax[0].set_yticklabels(names_dict.values())

#sns.move_legend(ax, "upper center", bbox_to_anchor=(.5, -.2), ncol=1, title=None, frameon=False)
sns.despine()
for i, ax in enumerate(ax):
    ax.set_ylabel("")
    if target == "RCB.score":
        if metric == "test_spearmanr":
            ax.set_xlabel("Spearman\ncorrelation", fontsize=10)
            ax.set_xlim(0.5, 0.82)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
        elif metric == "test_pearsonr":
            ax.set_xlabel("Pearson\ncorrelation", fontsize=10)
            ax.set_xlim(0.5, 0.82)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
        elif metric == "test_rmse":
            ax.set_xlabel("RMSE", fontsize=10)
            ax.set_xlim(0.85, 1.15)
            ax.set_xticks([0.9, 1.0, 1.1], [0.9, 1.0, 1.1])
        elif metric == "auroc":
            ax.set_xlabel("AUROC", fontsize=10)
            ax.set_xlim(0.62, 0.95)
            ax.set_xticks([0.7, 0.8, 0.9], [0.7, 0.8, 0.9])
    else:
        if metric == "auroc":
            ax.set_xlabel("AUROC", fontsize=10)
            ax.set_xlim(0.62, 0.95)
            ax.set_xticks([0.7, 0.8, 0.9], [0.7, 0.8, 0.9])
        elif metric == "aupr":
            ax.set_xlabel("AUPR", fontsize=10)
            ax.set_xlim(0.35, 0.75)
            ax.set_xticks([0.4, 0.5, 0.6, 0.7], [0.4, 0.5, 0.6, 0.7])
        elif metric == "test_f1":
            ax.set_xlabel("F1 score", fontsize=10)
            ax.set_xlim(0.35, 0.75)
            ax.set_xticks([0.4, 0.5, 0.6, 0.7], [0.4, 0.5, 0.6, 0.7])
        elif metric == "test_cross_entropy":
            ax.set_xlabel("Cross-entropy", fontsize=10)
            ax.set_xlim(0.3, 0.6)
            ax.set_xticks([0.3, 0.4, 0.5, 0.6], [0.3, 0.4, 0.5, 0.6])
    ax.set_title("")
    ax.text(
        s=facet_order[i],
        x=0.5,
        y=1.1,
        transform=ax.transAxes,
        fontweight="regular",
        horizontalalignment="center",
        fontsize=10,
    )
    ax.grid(visible=True, axis="x")
    ax.tick_params(labelsize=10)
    if i > 0:
        sns.despine(left=True, ax=ax)
        ax.tick_params(
            top=False,
            bottom=True,
            left=False,
            right=False,
            labelleft=False,
            labelbottom=True,
            labelsize=10,
        )

fig.tight_layout()

plt.savefig(
    f"./figures/transneo/perf_facet_{target}_{metric}_auroc.png",
    #bbox_extra_artists=(leg1, leg2),
    bbox_inches="tight",
    dpi=600,
)
plt.savefig(
    f"./figures/transneo/perf_facet_{target}_{metric}_auroc.svg",
    #bbox_extra_artists=(leg1, leg2),
    bbox_inches="tight",
)

### 4.1.3 Performance by subtype

In [None]:
# LOAD TEST SET RESULTS
from utils.comp_utils import calculate_feat_imps, plot_feat_imps_v2
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import roc_auc_score

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
seeds = [10,20,30,40,50,60,70,80,90,100]

feat_sets_subtype = {"Clinical": "_Size.at.diagnosis_7_norep", "Clinical+Rep": "_Size.at.diagnosis_7", "Clinical+RNA": "_Size.at.diagnosis_18_norep", "Clinical+Rep+RNA": "_Size.at.diagnosis_18", "RNA": "_PGR.log2.tpm_11_norep", "Rep": ""}

test_metrics_list = []
hopt_df = None

for feat_set, ext in feat_sets_subtype.items():
    for rep_type, rep_type_label in rep_types.items():
        for model_type in model_types:
            # if rep_type_label == "PiCo":
            #     try:
            #         pred_dict_list, constraints, confounders, feat_imps_df = calculate_feat_imps(enc=rep_type, reg=model_type, model_path=f"{res_root}/{model_type}_{rep_type + ext}", target=target, seeds=seeds)
            #         if feat_set in ["Clinical", "Clinical+RNA", "RNA"]:
            #             zdim = 0
            #             norep = True
            #         else:
            #             zdim = 64
            #             norep = False

            #         if target == "RCB.score":
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_r_{rep_type+ext}", metric="r", norep=norep, sort_feats=True, top_k=16)
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_s_{rep_type+ext}", metric="s", norep=norep, sort_feats=True, top_k=16)
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_rmse_{rep_type+ext}", metric="rmse", norep=norep, sort_feats=True, top_k=16)
            #         else:
            #             plot_feat_imps_v2(feat_imps_df, target=target, constraints=constraints, confounders=feat_sets_conf[feat_set], zdim=zdim, enc=rep_type, reg=model_type, experiment=experiment, names_map=names_map, save_path=f"{model_type}_{target}_auroc_{rep_type+ext}", metric="auroc", norep=norep, sort_feats=True, top_k=16)
            #     except:
            #         print(f"Cannot produce feat imp plot for {model_type}, {rep_type}, {ext}...")
            for seed in seeds:
            #try:
                #best_trial = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/opt_study_results_s{seed}.csv").sort_values("value", ascending=False)["number"][0]
                #curr_hopt_df = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/cv_results_{best_trial}_s{seed}.csv")

                curr_preds = pd.read_csv(f"{res_root}/{model_type}_{rep_type + ext}/z_pred_test_s{seed}.csv")
                curr_preds_subtype = pd.read_csv(f"{res_root}/{model_type}_{rep_type}{feat_sets_subtype['Clinical']}/z_pred_test_s{seed}.csv")

                curr_preds_00 = curr_preds[(curr_preds_subtype["c_4"] < 0) & (curr_preds_subtype["c_5"] < 0)]
                curr_preds_10 = curr_preds[(curr_preds_subtype["c_4"] > 0) & (curr_preds_subtype["c_5"] < 0)]
                curr_preds_01 = curr_preds[(curr_preds_subtype["c_4"] < 0) & (curr_preds_subtype["c_5"] > 0)]
                curr_preds_11 = curr_preds[(curr_preds_subtype["c_4"] > 0) & (curr_preds_subtype["c_5"] > 0)]

                subtype_preds = {"ER-/HER2-": curr_preds_00, "ER+/HER2-": curr_preds_10, "ER-/HER2+": curr_preds_01, "ER+/HER2+": curr_preds_11}

                for subtype, preds in subtype_preds.items():
                    pred_metrics = preds[["pred_0", "y"]].dropna()
                    pred_metrics_auroc = pred_metrics.copy()
                    pred_metrics_auroc["y"] = (pred_metrics_auroc["y"] == 0.0).astype(float)
                    auc = roc_auc_score(pred_metrics_auroc["y"], pred_metrics_auroc["pred_0"])
                    auc_neg = roc_auc_score(pred_metrics_auroc["y"], -1*pred_metrics_auroc["pred_0"])
                    auc = np.max([auc, auc_neg])
                    test_metrics_list.append({"subtype": subtype, "n": len(pred_metrics), "rep_type": rep_type, "model_type": model_type, "feat_sets": feat_set, "seed": seed,
                                                "spearman_r": spearmanr(pred_metrics["pred_0"], pred_metrics["y"])[0], "pearson_r": pearsonr(pred_metrics["pred_0"], pred_metrics["y"])[0],
                                                    "rmse": (np.sqrt(pred_metrics["pred_0"] - pred_metrics["y"])**2).mean(), "auroc": auc})
                    
                # except:
                #     print(f"{model_type}, {rep_type}, {ext}, {seed} not found")
                #     continue

test_metrics_df = pd.DataFrame.from_dict(test_metrics_list)
test_metrics_df["dataset"] = "test"

# Print results on external validation for reporting
test_metrics_df.groupby(["rep_type", "model_type", "feat_sets", "dataset", "subtype", "n"]).mean().reset_index()

In [None]:
# COMBINED TRAIN AND EXT VAL PERFORMANCE -- DATASET FACET
# AUROC
from matplotlib.lines import Line2D

metric = "auroc"
model = "LogisticRegression"
target = "resp.pCR"

palette = sns.color_palette("colorblind")
# palette = {"PiCo_D": pal[1], "VAE": pal[0],}
pal_0_desat = sns.set_hls_values(palette[0], l=0.6, s=0.5)
pal_1_desat = sns.set_hls_values(palette[1], l=0.6, s=0.5)
pal_ev = {
    "vae": palette[1],
    "icovae_MCL1_16": palette[0],
    "NA": "grey",
}

plot_order = ["Clinical", "RNA", "Clinical+RNA", "Rep", "Clinical+Rep", "Clinical+Rep+RNA"]

facet_order = ["TransNEO\nCross-validation", "ARTemis+PBCP\nExternal validation"]

fig, axes = plt.subplots(1,4, figsize=(9,2.5), sharex=True)

## PLOTS EXT VAL RESULTS
metrics_df_plot = test_metrics_df[test_metrics_df["model_type"] == model].copy(deep=True)
metrics_df_plot["rep_type_hue"] = metrics_df_plot["rep_type"].copy()
metrics_df_plot.loc[
    metrics_df_plot["feat_sets"].isin(["Clinical+RNA", "Clinical", "RNA"]), "rep_type_hue"
] = "NA"

subtypes = metrics_df_plot["subtype"].unique()


for i, ax in enumerate(axes):
    curr_metrics_plot = metrics_df_plot[metrics_df_plot["subtype"] == subtypes[i]]
    if target == "RCB.score":
        sns.pointplot(
            data=curr_metrics_plot,
            y="feat_sets",
            x=metric,
            hue="rep_type_hue",
            order=plot_order,
            palette=pal_ev,
            linestyle="-",
            linewidth=1,
            markersize=4,
            marker="o",
            errorbar=("sd", 1),
            capsize=0.25,
            err_kws={"linewidth": 1.2, "alpha": 0.5},
            ax=ax,
            legend=False,
        )
    else:
        sns.pointplot(
            data=curr_metrics_plot,
            y="feat_sets",
            x=metric,
            hue="rep_type_hue",
            palette=pal_ev,
            order=plot_order,
            linestyle="-",
            linewidth=1,
            markersize=4,
            marker="o",
            errorbar=("sd", 1),
            capsize=0.25,
            err_kws={"linewidth": 1.2, "alpha": 0.5},
            ax=ax,
            legend=False,
        )

    handles, labels = plt.gca().get_legend_handles_labels()

    line_train = Line2D(
        [0], [0], label="VAE", color=palette[1], linestyle="-", linewidth=1
    )
    line_val = Line2D(
        [0], [0], label="PiCo", color=palette[0], linestyle="-", linewidth=1
    )

    if i == len(subtypes) - 1:
        leg1 = ax.legend(
            handles=handles[3:],
            bbox_to_anchor=(1.0, 1.6),
            loc="upper center",
            frameon=False,
            ncol=2,
            fontsize=10,
        )

        leg2 = ax.legend(
            handles=[line_train, line_val],
            bbox_to_anchor=(-1.15, 1.55),
            loc="upper center",
            frameon=False,
            ncol=2,
            fontsize=10,
        )

        ax.add_artist(leg1)
        ax.add_artist(leg2)

    # Rename y ticks
    names_dict = {"Clinical": "Clinical", "RNA": "RNA", "Clinical+RNA": "Clinical+RNA", "Clinical": "Clinical", "Rep": rf"$\mathbf{{z}}$", "Clinical+Rep": rf"Clinical+$\mathbf{{z}}$", "Clinical+Rep+RNA": rf"Clinical+$\mathbf{{z}}$+RNA", "RNA": "RNA"}
    #ax.set_yticks(names_dict.keys())
    ax.set_yticklabels(names_dict.values())

    #sns.move_legend(ax, "upper center", bbox_to_anchor=(.5, -.2), ncol=1, title=None, frameon=False)
    sns.despine()
    ax.set_ylabel("")
    if target == "RCB.score":
        if metric == "spearman_r":
            ax.set_xlabel("Spearman\ncorrelation", fontsize=10)
            ax.set_xlim(0.0, 1.0)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
        elif metric == "pearson_r":
            ax.set_xlabel("Pearson\ncorrelation", fontsize=10)
            ax.set_xlim(0.5, 0.82)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8], [0.5, 0.6, 0.7, 0.8])
        elif metric == "test_rmse":
            ax.set_xlabel("RMSE", fontsize=10)
            ax.set_xlim(0.85, 1.15)
            ax.set_xticks([0.9, 1.0, 1.1], [0.9, 1.0, 1.1])
        if metric == "auroc":
            ax.set_xlabel("AUROC", fontsize=10)
            ax.set_xlim(0.45, 1.0)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8, 0.9], [0.5, 0.6, 0.7, 0.8, 0.9])
    else:
        if metric == "auroc":
            ax.set_xlabel("AUROC", fontsize=10)
            ax.set_xlim(0.45, 1.02)
            ax.set_xticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], [0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
        elif metric == "aupr":
            ax.set_xlabel("AUPR", fontsize=10)
            ax.set_xlim(0.35, 0.75)
            ax.set_xticks([0.4, 0.5, 0.6, 0.7], [0.4, 0.5, 0.6, 0.7])
        elif metric == "f1":
            ax.set_xlabel("F1 score", fontsize=10)
            ax.set_xlim(0.35, 0.75)
            ax.set_xticks([0.4, 0.5, 0.6, 0.7], [0.4, 0.5, 0.6, 0.7])
        elif metric == "cross_entropy":
            ax.set_xlabel("Cross-entropy", fontsize=10)
            ax.set_xlim(0.3, 0.6)
            ax.set_xticks([0.3, 0.4, 0.5, 0.6], [0.3, 0.4, 0.5, 0.6])
    ax.set_title("")
    ax.text(
        s=f"{subtypes[i]}\n(n={curr_metrics_plot['n'].iloc[0]})",
        x=0.5,
        y=1.1,
        transform=ax.transAxes,
        fontweight="regular",
        horizontalalignment="center",
        fontsize=10,
    )
    ax.grid(visible=True, axis="x")
    ax.tick_params(labelsize=10)
    if i > 0:
        sns.despine(left=False, ax=ax)
        ax.tick_params(
            top=False,
            bottom=True,
            left=False,
            right=False,
            labelleft=False,
            labelbottom=True,
            labelsize=10,
        )
    #ax.set_xlabel("Spearman correlation")

fig.tight_layout()

plt.savefig(
    f"./figures/transneo/perf_facet_{target}_{metric}_subtype.png",
    #bbox_extra_artists=(leg1, leg2),
    bbox_inches="tight",
    dpi=600,
)
plt.savefig(
    f"./figures/transneo/perf_facet_{target}_{metric}_subtype.svg",
    #bbox_extra_artists=(leg1, leg2),
    bbox_inches="tight",
)

## 4.2 Representation associations with outcome and other variables

### 4.2.1 Single feature scatter plots

In [None]:
## SCATTER PLOT FEATURES
from scipy.stats import pointbiserialr, spearmanr
from functools import partial

def rep_renamer(x, constraints, prefix="z"):
    dim = int(x.split("_")[1])
    if dim < len(constraints):
        return f"{prefix}_{constraints[dim]}"
    else:
        return x


rep_type = "icovae_MCL1_16"
target = "resp.pCR"
experiment = "artemis_pbcp"
model_type = "LogisticRegression"

def pbc_corr(x1, x2):
    u_1 = len(np.unique(x1))
    u_2 = len(np.unique(x2))

    if min(u_1, u_2) == 2:
        res = pointbiserialr(x1, x2)
    else:
        res = spearmanr(x1, x2)

    return res[0]

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
ext = "_Size.at.diagnosis_18"

model_path = f"{res_root}/{model_type}_{rep_type + ext}"


corrs = []
corrs_val = []
for seed in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:

    # LOAD ARGUMENTS
    with open(f"{model_path}/args_best_s{seed}.txt", "r") as f:
        args = json.load(f)
    
    print(args)

    constraints = args["constraints"]
    n_constraints = len(constraints)
    confounders = args["confounders"]
    if confounders is None:
        n_confounders = 0
    else:
        n_confounders = len(confounders)

    # Load predictions
    test_z = pd.read_csv(f"{model_path}/z_pred_test_s{seed}.csv")
    train_z = pd.read_csv(f"{model_path}/z_pred_train_s{seed}.csv")

    # Rename test df
    test_z_rep_z = test_z.iloc[:, test_z.columns.str.startswith("z")]
    test_z_rep_c =  test_z.iloc[:, test_z.columns.str.startswith("c")]
    test_z_rep_z = test_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    test_z_rep_c = test_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )

    test_z_rep = pd.concat([test_z_rep_z, test_z_rep_c], axis=1)

    # Rename train df
    train_z_rep_z = train_z.iloc[:, train_z.columns.str.startswith("z")]
    train_z_rep_c =  train_z.iloc[:, train_z.columns.str.startswith("c")]
    train_z_rep_z = train_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    train_z_rep_c = train_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )

    train_z_rep = pd.concat([train_z_rep_z, train_z_rep_c], axis=1)

    feat_top_corr = pd.DataFrame(train_z_rep.corr(method="spearman"))
    feat_top_corr["seed"] = seed
    corrs.append(feat_top_corr)

    feat_top_corr_val = pd.DataFrame(test_z_rep.corr(method="spearman"))
    feat_top_corr_val["seed"] = seed
    corrs_val.append(feat_top_corr_val)

# corrs = pd.DataFrame(corrs)
# corrs_val = pd.DataFrame(corrs_val)

# r, p = pearsonr(x[feat_1], x[feat_2])
# print(f"{p*len(x.columns):.3e}")

# f, ax = plt.subplots(1,1, figsize=(2,2))
# sns.scatterplot(data=x, x=feat_1, y=feat_2, hue=hue, ax=ax, palette="Set2")
# sns.despine(ax=ax)
# ax.text(s=f"$r_{{p}} = {r:.3f}, p = {p*len(x.columns):.3f}$", x=0.1, y=1.0, transform=ax.transAxes, size=8)

# feat_1_split = feat_1.split("_")
# if feat_1_split[0] == "z":
#     ax.set_xlabel(rf"$z_{{{feat_1_split[1]}}}$")
# else:
#     ax.set_xlabel(names_map[feat_1])

# ax.set_ylabel(names_map[feat_2])


# plt.savefig(f"./ext_val_sammut/figures/corr_{feat_1}_{feat_2}.png", dpi=600, bbox_inches="tight")
# plt.savefig(f"./ext_val_sammut/figures/corr_{feat_1}_{feat_2}.svg", bbox_inches="tight")

### 4.2.2 All features correlation matrix heatmap

In [None]:
## REPRESENTATION CORRELATION WITH OTHER FEATURES
val = False

if val:
    corrs_df = pd.concat(corrs_val, axis=0).reset_index().groupby("index").mean().reset_index()
else:
    corrs_df = pd.concat(corrs, axis=0).reset_index().groupby("index").mean().reset_index()

pal = sns.color_palette("colorblind")


def feat_labeller(x):
    if x[0] == "z":
        return pal[0]
    elif x in [
        "ESR1.log2.tpm",
        "PGR.log2.tpm",
        "ERBB2.log2.tpm",
        "GGI.ssgsea.notnorm",
        "Swanton.PaclitaxelScore",
        "ESC.ssgsea.notnorm",
    ]:
        return pal[1]
    elif x in ["Danaher.Mast.cells", "STAT1.ssgsea.notnorm", "TIDE.Exclusion"]:
        return pal[2]
    elif x in [
        "CodingMuts.PIK3CA",
        "CodingMuts.TP53",
        "All.TMB",
        "Coding.TMB",
        "Expressed.NAg",
        "",
    ]:
        return pal[3]
    else:
        return pal[4]


# row_colors = corrs_df["index"].apply(lambda x: feat_labeller(x))
# row_colors.index = corrs_df["index"]
# row_colors = row_colors.rename(lambda x: x.split("_"))
# row_colors = row_colors.rename(lambda x: f"$z_{{{x[1]}}}$" if len(x) > 1 else names_map[x[0]])
# print(row_colors)

corrs_df = corrs_df.set_index("index")

col_plot = [
    not col.split("_")[1].isdigit() and (col.split("_")[0] == "z") if (len(col.split("_")) > 1) else False
    for col in corrs_df.columns
]

corrs_df = corrs_df.loc[~corrs_df.index.str.startswith("z"), col_plot].reset_index()

corrs_df["index"] = corrs_df["index"].apply(lambda x: x.split("_"))
corrs_df["index"] = corrs_df["index"].apply(
    lambda x: f"$z_{{{x[1]}}}$" if (x[0] == "z") else names_map[x[1]]
)

split_str = "_"

corrs_df = corrs_df.set_index("index").rename(
     lambda x: rf"$z_{{{x.split(split_str)[1]}}}$", axis=1
)

print(np.min(corrs_df))
print(np.max(corrs_df))

# Drop columns with all low values
corrs_df = corrs_df.T
#for col in corrs_df:
    # If all value in column less than 0.2 then drop
    # if (corrs_df[col].abs() < 0.2).all():
    #     corrs_df = corrs_df.drop(col, axis=1)

fig_width = len(corrs_df.columns)

g = sns.clustermap(
    corrs_df.T,
    yticklabels=True,
    xticklabels=True,
    figsize=(4.5, fig_width / 4.6),
    method="complete",
    cmap=sns.diverging_palette(220, 10, s=100, l=25, as_cmap=True),
    vmin=-1,
    vmax=1,
    center=0,
    cbar_pos=(0.75, 0.02, 0.02, 0.125),
    dendrogram_ratio=0.1,
)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize=12)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize=10)
g.ax_heatmap.set_ylabel("")
g.ax_heatmap.set_xlabel("")

plt.savefig(
    f"./figures/transneo/feat_corr_hm_{rep_type}_{'val' if val else 'train'}.png",
    bbox_inches="tight",
    dpi=600,
)
plt.savefig(
    f"./figures/transneo/feat_corr_hm_{rep_type}_{'val' if val else 'train'}.svg", bbox_inches="tight"
)

In [None]:
corrs_df

In [None]:
## REPRESENTATION CORRELATION WITH ITSELF
val = True
if val:
    corrs_df = pd.concat(corrs_val, axis=0).reset_index().groupby("index").mean().reset_index()
else:
    corrs_df = pd.concat(corrs, axis=0).reset_index().groupby("index").mean().reset_index()

corrs_df = corrs_df.set_index("index")

col_plot = [
    not col.split("_")[1].isdigit() if (col.split("_")[0] == "z") else False
    for col in corrs_df.columns
]

row_plot = [
    not row.split("_")[1].isdigit() if (row.split("_")[0] == "z") else False
    for row in corrs_df.index
]

corrs_df = corrs_df.loc[row_plot, col_plot].reset_index()

print(corrs_df)

corrs_df["index"] = corrs_df["index"].apply(lambda x: x.split("_"))
corrs_df["index"] = corrs_df["index"].apply(
    lambda x: f"$z_{{{x[1]}}}$" if len(x) > 1 else names_map[x[0]]
)

split_str = "_"

corrs_df = corrs_df.set_index("index").rename(
    lambda x: rf"$z_{{{x.split(split_str)[1]}}}$", axis=1
)

g = sns.clustermap(
    corrs_df,
    square=True,
    yticklabels=True,
    xticklabels=True,
    figsize=(3.25, 3),
    method="complete",
    cmap=sns.diverging_palette(220, 10, s=100, l=25, as_cmap=True),
    cbar_pos=(0.77, 0.02, 0.02, 0.125),
    dendrogram_ratio=0.1,
    center=0.0,
    vmax=1,
    vmin=-1,
)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize=10)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize=10)
g.ax_heatmap.set_ylabel("")

plt.savefig(
    f"./figures/transneo/rep_corr_hm_{rep_type}_{'val' if val else 'train'}.png",
    bbox_inches="tight",
    dpi=600,
)
plt.savefig(
    f"./figures/transneo/rep_corr_hm_{rep_type}_{'val' if val else 'train'}.svg", bbox_inches="tight"
)

### 4.2.3 Single variable associations with outcome

In [None]:
# CORRELATION WITH RCB SCORE
from scipy.stats import fisher_exact, mannwhitneyu
import json
from scipy.stats import pointbiserialr, spearmanr
from functools import partial
from sklearn.metrics import roc_auc_score

def rep_renamer(x, constraints, prefix="z"):
    dim = int(x.split("_")[1])
    if dim < len(constraints):
        return f"{prefix}_{constraints[dim]}"
    else:
        return x


rep_type = "icovae_MCL1_16"
target = "resp.pCR"
experiment = "artemis_pbcp"
model_type = "LogisticRegression"

def pbc_corr(x1, x2):
    u_1 = len(np.unique(x1))
    u_2 = len(np.unique(x2))

    if min(u_1, u_2) == 2:
        res = pointbiserialr(x1, x2)
    else:
        res = spearmanr(x1, x2)

    return res[0]

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
ext = "_Size.at.diagnosis_18"

model_path = f"{res_root}/{model_type}_{rep_type + ext}"


corrs = []
corrs_val = []
for seed in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:

    # LOAD ARGUMENTS
    with open(f"{model_path}/args_best_s{seed}.txt", "r") as f:
        args = json.load(f)
    
    print(args)

    constraints = args["constraints"]
    n_constraints = len(constraints)
    confounders = args["confounders"]
    if confounders is None:
        n_confounders = 0
    else:
        n_confounders = len(confounders)

    # Load predictions
    test_z = pd.read_csv(f"{model_path}/z_pred_test_s{seed}.csv")
    train_z = pd.read_csv(f"{model_path}/z_pred_train_s{seed}.csv")

    # Rename test df
    test_z_rep_z = test_z.iloc[:, test_z.columns.str.startswith("z")]
    test_z_rep_c =  test_z.iloc[:, test_z.columns.str.startswith("c")]
    test_z_rep_y = test_z[["y"]]
    test_z_rep_z = test_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    test_z_rep_c = test_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    test_z_rep_y = test_z_rep_y.rename({"y": target}, axis=1)

    test_z_rep = pd.concat([test_z_rep_z, test_z_rep_c, test_z_rep_y], axis=1).dropna(axis=0)

    # Rename train df
    train_z_rep_z = train_z.iloc[:, train_z.columns.str.startswith("z")]
    train_z_rep_c =  train_z.iloc[:, train_z.columns.str.startswith("c")]
    train_z_rep_y = train_z[["y"]]
    train_z_rep_z = train_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    train_z_rep_c = train_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    train_z_rep_y = train_z_rep_y.rename({"y": target}, axis=1)

    train_z_rep = pd.concat([train_z_rep_z, train_z_rep_c, train_z_rep_y], axis=1).dropna(axis=0)

    for col in train_z_rep.columns:
        if col != target:
            # Calculate AUROC for this feature
            if target == "resp.pCR":
                auc = roc_auc_score(train_z_rep[target], train_z_rep[col])
                auc_neg = roc_auc_score(train_z_rep[target], -1*train_z_rep[col])
                auc = np.max([auc, auc_neg])
            else:
                auc = np.nan
            corrs.append({"feat": col, "seed": seed, "spearmanr": spearmanr(train_z_rep[col], train_z_rep[target])[0], "pearsonr": pearsonr(train_z_rep[col], train_z_rep[target])[0], "auroc": auc})

    for col in test_z_rep.columns:
        if col != target:
            if target == "resp.pCR":
                auc = roc_auc_score(test_z_rep[target], test_z_rep[col])
                auc_neg = roc_auc_score(test_z_rep[target], -1*test_z_rep[col])
                auc = np.max([auc, auc_neg])
            else:
                auc = np.nan
            corrs_val.append({"feat": col, "seed": seed, "spearmanr": spearmanr(test_z_rep[col], test_z_rep[target])[0], "pearsonr": pearsonr(test_z_rep[col], test_z_rep[target])[0], "auroc": auc})

In [None]:
for metric in ["spearmanr", "pearsonr", "auroc"]:

    palette = sns.color_palette("colorblind")[2:]

    corrs_df_val = pd.DataFrame.from_dict(corrs_val).reset_index()

    corrs_df = pd.DataFrame.from_dict(corrs).reset_index()

    corrs_df_val["dataset"] = "ARTemis+PBCP"
    corrs_df["dataset"] = "TransNEO"

    corrs_df = pd.concat([corrs_df_val, corrs_df], axis=0)

    corrs_df["mean_spearmanr"] = corrs_df.groupby(["dataset", "feat"]).transform("mean")["spearmanr"].astype(float)
    corrs_df["mean_pearsonr"] = corrs_df.groupby(["dataset", "feat"]).transform("mean")["pearsonr"].astype(float)

    corrs_df["abs_spearmanr"] = corrs_df["spearmanr"].abs().astype(float)
    corrs_df["abs_pearsonr"] = corrs_df["pearsonr"].abs().astype(float)

    corrs_df["mean_auroc"] = corrs_df.groupby(["dataset", "feat"]).transform("mean")["auroc"].astype(float)
    corrs_df["abs_auroc"] = corrs_df["auroc"].abs().astype(float)

    corrs_df["feat"] = corrs_df["feat"].apply(lambda x: x.split("_"))
    corrs_df["feat"] = corrs_df["feat"].apply(
        lambda x: f"$z_{{{x[1]}}}$" if x[0] == "z" else names_map[x[1]]
    )

    corrs_df = corrs_df.sort_values(by=["dataset", f"mean_{metric.split('_')[-1]}"], key=lambda x: abs(x) if x.dtypes == "float64" else x, ascending=[False, False])
    print(corrs_df)

    fig, ax = plt.subplots(1,1, figsize=(2.5,3))

    sns.pointplot(data=corrs_df, x=metric, y="feat", hue="dataset", ax=ax, errorbar=("sd", 1),
        capsize=0.25,
        linestyle="none",
        markersize=3,
        err_kws={"linewidth": 1, "alpha": 0.5},
        palette=palette,)

    if metric == "abs_spearmanr":
        ax.set_xlabel("Abs. Spearman correlation")
        ax.set_xticks([-0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6])
    elif metric == "abs_pearsonr":
        ax.set_xlabel("Abs. Pearson correlation")
        ax.set_xticks([-0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6])
    elif metric == "spearmanr":
        ax.set_xlabel("Spearman correlation")
        ax.set_xticks([-0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6])
    elif metric == "pearsonr":
        ax.set_xlabel("Pearson correlation")
        ax.set_xticks([-0.6, -0.4, -0.2, 0.0, 0.2, 0.4, 0.6])
    elif metric == "auroc":
        ax.set_xlabel("AUROC")
        ax.set_xlim(0.4, 0.8)
        ax.set_xticks([0.4, 0.5, 0.6, 0.7, 0.8])
    else:
        raise ValueError("Invalid metric")
    ax.set_ylabel("")
    ax.tick_params(
        top=False,
        bottom=True,
        left=True,
        right=False,
        labelleft=True,
        labelbottom=True,
        labelsize=10,
            )
    sns.despine(ax=ax)

    ax.grid(visible=True, axis="x")

    ax.set_ylim(15.5,-0.5)
    ax.axvline(0, c="grey", lw=0.5)

    ax.legend(frameon=False, title="", ncol=1, bbox_to_anchor=(0.5,1.0), loc="lower center", fontsize=10)#, labels=["TransNEO", "ARTemis+PBCP"])

    plt.savefig(
        f"./figures/transneo/{metric}_assocs_{rep_type}.svg",
        bbox_inches="tight",
    )
    plt.savefig(
        f"./figures/transneo/{metric}_assocs_{rep_type}.png",
        bbox_inches="tight",
        dpi=600,
    )

In [None]:
corrs_df.groupby(["feat", "dataset"]).mean().sort_values(
    "abs_spearmanr", ascending=False
).head(50)

In [None]:
# CORRELATION WITH RCB SCORE SEPARATED BY SUBTYPE

from scipy.stats import fisher_exact, mannwhitneyu
import json
from scipy.stats import pointbiserialr, spearmanr
from functools import partial

def rep_renamer(x, constraints, prefix="z"):
    dim = int(x.split("_")[1])
    if dim < len(constraints):
        return f"{prefix}_{constraints[dim]}"
    else:
        return x


rep_type = "icovae_MCL1_16"
target = "resp.pCR"
experiment = "artemis_pbcp"
model_type = "LogisticRegression"

def pbc_corr(x1, x2):
    u_1 = len(np.unique(x1))
    u_2 = len(np.unique(x2))

    if min(u_1, u_2) == 2:
        res = pointbiserialr(x1, x2)
    else:
        res = spearmanr(x1, x2)

    return res[0]

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}_new/{experiment}/pico"
ext = "_Size.at.diagnosis_18"

model_path = f"{res_root}/{model_type}_{rep_type + ext}"


corrs_subtype = []
corrs_subtype_val = []
for seed in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:

    # LOAD ARGUMENTS
    with open(f"{model_path}/args_best_s{seed}.txt", "r") as f:
        args = json.load(f)
    
    print(args)

    constraints = args["constraints"]
    n_constraints = len(constraints)
    confounders = args["confounders"]
    if confounders is None:
        n_confounders = 0
    else:
        n_confounders = len(confounders)

    # Load predictions
    test_z = pd.read_csv(f"{model_path}/z_pred_test_s{seed}.csv")
    train_z = pd.read_csv(f"{model_path}/z_pred_train_s{seed}.csv")

    # Rename test df
    test_z_rep_z = test_z.iloc[:, test_z.columns.str.startswith("z")]
    test_z_rep_c =  test_z.iloc[:, test_z.columns.str.startswith("c")]
    test_z_rep_y = test_z[["y"]]
    test_z_rep_z = test_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    test_z_rep_c = test_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    test_z_rep_y = test_z_rep_y.rename({"y": target}, axis=1)

    test_z_rep = pd.concat([test_z_rep_z, test_z_rep_c, test_z_rep_y], axis=1).dropna(axis=0)

    # Rename train df
    train_z_rep_z = train_z.iloc[:, train_z.columns.str.startswith("z")]
    train_z_rep_c =  train_z.iloc[:, train_z.columns.str.startswith("c")]
    train_z_rep_y = train_z[["y"]]
    train_z_rep_z = train_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    train_z_rep_c = train_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    train_z_rep_y = train_z_rep_y.rename({"y": target}, axis=1)

    train_z_rep = pd.concat([train_z_rep_z, train_z_rep_c, train_z_rep_y], axis=1).dropna(axis=0)

    train_z_rep_00 = train_z_rep[(train_z_rep["c_ER.status"] < 0) & (train_z_rep["c_HER2.status"] < 0)]
    train_z_rep_10 = train_z_rep[(train_z_rep["c_ER.status"] > 0) & (train_z_rep["c_HER2.status"] < 0)]
    train_z_rep_01 = train_z_rep[(train_z_rep["c_ER.status"] < 0) & (train_z_rep["c_HER2.status"] > 0)]
    train_z_rep_11 = train_z_rep[(train_z_rep["c_ER.status"] > 0) & (train_z_rep["c_HER2.status"] > 0)]

    test_z_rep_00 = test_z_rep[(test_z_rep["c_ER.status"] < 0) & (test_z_rep["c_HER2.status"] < 0)]
    test_z_rep_10 = test_z_rep[(test_z_rep["c_ER.status"] > 0) & (test_z_rep["c_HER2.status"] < 0)]
    test_z_rep_01 = test_z_rep[(test_z_rep["c_ER.status"] < 0) & (test_z_rep["c_HER2.status"] > 0)]
    test_z_rep_11 = test_z_rep[(test_z_rep["c_ER.status"] > 0) & (test_z_rep["c_HER2.status"] > 0)]

    reps_subtype = {"ER-/HER2-": (train_z_rep_00, test_z_rep_00), "ER+/HER2-": (train_z_rep_10, test_z_rep_10), "ER-/HER2+": (train_z_rep_01, test_z_rep_01), "ER+/HER2+": (train_z_rep_11, test_z_rep_11)}

    for subtype, (train_rep, test_rep) in reps_subtype.items():
        for col in train_rep.columns:
            if target == "resp.pCR":
                auc = roc_auc_score(train_rep[target], train_rep[col])
                auc_neg = roc_auc_score(train_rep[target], -1*train_rep[col])
                auc = np.max([auc, auc_neg])
            else:
                auc = np.nan
            if col != target:
                corrs_subtype.append({"subtype": subtype, "feat": col, "seed": seed, "spearmanr": spearmanr(train_rep[col], train_rep[target])[0], "pearsonr": pearsonr(train_rep[col], train_rep[target])[0], "auroc": auc})

        for col in test_rep.columns:
            if target == "resp.pCR":
                auc = roc_auc_score(test_rep[target], test_rep[col])
                auc_neg = roc_auc_score(test_rep[target], -1*test_rep[col])
                auc = np.max([auc, auc_neg])
            else:
                auc = np.nan
            if col != target:
                corrs_subtype_val.append({"subtype": subtype, "feat": col, "seed": seed, "spearmanr": spearmanr(test_rep[col], test_rep[target])[0], "pearsonr": pearsonr(test_rep[col], test_rep[target])[0], "auroc": auc})

In [None]:
for metric in ["spearmanr", "pearsonr", "auroc"]:

    palette = sns.color_palette("colorblind")[2:]

    corrs_subtype_df_val = pd.DataFrame.from_dict(corrs_subtype_val).reset_index()

    corrs_subtype_df = pd.DataFrame.from_dict(corrs_subtype).reset_index()

    corrs_subtype_df_val["dataset"] = "ARTemis+PBCP"
    corrs_subtype_df["dataset"] = "TransNEO"

    corrs_subtype_df = pd.concat([corrs_subtype_df_val, corrs_subtype_df], axis=0).reset_index()

    corrs_subtype_df["mean_spearmanr"] = corrs_subtype_df.groupby(["dataset", "feat"]).transform("mean")["spearmanr"].astype(float)
    corrs_subtype_df["mean_pearsonr"] = corrs_subtype_df.groupby(["dataset", "feat"]).transform("mean")["pearsonr"].astype(float)

    corrs_subtype_df["abs_spearmanr"] = corrs_subtype_df["spearmanr"].abs().astype(float)
    corrs_subtype_df["abs_pearsonr"] = corrs_subtype_df["pearsonr"].abs().astype(float)

    corrs_subtype_df["mean_auroc"] = corrs_subtype_df.groupby(["dataset", "feat"]).transform("mean")["auroc"].astype(float)
    corrs_subtype_df["abs_auroc"] = corrs_subtype_df["auroc"].abs().astype(float)

    corrs_subtype_df["feat"] = corrs_subtype_df["feat"].apply(lambda x: x.split("_"))
    corrs_subtype_df["feat"] = corrs_subtype_df["feat"].apply(
        lambda x: f"$z_{{{x[1]}}}$" if x[0] == "z" else names_map[x[1]]
    )

    # Create a new empty DataFrame to store sorted entries
    sorted_corrs = []

    # Sort 'feat' within each subtype independently
    for subtype, df_sub in corrs_subtype_df.groupby("subtype"):
        # Compute mean correlation across datasets for ordering
        order = (
            df_sub.groupby("feat")[metric]
            .mean()
            .sort_values(ascending=False, key=abs)
            .index
        )

        # Convert feat to ordered categorical
        df_sub = df_sub.copy()
        df_sub["feat"] = pd.Categorical(df_sub["feat"], categories=order, ordered=True)

        sorted_corrs.append(df_sub)

    # Concatenate sorted chunks
    corrs_subtype_df = pd.concat(sorted_corrs, axis=0)

    # Store your subtypes
    subtypes = corrs_subtype_df["subtype"].unique()

    # Set color palette
    palette = sns.color_palette("colorblind")[2:]

    # Create subplots with one axis per subtype
    fig, axes = plt.subplots(1, len(subtypes), figsize=(3 * len(subtypes), 3.5), sharex=True, sharey=False)

    if len(subtypes) == 1:
        axes = [axes]  # Make iterable if only one axis

    for ax, subtype in zip(axes, subtypes):
        df_sub = corrs_subtype_df[corrs_subtype_df["subtype"] == subtype].copy()

        # Get ordering for this facet
        order = (
            df_sub.groupby("feat")[metric]
            .mean()
            .sort_values(ascending=False, key=abs)
            .head(15)
            .index
        )
        df_sub["feat"] = pd.Categorical(df_sub["feat"], categories=order, ordered=True)

        # Plot each dataset
        sns.pointplot(
            data=df_sub.sort_values(by="dataset", ascending=False),
            x=metric,
            y="feat",
            hue="dataset",
            dodge=True,
            errorbar=("sd", 1),
            capsize=0.25,
            ax=ax,
            linestyle="",
            palette=palette,
            legend=True,
            markersize=3,
            err_kws={"linewidth": 1, "alpha": 0.5},
        )

        ax.set_title(f"{subtype}")
        ax.set_xlabel({
            "spearmanr": "Spearman correlation",
            "abs_spearmanr": "Abs. Spearman correlation",
            "pearsonr": "Pearson correlation",
            "abs_pearsonr": "Abs. Pearson correlation",
            "auroc": "AUROC",
        }.get(metric, metric))
        ax.axvline(0, color='grey', lw=0.5)
        ax.grid(True, axis="x")
        ax.set_ylabel("")
        ax.tick_params(axis='y', labelsize=10)
        sns.despine(ax=ax)

    # Remove legends from individual axes
    for ax in axes:
        ax.legend_.remove()

    # Create a global legend from the first plot
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",        # or 'upper center', 'right', etc.
        ncol=len(labels),          # horizontal layout
        frameon=False,
        fontsize=10,
        bbox_to_anchor=(0.5, 0.97)  # (x, y) position relative to figure
    )

        #if ax != axes[0]:
            #ax.set_yticklabels([])  # Only show y labels on first facet

    plt.tight_layout()
    plt.savefig(
        f"./figures/transneo/{metric}_assocs_subtype_{rep_type}.svg",
        bbox_inches="tight",
    )
    plt.savefig(
        f"./figures/transneo/{metric}_assocs_subtype_{rep_type}.png",
        bbox_inches="tight",
        dpi=600,
    )

In [None]:
from sklearn.metrics import roc_auc_score
from scipy.stats import fisher_exact, mannwhitneyu
from functools import partial

def rep_renamer(x, constraints, prefix="z"):
    dim = int(x.split("_")[1])
    if dim < len(constraints):
        return f"{prefix}_{constraints[dim]}"
    else:
        return x

rep_type = "icovae_MCL1_16"
target = "RCB.score"
experiment = "artemis_pbcp"
model_type = "ElasticNet"

def pbc_corr(x1, x2):
    u_1 = len(np.unique(x1))
    u_2 = len(np.unique(x2))

    if min(u_1, u_2) == 2:
        res = pointbiserialr(x1, x2)
    else:
        res = spearmanr(x1, x2)

    return res[0]

class SigRes():
    def __init__(self, statistic, pvalue):
        self.statistic = statistic
        self.pvalue = pvalue

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}/{experiment}/pico"
ext = "_Size.at.diagnosis_18"

model_path = f"{res_root}/{model_type}_{rep_type + ext}"

hue = None
row = "c_HER2.status"
col = "c_ER.status"


def mwu_auc(x, y, x_rd, x_pcr, feat):
    res = mannwhitneyu(x_rd[feat], x_pcr[feat])
    auc = roc_auc_score(y, x[feat])
    auc_neg = roc_auc_score(y, -1 * x[feat])
    auc = np.max([auc, auc_neg])

    return res, auc


def fisher_auc(x, y, x_rd, x_pcr, feat):
    cont_tab = pd.merge(
        x_rd[feat].value_counts(),
        x_pcr[feat].value_counts(),
        left_index=True,
        right_index=True,
        how="outer",
    ).fillna(0)

    auc = roc_auc_score(y, x[feat])
    auc_neg = roc_auc_score(y, -1 * x[feat])
    auc = np.max([auc, auc_neg])

    try:
        res = fisher_exact(cont_tab)
        res = SigRes(res[0], res[1])
    except:
        res = SigRes(1.0, 1.0)
        #print(f"{feat}, {auc}")

    return res, auc


pcr_assocs = []
for seed in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
    # LOAD ARGUMENTS
    with open(f"{model_path}/args_best_s{seed}.txt", "r") as f:
        args = json.load(f)
    
    print(args)

    constraints = args["constraints"]
    n_constraints = len(constraints)
    confounders = args["confounders"]
    if confounders is None:
        n_confounders = 0
    else:
        n_confounders = len(confounders)

    # Load predictions
    test_z = pd.read_csv(f"{model_path}/z_pred_test_s{seed}.csv")
    train_z = pd.read_csv(f"{model_path}/z_pred_train_s{seed}.csv")

    # Rename test df
    test_z_rep_z = test_z.iloc[:, test_z.columns.str.startswith("z")]
    test_z_rep_c =  test_z.iloc[:, test_z.columns.str.startswith("c")]
    test_z_rep_y = test_z[["y"]]
    test_z_rep_z = test_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    test_z_rep_c = test_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    test_z_rep_y = test_z_rep_y.rename({"y": target}, axis=1)

    test_z_rep = pd.concat([test_z_rep_z, test_z_rep_c, test_z_rep_y], axis=1).dropna(axis=0)

    # Rename train df
    train_z_rep_z = train_z.iloc[:, train_z.columns.str.startswith("z")]
    train_z_rep_c =  train_z.iloc[:, train_z.columns.str.startswith("c")]
    train_z_rep_y = train_z[["y"]]
    train_z_rep_z = train_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    train_z_rep_c = train_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )
    train_z_rep_y = train_z_rep_y.rename({"y": target}, axis=1)

    train_z_rep = pd.concat([train_z_rep_z, train_z_rep_c, train_z_rep_y], axis=1).dropna(axis=0)

    for i, df in enumerate([train_z_rep, test_z_rep]):
        x_00_pcr = df.loc[(df[target] == 1) & (df[col] < 0) & (df[row] < 0)]
        x_10_pcr = df.loc[(df[target] == 1) & (df[col] > 0) & (df[row] > 0)]
        x_01_pcr = df.loc[(df[target] == 1) & (df[col] < 0) & (df[row] > 0)]
        x_11_pcr = df.loc[(df[target] == 1) & (df[col] > 0) & (df[row] > 0)]
        x_00_rd = df.loc[(df[target] == 0) & (df[col] < 0) & (df[row] < 0)]
        x_10_rd = df.loc[(df[target] == 0) & (df[col] > 0) & (df[row] < 0)]
        x_01_rd = df.loc[(df[target] == 0) & (df[col] < 0) & (df[row] > 0)]
        x_11_rd = df.loc[(df[target] == 0) & (df[col] > 0) & (df[row] > 0)]

        x_00 = df.loc[(df[col] < 0) & (df[row] < 0)]
        x_01 = df.loc[(df[col] < 0) & (df[row] > 0)]
        x_10 = df.loc[(df[col] > 0) & (df[row] < 0)]
        x_11 = df.loc[(df[col] > 0) & (df[row] > 0)]

        y_00 = x_00[target]
        y_01 = x_01[target]
        y_10 = x_10[target]
        y_11 = x_11[target]

        if i == 0:
            dataset = "TransNEO"
        else:
            dataset = "ARTemis+PBCP"

        for feat in df.columns:
            if feat != target:
                # Check if continuous else use Fisher's exact test
                if len(df[feat].unique()) > 2:
                    res_00, auc_00 = mwu_auc(x_00, y_00, x_00_rd, x_00_pcr, feat)
                    res_01, auc_01 = mwu_auc(x_01, y_01, x_01_rd, x_01_pcr, feat)
                    res_10, auc_10 = mwu_auc(x_10, y_10, x_10_rd, x_10_pcr, feat)
                    res_11, auc_11 = mwu_auc(x_11, y_11, x_11_rd, x_11_pcr, feat)
                else:
                    # print(np.stack((x_0_0_rd[feat].value_counts().to_numpy(), x_0_0_pcr[feat].value_counts().to_numpy()), axis=1))
                    res_00, auc_00 = fisher_auc(x_00, y_00, x_00_rd, x_00_pcr, feat)
                    res_01, auc_01 = fisher_auc(x_01, y_01, x_01_rd, x_01_pcr, feat)
                    res_10, auc_10 = fisher_auc(x_10, y_10, x_10_rd, x_10_pcr, feat)
                    res_11, auc_11 = fisher_auc(x_11, y_11, x_11_rd, x_11_pcr, feat)

                pcr_assocs.append(
                    {
                        "feat": feat,
                        "group": "ER-HER2-",
                        "p": res_00.pvalue,
                        "auc": auc_00,
                        "seed": seed,
                        "dataset": dataset,
                    }
                )
                pcr_assocs.append(
                    {
                        "feat": feat,
                        "group": "ER+HER2-",
                        "p": res_10.pvalue,
                        "auc": auc_10,
                        "seed": seed,
                        "dataset": dataset,
                    }
                )
                pcr_assocs.append(
                    {
                        "feat": feat,
                        "group": "ER-HER2+",
                        "p": res_01.pvalue,
                        "auc": auc_01,
                        "seed": seed,
                        "dataset": dataset,
                    }
                )
                pcr_assocs.append(
                    {
                        "feat": feat,
                        "group": "ER+HER2+",
                        "p": res_11.pvalue,
                        "auc": auc_11,
                        "seed": seed,
                        "dataset": dataset,
                    }
                )

In [None]:
## PLOTTING
## DO AUPR, DO ALL SAMPLES
n_feats = 20

pcr_assocs_df = pd.DataFrame(pcr_assocs)
pcr_assocs_df["logp"] = np.log(pcr_assocs_df["p"])

pcr_assocs_df["feat"] = pcr_assocs_df["feat"].apply(lambda x: x.split("_"))
pcr_assocs_df["feat"] = pcr_assocs_df["feat"].apply(
    lambda x: rf"$z_{{{x[1]}}}$" if x[0] == "z" else names_map[x[1]]
)

pcr_assocs_df_mean = (
    pcr_assocs_df[pcr_assocs_df["dataset"] == "TransNEO"]
    .groupby(["feat", "group", "dataset"])
    .mean()
    .sort_values(by="auc", ascending=False)
    .groupby(["group", "dataset"])
    .head(n_feats)
    .reset_index()
    .set_index(["feat", "group"])
)
# print(pcr_assocs_df_mean)

# if row == "c_HER2.status":
#     col_order = ["HER2-", "HER2+"]
# elif row == "c_ER.status":
#     col_order = ["ER-", "ER+"]

# pcr_assocs_df_plot = pcr_assocs_df.set_index(["feat", "group"]).loc[pcr_assocs_df_mean.index].reset_index()

pcr_assocs_df_hm = (
    pcr_assocs_df.set_index(["feat", "group"])
    .loc[pcr_assocs_df_mean.index]
    .reset_index()
    .pivot_table(index=["dataset", "group"], columns="feat", values="auc")
)

g1 = sns.catplot(
    data=pcr_assocs_df.set_index(["feat", "group"])
    .loc[pcr_assocs_df_mean.index]
    .reset_index(),
    height=n_feats / 5.5,
    aspect=0.8,
    kind="point",
    col="group",
    #col_order=col_order,
    sharey=False,
    sharex=True,
    x="auc",
    y="feat",
    hue="dataset",
    errorbar=("sd", 1),
    capsize=0.25,
    linestyle="none",
    markersize=3,
    err_kws={"linewidth": 1, "alpha": 0.5},
    palette="colorblind",
)
axes = g1.axes.flatten()
for i, ax in enumerate(axes):
    # ax.invert_xaxis()
    ax.set_title(col_order[i], fontweight="semibold")
    ax.set_xlabel("AUROC")
    ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    ax.grid(visible=True, axis="x")
    ax.set_xlim(0, 1)
    ax.set_ylabel("")
sns.despine()

handles, labels = plt.gca().get_legend_handles_labels()

leg1 = axes[0].legend(
    handles=handles,
    bbox_to_anchor=(1.3, 1.2),
    loc="upper center",
    frameon=False,
    ncol=2,
)
axes[0].add_artist(leg1)
g1._legend.remove()
axes[0].set_yticklabels(axes[0].get_ymajorticklabels(), fontsize=10)
axes[1].set_yticklabels(axes[1].get_ymajorticklabels(), fontsize=10)

plt.savefig(
    f"./ext_val_sammut/figures/pcr_assocs_{rep_type}_{row}_{n_feats}.svg",
    bbox_inches="tight",
    bbox_extra_artists=(leg1,),
)
plt.savefig(
    f"./ext_val_sammut/figures/pcr_assocs_{rep_type}_{row}_{n_feats}.png",
    bbox_inches="tight",
    bbox_extra_artists=(leg1,),
    dpi=600,
)

# print(pcr_assocs_df_hm)

# f, ax = plt.subplots(1,2, figsize=(5,5), sharey=False)
# sns.heatmap(pcr_assocs_df_hm.loc(axis=0)[:, "HER2-"].transpose().dropna(axis=0), square=True, cmap="Blues", yticklabels=True, xticklabels=True, ax=ax[0], vmin=0, vmax=1)
# sns.heatmap(pcr_assocs_df_hm.loc(axis=0)[:, "HER2+"].transpose().dropna(axis=0), square=True, cmap="Blues", yticklabels=True, xticklabels=True, ax=ax[1], vmin=0, vmax=1)

In [None]:
# Print results for reporting
pcr_assocs_df.groupby(["feat", "group", "dataset"]).mean().sort_values(
    "auc", ascending=False
).head(50)

## 4.3 Representation UMAPs

In [None]:
from umap import UMAP
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

## SCATTER PLOT FEATURES
from scipy.stats import pointbiserialr, spearmanr
from functools import partial

def rep_renamer(x, constraints, prefix="z"):
    dim = int(x.split("_")[1])
    if dim < len(constraints):
        return f"{prefix}_{constraints[dim]}"
    else:
        return x


rep_type = "icovae_MCL1_16"
#rep_type = "vae"
target = "RCB.score"
experiment = "artemis_pbcp"
model_type = "ElasticNet"

def pbc_corr(x1, x2):
    u_1 = len(np.unique(x1))
    u_2 = len(np.unique(x2))

    if min(u_1, u_2) == 2:
        res = pointbiserialr(x1, x2)
    else:
        res = spearmanr(x1, x2)

    return res[0]

def plot_umap(x, hue, n_neighbors=20, seed=10):

    pal = sns.color_palette("colorblind")

    # merge dfs
    if hue in ["TP53", "PIK3CA", "PTEN"]:
        palette = {1.0: pal[0], 0.0: pal[1], "NA": "lightgrey"}
    elif hue.startswith(("z_", "GGI", "STAT1", "All.TMB", "Coding.TMB", "c_GGI")):
        palette = sns.color_palette("Blues", as_cmap=True)
    elif hue == "PAM50":
        palette = {
            "Basal": pal[0],
            "LumA": pal[1],
            "LumB": pal[2],
            "Her2": pal[3],
            "Normal": pal[4],
            "NA": "lightgrey",
        }
    elif hue == "c_MolType":
        palette = {
            "ER+/HER2-": pal[0],
            "ER+/HER2+": pal[1],
            "ER-/HER2+": pal[2],
            "ER-/HER2-": pal[3],
            "NA": "lightgrey",
        }
    else:
        palette = pal

    n_samp = len(x)
    print(f"n: {n_samp}")

    fig, ax = plt.subplots(1, 1, figsize=(3, 3))

    surv_rep_zs = x.loc[:, x.columns.str.startswith("z_")]

    # standardize z dimensions for t-SNE
    scaler = StandardScaler()
    surv_rep_zs = scaler.fit_transform(surv_rep_zs)

    results_tsne_df = UMAP(
        n_components=2, n_neighbors=n_neighbors, min_dist=0.01, random_state=seed,
    ).fit_transform(surv_rep_zs)

    results_tsne_df = pd.DataFrame(results_tsne_df).rename(
        {0: "UMAP1", 1: "UMAP2"}, axis=1
    )
    results_tsne_df.index = x.index

    results_tsne_df = pd.merge(results_tsne_df, x, left_index=True, right_index=True)

    sns.scatterplot(
        data=results_tsne_df,
        x="UMAP1",
        y="UMAP2",
        hue=hue,
        style="dataset",
        palette=palette,
        ax=ax,
        s=30,
    )
    plt.xticks([], [])
    plt.yticks([], [])

    sns.despine(bottom=True, top=True, left=True, right=True)

    h, l = ax.get_legend_handles_labels()
    #h, l = zip(*sorted(zip(h, l), key=lambda t: t[1]))
    h[5].text = "Dataset"
    h = h[1:]
    l = l[1:]

    for handle in h:
        handle.set_markersize(6.0)
    hue_map = {
        "c_PAM50": "PAM50 subtype",
        "c_NCN.PAM50": "PAM50 subtype",
        "c_ClinGroup": "Clinical subgroup",
        "TP53": "TP53 mutation status",
        "PIK3CA": "PIK3CA mutation status",
        "c_GGI.ssgsea.notnorm": "GGI score",
        "c_ER.status": "ER status",
        "c_HER2.status": "HER2 status",
        "c_Histology": "Histology",
        "c_MolType": "Molecular subtype",
        "dataset": "Dataset",
    }
    ax.legend(
        handles=h,
        labels=l,
        title=hue_map[hue],
        fontsize="medium",
        ncol=1,
        frameon=False,
        bbox_to_anchor=(1.20, 1.15),
        loc="upper center",
        title_fontproperties={"style": "italic", "size": "medium"},
    )

    # plt.legend('',frameon=False)
    plt.savefig(
        f"./figures/transneo/UMAP_{hue}_n{n_neighbors}_s{seed}.png",
        bbox_inches="tight",
        dpi=600,
    )
    plt.savefig(
        f"./figures/transneo/UMAP_{hue}_n{n_neighbors}_s{seed}.svg", bbox_inches="tight"
    )

res_root = f"{wd_path}/data/outputs/depmap_gdsc_transneo/{target}/{experiment}/pico"
ext = "_Size.at.diagnosis_18"

model_path = f"{res_root}/{model_type}_{rep_type + ext}"

corrs = []
corrs_val = []
for seed in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:

    # LOAD ARGUMENTS
    with open(f"{model_path}/args_best_s{seed}.txt", "r") as f:
        args = json.load(f)
    
    print(args)

    constraints = args["constraints"]
    n_constraints = len(constraints)
    confounders = args["confounders"]
    if confounders is None:
        n_confounders = 0
    else:
        n_confounders = len(confounders)

    # Load predictions
    test_z = pd.read_csv(f"{model_path}/z_pred_test_s{seed}.csv")
    train_z = pd.read_csv(f"{model_path}/z_pred_train_s{seed}.csv")

    # Rename test df
    test_z_rep_z = test_z.iloc[:, test_z.columns.str.startswith("z")]
    test_z_rep_c =  test_z.iloc[:, test_z.columns.str.startswith("c")]
    test_z_rep_z = test_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    test_z_rep_c = test_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )

    test_z_rep = pd.concat([test_z_rep_z, test_z_rep_c], axis=1)

    # Rename train df
    train_z_rep_z = train_z.iloc[:, train_z.columns.str.startswith("z")]
    train_z_rep_c =  train_z.iloc[:, train_z.columns.str.startswith("c")]
    train_z_rep_z = train_z_rep_z.rename(
        mapper=partial(rep_renamer, constraints=constraints, prefix="z"), axis=1
    )
    train_z_rep_c = train_z_rep_c.rename(
        mapper=partial(rep_renamer, constraints=confounders, prefix="c"), axis=1
    )

    train_z_rep = pd.concat([train_z_rep_z, train_z_rep_c], axis=1)

    # Concatenate train and test
    test_z_rep["dataset"] = "Artemis+PBCP"
    train_z_rep["dataset"] = "TransNEO"
    train_z_rep = pd.concat([train_z_rep, test_z_rep], axis=0).reset_index()

    train_z_rep["c_ER.status"] = train_z_rep["c_ER.status"] > 0 
    train_z_rep["c_HER2.status"] = train_z_rep["c_HER2.status"] > 0 
    train_z_rep["c_MolType"] = train_z_rep["c_ER.status"].astype(str) + train_z_rep["c_HER2.status"].astype(str)

    moltype_map = {"TrueFalse": "ER+/HER2-", "TrueTrue": "ER+/HER2+", "FalseTrue": "ER-/HER2+", "FalseFalse": "ER-/HER2-"}

    train_z_rep["c_MolType"] = train_z_rep["c_MolType"].apply(lambda x: moltype_map[x])


    plot_umap(train_z_rep, hue="c_MolType", n_neighbors=15, seed=seed)
