In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import anndata as ad
import numpy as np
import pandas as pd
import pickle
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
import json

sys.path.append("..")
from src import *
sns.set_theme("paper")

In [None]:
config_path = "/data_nfs/je30bery/melanoma_data/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

In [None]:
data = get_data_csv(dataset="Melanoma", groups=["Melanoma"], high_quality_only=True, config_path=config_path)
fovs = data["file_path"].values

In [None]:
with open(config["antibody_gene_mapping"], "rb") as f:
    antibody_gene_symbols = json.load(f)

In [None]:
# get df from anndata files
model = "tumor_stage_clf"

x = pickle.load(open(config["segmentation_results"] + "/anndata_files/adata_cell.pickle", 'rb'))
dfs = list()
ne = NeighborEnricher(None, config=config)

for k in x:
    anndata = x[k]
    df = pd.DataFrame()
    df["fov"] = anndata.obsm["field_of_view"]
    if not np.unique(df["fov"])[0] in fovs:
        continue


    raw_df = pd.DataFrame(anndata.X, columns=anndata.var["gene_symbol"])
    for c in raw_df.columns:
        if c in ["CD45RA", "CD45RO", "PPB", 'CD66abce']:
            continue
        symbol = antibody_gene_symbols[c]
        if isinstance(symbol, list):
            for s in symbol:
                df[s] = raw_df[c]
        else:
            df[symbol] = raw_df[c]    
    df["condition"] = anndata.obsm["Group"]

    fov = np.unique(df["fov"])[0]
    if fov not in data["file_path"].values:
        print(fov)
        continue
    # samples that need to be excluded because they lack channels
    if fov in ['Melanoma_29_202006031146_1', 'Melanoma_29_202006031146_2', 'Melanoma_29_202006031146_3', 'Melanoma_29_202006031146_4', 
               'Nevi_01_201712121140_1', 'Nevi_01_201712121140_2', 'Melanoma_35_202009031055_1', 'Melanoma_35_202009031055_2', 
               'Melanoma_35_202009031055_3', 'Melanoma_35_202009031055_4']:
        continue

    cell_types = ne.get_cell_types(fov)
    try:
        roi_cells = ne.get_roi_cells(fov, model=model)
    except:
        continue

    roi_info = np.zeros(len(cell_types))
    roi_info[roi_cells[1:] - 1] = 1
    roi_info = roi_info.astype(bool)
    df["cell_types"] = cell_types
    df["in_roi"] = roi_info
    dfs.append(df) 

df = pd.concat(dfs, ignore_index=True)
df = df.dropna(axis="columns")

In [None]:
cell_types = sorted(np.unique(df["cell_types"].values))
pal = sns.color_palette("hls", len(np.unique(cell_types)))
cell_type_colors = {ct: pal[i] for i, ct in enumerate(cell_types)}

In [None]:
df.drop(['HLA-B', 'HLA-C', 'HLA-DRB1', 'HLA-DRB5', 'NFX1', 'CD3E', 'CD3G', 'CD8B', "COL4A1"], axis=1, inplace=True)
df.rename({"NCR3LG1": "DLG4", "HLA-A": "HLA-ABC", "HLA-DRA": "HLA-DR", "CD8A": "CD8", "COL4A2": "COL4"}, axis=1, inplace=True)

In [None]:
data = data.set_index("file_path")
df["tumor_stage"] = df["fov"].apply(lambda x: data.loc[x]["Tumor stage"][:2])
df["histo_id"] = df["fov"].apply(lambda x: data.loc[x]["Histo-ID"])
df.drop("fov", axis=1, inplace=True)

In [None]:
df["coarse_tumor_stage"] = df["tumor_stage"].replace({"T1": "T1, T2",
                                                      "T2": "T1, T2",
                                                      "T3": "T3, T4",
                                                      "T4": "T3, T4"})

In [None]:
palette = {"T1, T2": "#f57d05", "T3, T4": "#8634b3"}

f, ax = plt.subplots(1, figsize=(6, 3))
ins = df[df["in_roi"] == True].drop("in_roi", axis=1)
ins.rename({"coarse_tumor_stage": "Tumor stage", "cell_types": "Cell type"}, axis=1, inplace=True)
sns.histplot(ins[ins["Tumor stage"] == "T1, T2"].sort_values("Cell type"), x="Cell type", hue="Tumor stage", multiple="dodge", shrink=0.4, stat="probability", ax=ax, palette=[palette["T1, T2"]], legend=False)
sns.histplot(ins[ins["Tumor stage"] == "T3, T4"].sort_values("Cell type"), x="Cell type", hue="Tumor stage", multiple="dodge", shrink=0.4, stat="probability", ax=ax, palette=[palette["T3, T4"]], legend=False)

for container in ax.containers:
    dx = -0.2 if container.get_label() == "_container0" else 0.2
    for bar in container:
        bar.set_x(bar.get_x() + dx)


tick_labels = [label.get_text().replace(" cells", "\ncells") for label in ax.get_xticklabels()]
tick_labels = [label.replace(" keratinocytes", "\nkeratinocytes") for label in tick_labels]
ax.set_xticks(list(range(len(tick_labels))), tick_labels, rotation=90)

legend_handles = [Patch(color=color, label=key) for key, color in palette.items()]
ax.legend(handles=legend_handles, loc=(0.275, 1), ncol=2, frameon=False)
plt.tight_layout()
#plt.savefig("../result_plots/cell_type_distri.pdf", bbox_inches='tight') 

In [None]:
plot = True
markers = df.columns[:-6] #['NCAM1', 'NCR3LG1', 'CSPG4', 'MLANA', 'PPARG', 'TP63', 'NGFR', 'CD3D'] # ['NCAM1', 'CSPG4', 'MLANA', 'NGFR'] # ['NCAM1', 'EBF1', 'NCR3LG1', 'CSPG4', 'MLANA', 'TP63', 'NGFR']# ['MLANA', 'SDC1', 'HLA-ABC', 'NCR3LG1', 'NOTCH3', 'PPARG']
#markers = ['CD4', 'EBF1', 'NCR3LG1', 'CD3D', 'NCAM1', 'CSPG4', 'MLANA', 'NGFR'] means
markers = ['LAMP1', 'IL2RA', 'EGFR', 'RIMS3']
melted_in = df[df["in_roi"] == True].drop("in_roi", axis=1).melt(id_vars=['condition', 'cell_types', 'tumor_stage', 'coarse_tumor_stage', 'histo_id'], var_name="Gene")
melted_rest = df.drop("in_roi", axis=1).melt(id_vars=['condition', 'cell_types', 'tumor_stage', 'coarse_tumor_stage', 'histo_id'], var_name="Gene")
melted_in.rename({"coarse_tumor_stage": "Tumor stage", "value": "Expression"}, inplace=True, axis=1)
melted_rest.rename({"coarse_tumor_stage": "Tumor stage", "value": "Expression"}, inplace=True, axis=1)

if plot:
    plt.clf()
    fig, axs = plt.subplots(2, len(markers), figsize=(10, 5), sharex=True, sharey="col")

t_ins = list()
t_rest = list()
for i, m in enumerate(markers):
    if plot:
        ax = axs[0, i]
        ax.legend("off")
    subset = melted_in[melted_in["Gene"] == m].copy()    
    cutoff = np.quantile(subset["Expression"], 0.95)
    subset = subset[subset["Expression"] < cutoff]

    t = np.median(subset[subset["Tumor stage"] == "T1, T2"]["Expression"]) - np.median(subset[subset["Tumor stage"] == "T3, T4"]["Expression"])
    t_ins.append(t)
    if plot:
        ax.set_title(f"\n{m}\n$Δ_{{med, ROI}}={np.round(t, 3):.3f}$")
        sns.violinplot(subset, x="Tumor stage", y="Expression", cut=0, ax=ax, hue="Tumor stage", order=["T1, T2", "T3, T4"], legend=False, palette=palette) #, hue_order=[True, False], palette=palette)#, order=[True, False])
        ax.get_legend().remove()
        ax.set(xlabel=None)
        if i != 0:
            ax.set(ylabel=None)
        else:
            ax.set(ylabel="Expression\n ROI cells")
        #ax.set_yscale("log")
        ax = axs[1, i]
    
    subset = melted_rest[melted_rest["Gene"] == m].copy()    
    cutoff = np.quantile(subset["Expression"], 0.9)
    subset = subset[subset["Expression"] < cutoff]
    t = np.median(subset[subset["Tumor stage"] == "T1, T2"]["Expression"]) - np.median(subset[subset["Tumor stage"] == "T3, T4"]["Expression"])
    t_rest.append(t)

    if plot:
        ax.set_title(f"{m}\n$Δ_{{med, all}}={np.round(t, 3):.3f}$")
        sns.violinplot(subset, x="Tumor stage", y="Expression", cut=0, ax=ax, hue="Tumor stage", order=["T1, T2", "T3, T4"], legend=False, palette=palette) #, hue_order=[True, False], palette=palette)#, order=[True, False])
        ax.get_legend().remove()
        if i != 0:
            ax.set(ylabel=None)
        else:
            ax.set(ylabel="Expression\nin all cells")
        #ax.set_yscale("log")
if plot: 
    legend_handles = [Patch(color=color, label=key) for key, color in palette.items()]
    plt.figlegend(handles=legend_handles, loc="upper center", ncol=2, frameon=False)
    plt.tight_layout()
#plt.savefig("../result_plots/expression_per_tumor_stage_all_rois.pdf", bbox_inches='tight') 
#plt.savefig("../result_plots/expression_per_tumor_stage_all_vs_rois.pdf", bbox_inches='tight') 

In [None]:
plot = False
markers = df.columns[:-6] 
melted_in = df[df["in_roi"] == True].drop("in_roi", axis=1).melt(id_vars=['condition', 'cell_types', 'tumor_stage', 'coarse_tumor_stage', 'histo_id'], var_name="Gene")
melted_rest = df.drop("in_roi", axis=1).melt(id_vars=['condition', 'cell_types', 'tumor_stage', 'coarse_tumor_stage', 'histo_id'], var_name="Gene")
melted_in.rename({"coarse_tumor_stage": "Tumor stage", "value": "Expression"}, inplace=True, axis=1)
melted_rest.rename({"coarse_tumor_stage": "Tumor stage", "value": "Expression"}, inplace=True, axis=1)

if plot:
    plt.clf()
    fig, axs = plt.subplots(2, len(markers), figsize=(10, 5), sharex=True, sharey="col")

t_ins = list()
t_rest = list()
for i, m in enumerate(markers):
    if plot:
        ax = axs[0, i]
        ax.legend("off")
    subset = melted_in[melted_in["Gene"] == m].copy()    
    cutoff = np.quantile(subset["Expression"], 0.95)
    subset = subset[subset["Expression"] < cutoff]

    t = np.median(subset[subset["Tumor stage"] == "T1, T2"]["Expression"]) - np.median(subset[subset["Tumor stage"] == "T3, T4"]["Expression"])
    t_ins.append(t)
    if plot:
        ax.set_title(f"\n{m}\n$Δ_{{med, ROI}}={np.round(t, 3):.3f}$")
        sns.violinplot(subset, x="Tumor stage", y="Expression", cut=0, ax=ax, hue="Tumor stage", order=["T1, T2", "T3, T4"], legend=False, palette=palette) #, hue_order=[True, False], palette=palette)#, order=[True, False])
        ax.get_legend().remove()
        ax.set(xlabel=None)
        if i != 0:
            ax.set(ylabel=None)
        else:
            ax.set(ylabel="Expression\n ROI cells")
        #ax.set_yscale("log")
        ax = axs[1, i]
    
    subset = melted_rest[melted_rest["Gene"] == m].copy()    
    cutoff = np.quantile(subset["Expression"], 0.95)
    subset = subset[subset["Expression"] < cutoff]
    t = np.median(subset[subset["Tumor stage"] == "T1, T2"]["Expression"]) - np.median(subset[subset["Tumor stage"] == "T3, T4"]["Expression"])
    t_rest.append(t)

    if plot:
        ax.set_title(f"{m}\n$Δ_{{med, all}}={np.round(t, 3):.3f}$")
        sns.violinplot(subset, x="Tumor stage", y="Expression", cut=0, ax=ax, hue="Tumor stage", order=["T1, T2", "T3, T4"], legend=False, palette=palette) #, hue_order=[True, False], palette=palette)#, order=[True, False])
        ax.get_legend().remove()
        if i != 0:
            ax.set(ylabel=None)
        else:
            ax.set(ylabel="Expression\nin all cells")
        #ax.set_yscale("log")
if plot: 
    legend_handles = [Patch(color=color, label=key) for key, color in palette.items()]
    plt.figlegend(handles=legend_handles, loc="upper center", ncol=2, frameon=False)
    plt.tight_layout()
#plt.savefig("../result_plots/expression_per_tumor_stage_all_rois.pdf", bbox_inches='tight') 
#plt.savefig("../result_plots/expression_per_tumor_stage_all_vs_rois.pdf", bbox_inches='tight') 

In [None]:
plt.plot(t_rest, label="all")
plt.plot(t_ins, label="roi")
plt.plot(np.abs(np.array(t_ins)-np.array(t_rest)))
plt.xticks(range(len(markers)), labels=markers, rotation=90)
plt.legend()

In [None]:
np.where((np.abs(np.array(t_ins)-np.array(t_rest)) > 0.02) & (np.abs(np.array(t_ins)) > np.abs(np.array(t_rest))))

In [None]:
markers[np.where((np.abs(np.array(t_ins)-np.array(t_rest)) > 0.02) & (np.abs(np.array(t_ins)) > np.abs(np.array(t_rest))))]

In [None]:
np.unique(melted_in["condition"], return_counts=True)

In [None]:

sns.set_theme("paper")
palette = {"T1, T2": "#f57d05", "T3, T4": "#8634b3"}
f, ax = plt.subplots(1, len(interesting_roi), figsize=(4, 3))
i = 0
for cell_type in interesting_roi:
    subset = melted_in[(melted_in["cell_types"] == cell_type) & (melted_in["Gene"].isin(interesting_roi[cell_type]))] # 
    cutoff = np.quantile(subset["Expression"], 0.95)
    subset = subset[subset["Expression"] < cutoff]
    sns.violinplot(data=subset, x="Gene", y="Expression", ax=ax, hue="Tumor stage", legend=False, palette=palette, hue_order=["T1, T2", "T3, T4"], order=["GJA1", "NOTCH1", "TP73"])
    ax.set_title(cell_type)
    #if i == 1:
    legend_handles = [Patch(color=color, label=key) for key, color in palette.items()]
    ax.legend(handles=legend_handles, loc=(0.23, 1.15), ncol=2, frameon=False)
    #else:
    #ax.legend("", frameon=False)
    ax.set(ylabel="Expression in Endothelial cells")
    
plt.tight_layout()
#plt.suptitle("Marker distribution in cell types in ROIs vs. rest")
#plt.savefig("../result_plots/violins_cell_types_ROI.pdf")

In [None]:
interesting_roi = { "Endothelial cells": ["GJA1", "NOTCH1", "TP73"],
}
melted = melted_in
sns.set_theme("paper")
palette = {"T1, T2": "#f57d05", "T3, T4": "#8634b3"}
f, ax = plt.subplots(1, len(interesting_roi), figsize=(4, 3))
i = 0
for cell_type in interesting_roi:
    subset = melted[(melted["cell_types"] == cell_type) & (melted["Gene"].isin(interesting_roi[cell_type]))] # 
    cutoff = np.quantile(subset["Expression"], 0.95)
    subset = subset[subset["Expression"] < cutoff]
    sns.violinplot(data=subset, x="Gene", y="Expression", ax=ax, hue="Tumor stage", legend=False, palette=palette, hue_order=["T1, T2", "T3, T4"], order=["GJA1", "NOTCH1", "TP73"])
    ax.set_title(cell_type)
    #if i == 1:
    legend_handles = [Patch(color=color, label=key) for key, color in palette.items()]
    ax.legend(handles=legend_handles, loc=(0.23, 1.15), ncol=2, frameon=False)
    #else:
    #ax.legend("", frameon=False)
    ax.set(ylabel="Expression in Endothelial cells")
    
plt.tight_layout()
#plt.suptitle("Marker distribution in cell types in ROIs vs. rest")
#plt.savefig("../result_plots/violins_cell_types_ROI.pdf")

In [None]:
#np.logical_or((np.abs(diffdiff)<0.05),(np.abs(med_diffs_roi)<0.05))