In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import anndata as ad
import numpy as np
import pandas as pd
import pickle
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
import json

sys.path.append("..")
from src import *
sns.set_theme("paper")

In [None]:
config_path = "/data_nfs/je30bery/melanoma_data/config.json"
with open(config_path, "r") as f:
    config = json.load(f)

In [None]:
data = get_data_csv(dataset="Melanoma", groups=["Melanoma"], high_quality_only=True)
fovs = data["file_path"].values

In [None]:
with open(config["antibody_gene_mapping"], "rb") as f:
    antibody_gene_symbols = json.load(f)

In [None]:
with open(config["antibody_gene_mapping"], "rb") as f:
    antibody_gene_symbols = json.load(f)

In [None]:
# get df from anndata files
model = "tumor_stage_clf"

x = pickle.load(open(config["segmentation_results"] + "/anndata_files/adata_cell.pickle", 'rb'))
dfs = list()
ne = NeighborEnricher(None, config=config)

for k in x:
    anndata = x[k]
    df = pd.DataFrame()
    df["fov"] = anndata.obsm["field_of_view"]
    
    if not np.unique(df["fov"])[0] in fovs:
        continue
    raw_df = pd.DataFrame(anndata.X, columns=anndata.var["gene_symbol"])
    for c in raw_df.columns:
        if c in ["CD45RA", "CD45RO", "PPB", 'CD66abce']:
            continue
        symbol = antibody_gene_symbols[c]
        if isinstance(symbol, list):
            for s in symbol:
                df[s] = raw_df[c]
        else:
            df[symbol] = raw_df[c]
    
    
    df["condition"] = anndata.obsm["Group"]

    fov = np.unique(df["fov"])[0]
    if fov in ['Melanoma_29_202006031146_1', 'Melanoma_29_202006031146_2',
       'Melanoma_29_202006031146_3', 'Melanoma_29_202006031146_4',
             'Nevi_01_201712121140_1', 'Nevi_01_201712121140_2',
             'Melanoma_35_202009031055_1', 'Melanoma_35_202009031055_2',
       'Melanoma_35_202009031055_3', 'Melanoma_35_202009031055_4']: 
        continue

    cell_types = ne.get_cell_types(fov)
    roi_cells = ne.get_roi_cells(fov, model=model)
    roi_info = np.zeros(len(cell_types))
    assert roi_cells[0] == 0
    roi_info[roi_cells[1:] - 1] = 1
    roi_info = roi_info.astype(bool)
    assert len(cell_types) == len(roi_info) 
    assert len(cell_types) == len(df)
    df["cell_types"] = cell_types
    df["in_roi"] = roi_info
    dfs.append(df) 

df = pd.concat(dfs, ignore_index=True)
df = df.dropna(axis="columns")

In [None]:
cell_types = sorted(np.unique(df["cell_types"].values))
pal = sns.color_palette("hls", len(np.unique(cell_types)))
cell_type_colors = {ct: pal[i] for i, ct in enumerate(cell_types)}

In [None]:
df.drop(['HLA-B', 'HLA-C', 'HLA-DRB1', 'HLA-DRB5', 'NFX1', 'CD3E', 'CD3G', 'CD8B', "COL4A1"], axis=1, inplace=True)
df.rename({"HLA-A": "HLA-ABC", "HLA-DRA": "HLA-DR", "CD8A": "CD8", "COL4A2": "COL4"}, axis=1, inplace=True)
df.drop(["fov", "condition"], axis=1, inplace=True)

In [None]:
mean_ratio = dict()
med_ratio = dict()
med_diff = dict()
mean_diff = dict()

r = np.sum(df["in_roi"])
n = len(df)

for marker in df.columns[:-2]:
    expr = np.array(df[marker].values)
    mean_ratio[marker] = list()    
    med_ratio[marker] = list()    
    mean_diff[marker] = list()    
    med_diff[marker] = list()
    for _ in tqdm(range(1000), leave=False):
        rand = np.random.choice([0, 1], size=n, p=[1-r/n, r/n])
        in_roi = expr[np.where(rand == 1)]
        rest = expr[np.where(rand == 0)]
        mean_ratio[marker].append(np.mean(in_roi) / (np.mean(rest) + 1e-10))
        med_ratio[marker].append(np.median(in_roi) / (np.median(rest) + 1e-10))        
        mean_diff[marker].append(np.mean(in_roi) - (np.mean(rest) + 1e-10))
        med_diff[marker].append(np.median(in_roi) - (np.median(rest) + 1e-10))

In [None]:
z_scores = pd.DataFrame(index=df.columns[:-2], columns=["med_diff", "mean_diff", "med_ratio", "mean_ratio"])
for m in df.columns[:-2]:
    actual_val = np.median(df[df["in_roi"] == True][m]) - np.median(df[df["in_roi"] == False][m])
    z_scores.loc[m, "med_diff"] = (actual_val - np.mean(med_diff[m])) / (np.std(med_diff[m]) + 1e-5)
    
    actual_val = np.mean(df[df["in_roi"] == True][m]) - np.mean(df[df["in_roi"] == False][m]) 
    z_scores.loc[m, "mean_diff"] = (actual_val - np.mean(mean_diff[m])) / (np.std(mean_diff[m]) + 1e-5)
    
    actual_val = np.median(df[df["in_roi"] == True][m]) / (np.median(df[df["in_roi"] == False][m]) + 1e-10) 
    z_scores.loc[m, "med_ratio"] = (actual_val - np.mean(med_ratio[m])) / (np.std(med_ratio[m]) + 1e-5)

    actual_val = np.mean(df[df["in_roi"] == True][m]) / (np.mean(df[df["in_roi"] == False][m]) + 1e-10)

    z_scores.loc[m, "mean_ratio"] = (actual_val - np.mean(mean_ratio[m])) / (np.std(mean_ratio[m]) + 1e-5)

In [None]:
print(np.abs(z_scores).sort_values("med_diff", ascending=False).iloc[:5])
print(np.abs(z_scores).sort_values("mean_diff", ascending=False).iloc[:5])
print(np.abs(z_scores).sort_values("mean_ratio", ascending=False).iloc[:5])

In [None]:
plt.clf()
med = True
zs = list()
ps = list()
f, axs = plt.subplots(4, 9, figsize=(20, 5))

for i, m in enumerate(med_ratio.keys()):
    col = i % 9
    row = i // 9
    
    sns.histplot(med_ratio[m], bins=100, stat="probability", ax=axs[row, col], kde=True)
    actual_ratio = np.median(df[df["in_roi"] == True][m]) / (np.median(df[df["in_roi"] == False][m]) + 1e-10)
    
    z = (actual_ratio - np.mean(med_ratio[m])) / (np.std(med_ratio[m]) + 1e-5)
    axs[row, col].axvline(np.mean(med_ratio[m]), ymin=0, ymax=1, color="m")
    axs[row, col].axvline(np.mean(med_ratio[m]) + np.std(med_ratio[m]), ymin=0, ymax=1, color="m", ls="--")
    axs[row, col].axvline(np.mean(med_ratio[m]) - np.std(med_ratio[m]), ymin=0, ymax=1, color="m", ls="--")


    zs.append(z)
    p = norm.sf(abs(z)) * 2
    axs[row, col].axvline(actual_ratio, ymin=0, ymax=1, color="r")
    axs[row, col].set_title(f"{m}: z={z:.2f}, p={p:.3f}")
plt.tight_layout()
plt.show()

In [None]:
z_dict = {k: zs[i] for i, k in enumerate(list(med_ratio.keys()))}
ordered_keys = sorted(z_dict.keys(), key=lambda x: abs(z_dict[x]), reverse=True)

In [None]:
ordered_keys

In [None]:
zs = list()
z_dictf = dict()
for i, m in enumerate(med_ratio.keys()):
    actual_ratio = np.mean(df[df["in_roi"] == True][m]) / (np.mean(df[df["in_roi"] == False][m]) + 1e-10)
    z = (actual_ratio - np.mean(mean_ratio[m])) / (np.std(mean_ratio[m]) + 1e-5)
    zs.append(z)
z_dict = {k: zs[i] for i, k in enumerate(list(mean_ratio.keys()))}

In [None]:
melted = df.drop("cell_types", axis=1).melt(id_vars="in_roi", var_name="Gene")
#melted = melted[melted["Gene"].isin(["KRT14", "TP73", "EGFR", "GJA1", "CD1A"])]
melted["z"] = melted["Gene"].apply(lambda x: np.abs(z_dict[x]))
melted = melted.sort_values("z", ascending=False)
melted.rename({"value": "Expression", "in_roi": "In ROI"}, inplace=True, axis=1)
melted["In ROI"].replace({True: "In ROI", False: "Rest"}, inplace=True)

In [None]:
sns.set_theme("paper")
f = plt.figure(figsize=(10, 4))
palette = ["#B51B1B", "#117211"]

melted = df.drop("cell_types", axis=1).melt(id_vars="in_roi", var_name="Gene")
#melted = melted[melted["Gene"].isin(["KRT14","EGFR", "GJA1","TP73"])]


melted["z"] = melted["Gene"].apply(lambda x: np.abs(z_dict[x]))
melted = melted.sort_values("Gene", ascending=True)
melted.rename({"value": "Expression", "in_roi": "In ROI"}, inplace=True, axis=1)
melted["In ROI"].replace({True: "In ROI", False: "Rest"}, inplace=True)

ax = sns.boxplot(data=melted, x="Gene", y="Expression", showfliers=False, hue="In ROI", hue_order=["In ROI", "Rest"], palette=palette)
#axs[i].set_title(cell_type)
plt.tight_layout()
for i, xtick in enumerate(ax.get_xticklabels()):
    z = z_dict[ax.get_xticklabels()[i].get_text()]
    #color = palette[0] if z > 0 else palette[1]
    #xtick.set_color(color)
    

plt.xticks(rotation=90)
sns.move_legend(ax, loc=(0.4, 1), ncol=2, frameon=False, title=None)
markers = [m.get_text() for m in ax.get_xticklabels()]
#zs = z_scores.loc[markers, "mean_diff"]
#print(zs)
#xticklabels = [f"{m.get_text()}\n$z_{{Δ_{{mean}}}}={zs[i]:.2f}$" for i, m in enumerate(ax.get_xticklabels())] #
#ax.set_xticklabels(xticklabels)
#plt.title("Marker distribution in cell types in ROIs vs. rest")
plt.tight_layout()
plt.savefig("../result_plots/boxplots_cell_types_ROI_all.pdf")

In [None]:
sns.set_theme("paper")
f = plt.figure(figsize=(6, 4))
palette = ["#B51B1B", "#117211"]

melted = df.drop("cell_types", axis=1).melt(id_vars="in_roi", var_name="Gene")
melted = melted[melted["Gene"].isin(["KRT14","EGFR", "GJA1","TP73"])]


melted["z"] = melted["Gene"].apply(lambda x: np.abs(z_dict[x]))
melted = melted.sort_values("z", ascending=False)
melted.rename({"value": "Expression", "in_roi": "In ROI"}, inplace=True, axis=1)
melted["In ROI"].replace({True: "In ROI", False: "Rest"}, inplace=True)

ax = sns.boxplot(data=melted, x="Gene", y="Expression", showfliers=False, hue="In ROI", hue_order=["In ROI", "Rest"], palette=palette)
#axs[i].set_title(cell_type)
plt.tight_layout()
for i, xtick in enumerate(ax.get_xticklabels()):
    z = z_dict[ax.get_xticklabels()[i].get_text()]
    #color = palette[0] if z > 0 else palette[1]
    #xtick.set_color(color)
    

plt.xticks(rotation=0)
sns.move_legend(ax, loc=(0.275, 1), ncol=2, frameon=False, title=None)
markers = [m.get_text() for m in ax.get_xticklabels()]
zs = z_scores.loc[markers, "mean_diff"]
print(zs)
xticklabels = [f"{m.get_text()}\n$z_{{Δ_{{mean}}}}={zs[i]:.2f}$" for i, m in enumerate(ax.get_xticklabels())] #
ax.set_xticklabels(xticklabels)
#plt.title("Marker distribution in cell types in ROIs vs. rest")
plt.tight_layout()
plt.savefig("../result_plots/boxplots_cell_types_ROI.pdf")