In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")
import theoretic as th
from sklearn.mixture import GaussianMixture
import anndata as ad
from scipy.sparse import csr_matrix
import numpy as np
import pandas as pd
import pickle
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import os
import cv2
from matplotlib.patches import Patch
import json
sys.path.append("../../MAGICAL/data_utils")
from data_utils import *
from tqdm import tqdm
plt.style.use("dark_background")

In [None]:
base = "/data_nfs/"
data = os.path.join(base, "datasets/melc/melanoma/")
fovs = os.listdir(os.path.join(data, "processed"))

In [None]:
with open("../data/antibody_gene_mapping/antibodies.json", "rb") as f:
    antibody_gene_symbols = json.load(f)

In [None]:
antibody_gene_symbols

In [None]:
x = pickle.load(open(os.path.join(base, 'datasets/melc/melanoma/segmented/anndata_files/adata_cell.pickle'), 'rb'))
dfs = list()
coords = dict()
for k in x:
    anndata = x[k]
    raw_df = pd.DataFrame(anndata.X, columns=anndata.var["gene_symbol"])
    df = pd.DataFrame()
    
    for c in raw_df.columns:
        if c in ["CD45RA", "CD45RO", "PPB", 'CD66abce']:
            continue
        symbol = antibody_gene_symbols[c]
        if isinstance(symbol, list):
            for s in symbol:
                df[s] = raw_df[c]
        else:
            df[symbol] = raw_df[c]
    
    df["fov"] = anndata.obsm["field_of_view"]
    df["condition"] = anndata.obsm["Group"]

    fov = np.unique(df["fov"])[0]
    if fov in ['Melanoma_29_202006031146_1', 'Melanoma_29_202006031146_2',
       'Melanoma_29_202006031146_3', 'Melanoma_29_202006031146_4',
             'Nevi_01_201712121140_1', 'Nevi_01_201712121140_2',
             'Melanoma_35_202009031055_1', 'Melanoma_35_202009031055_2',
       'Melanoma_35_202009031055_3', 'Melanoma_35_202009031055_4']:
        continue
    dfs.append(df) 
    coords[fov] = (x[k].uns["cell_coordinates"])

df = pd.concat(dfs, ignore_index=True)
df = df.dropna(axis="columns")

adata = ad.AnnData(df.drop(["fov", "condition"], axis="columns"))
adata.vars = list(df.drop(["fov", "condition"], axis="columns").columns)
adata.var_names = list(df.drop(["fov", "condition"], axis="columns").columns)
adata.obs["field_of_view"] = list(df["fov"].astype(str))
adata.obs["condition"] = list(df["condition"].astype(str))
adata.obs_names = [f"Cell_{i:d}" for i in range(adata.n_obs)]    
#adata.uns["cell_coordinates"] = coords
#sc.pp.neighbors(adata)
#sc.tl.umap(adata)
#sc.tl.pca(adata)

In [None]:
try:
    reference = pd.read_csv("../data/theoretic_reference_data/skin_reference.csv", index_col="Unnamed: 0")
except:
    reference = th.get_hpa_reference('skin')
    usable_genes = [c for c in reference.columns if c in adata.var_names]
    reference = reference[usable_genes]

In [None]:
reference["CD4"]

In [None]:
high_quality = np.unique(get_data_csv(high_quality_only=True)["file_path"])

In [None]:
adata = adata[adata.obs["field_of_view"].isin(high_quality)]

In [None]:
samples = dict()
all_cells = list()
plot = False

for hq in high_quality:
    samples[hq] = adata[adata.obs["field_of_view"] == hq].copy()
    if len(samples[hq]) == 0:
        print(hq)
        del samples[hq]

In [None]:
cell_types = sorted(list(reference.index))
pal = sns.color_palette("hls", len(np.unique(cell_types)))
cell_type_colors = {ct: pal[i] for i, ct in enumerate(cell_types)}

In [None]:
forest = list()
for hq in samples.keys():
    print(hq)
    if "Nevi_03" in hq:
        continue
    tree = th.identify_cell_types(samples[hq], reference.copy(), min_fold_change=2, z_score_cutoff=1.96/4)
    forest.append(tree)
    #for st in tree.split_tuples:
    #    print(st.cell_type, st.mapped_genes)
    cell_types = list(samples[hq].obs["cell_type"])
    all_cells += cell_types
    p = os.path.join(base, "je30bery/melanoma_data/MAGICAL/data/cell_types", )
    with open(os.path.join(p, hq + "_cell_types.pkl"), "wb") as fp:   
        pickle.dump(cell_types, fp)
    #samples[hq].write_h5ad(f"./sample_wise_results/cell_types_{hq}.h5ad")
    if plot:
        # for i in range(len(ROIS)):
        fov = hq
        if fov in ['Melanoma_29_202006031146_1', 'Melanoma_29_202006031146_2',
       'Melanoma_29_202006031146_3', 'Melanoma_29_202006031146_4',
             'Nevi_01_201712121140_1', 'Nevi_01_201712121140_2',
             'Melanoma_35_202009031055_1', 'Melanoma_35_202009031055_2',
       'Melanoma_35_202009031055_3', 'Melanoma_35_202009031055_4']:
            continue
        segmented = os.path.join(data, "segmented", f'{fov}_cells.npy')
        with open(segmented, "rb") as openfile:
            seg_file = np.load(openfile)
            
        cell_types = samples[hq][samples[hq].obs["field_of_view"] == fov].obs["cell_type"]
        assert len(cell_types) > 0, "no cells found for this fov"
        assert len(np.unique(seg_file.flatten())) == len(cell_types) + 1
        
        cell_types_on_seg = np.zeros((seg_file.shape[0], seg_file.shape[1], 3))
        for i, cell in enumerate(np.unique(seg_file.flatten())):
            if cell == 0:
                continue
            cell_types_on_seg[np.where(seg_file == cell)] = cell_type_colors[cell_types[i - 1]]
        
        prop_iodide = cv2.imread([os.path.join(data, "processed", fov, f) for f in os.listdir(os.path.join(data, "processed", fov)) if "propidium" in f.lower()][0])
        
        plt.figure(figsize=(10,10))
        plt.imshow(cv2.addWeighted(prop_iodide, 0.4, (cell_types_on_seg * 255).astype(np.uint8), 0.6, 0))
        legend_handles = [Patch(color=color, label=key) for key, color in cell_type_colors.items()]
        
        # Create a dummy plot to show the legend
        plt.scatter([], [], label='Legend', alpha=0)  # Create an invisible point for the legend
        plt.legend(handles=legend_handles, loc=(1.01, 0.7))
        plt.title(fov)
    
        plt.tight_layout()
        plt.axis("off")
        plt.savefig(f"../result_plots/cell_types_on_prop/cell_type_analysis_{fov}_cells.png")
        #plt.show()

In [None]:
df = pd.DataFrame(columns=range(len(forest)))
for i, tree in enumerate(forest):
    entry = list()
    for j, st in enumerate(tree.split_tuples): 
        entry.append(f"{st.cell_type}: {st.mapped_genes[0]}")     
    df[i] = entry

In [None]:
df = df.T.sort_values(by=list(range(10)))
unique_trees = df.drop_duplicates()

In [None]:
row_counts = df.apply(tuple, axis=1).value_counts()

In [None]:
def assignment_heatmap(reference, assignment):
    cell_types = list()
    genes = list()
    for st in assignment:
        cell_types.append(st.split(":")[0])
        genes.append(st.split(":")[1][1:])
        
    ref_copy = reference.copy()
    ref_copy /= ref_copy.max(axis=0)
    hm = pd.DataFrame(columns=genes)
    
    for ct in cell_types:
        new_row = ref_copy.loc[ct][genes]
        hm.loc[ct] = new_row
        ref_copy.drop(ct, inplace=True)
    return hm

In [None]:
plt.style.use("default")
f, axs = plt.subplots(3, 3, figsize=(14, 10))

for i, row in enumerate(row_counts.index.values):
    hm = assignment_heatmap(reference, row)

    
    if int(row_counts[i]) == 1:
        axs[i // 3, i % 3].set_title(f"{int(row_counts[i])} sample")
    else:
        axs[i // 3, i % 3].set_title(f"{int(row_counts[i])} samples")
    sns.heatmap(hm, ax=axs[i // 3, i % 3], square=True, cbar=False)
    
plt.tight_layout()
plt.savefig("../result_plots/heatmap.pdf")

In [None]:
unique_trees = forest[[0, 5, 7, 27, 45, 46, 51, 52, 57]]

In [None]:
concatenated_ann_data = ad.concat(
    samples.values(),  # Pass the AnnData objects as values in the dictionary
    keys=list(samples.keys()),  # Use sample names as keys
    join='outer',  # Specify the join strategy (use 'outer' for concatenation)
    axis=0  # Concatenate along the observation axis
)

In [None]:
#genes = ['MLANA', 'CD3G', 'CD14', 'NOTCH3', 'PPARG', 'KRT14', 'TP63', 'EGFR', 'CSPG4', 'CD36']
concatenated_ann_data = concatenated_ann_data[:, np.where(concatenated_ann_data.var_names.isin(genes))[0]]

In [None]:
sc.pp.neighbors(concatenated_ann_data)
sc.tl.umap(concatenated_ann_data)
sc.pl.umap(concatenated_ann_data, color="cell_type")

In [None]:
sc.tl.pca(concatenated_ann_data)
sc.pl.pca(concatenated_ann_data, color="cell_type", components="1, 2")

In [None]:
var_exp = concatenated_ann_data.uns['pca']['variance_ratio']
# Access loadings (coefficients) of each gene on each principal component
loadings = concatenated_ann_data.varm['PCs']
pc_index = 2
loadings_pc1 = loadings[:, pc_index]
sorted_genes = concatenated_ann_data.var_names[np.argsort(np.abs(loadings_pc1))[::-1]]
print(sorted_genes)

In [None]:
concatenated_ann_data

In [None]:
fov = hq
segmented = os.path.join(data, "segmented", f'{fov}_cells.npy')
with open(segmented, "rb") as openfile:
    seg_file = np.load(openfile)
    
cell_types = samples[hq][samples[hq].obs["field_of_view"] == fov].obs["cell_type"]
assert len(cell_types) > 0, "no cells found for this fov"
assert len(np.unique(seg_file.flatten())) == len(cell_types) + 1

cell_types_on_seg = np.zeros((seg_file.shape[0], seg_file.shape[1], 3))
for i, cell in tqdm(enumerate(np.unique(seg_file.flatten()))):
    if cell == 0:
        continue
    cell_types_on_seg[np.where(seg_file == cell)] = cell_type_colors[cell_types[i - 1]]

prop_iodide = cv2.imread([os.path.join(data, "processed", fov, f) for f in os.listdir(os.path.join(data, "processed", fov)) if "propidium" in f.lower()][0])

In [None]:
plt.figure(figsize=(5,5))
plt.imshow(cv2.addWeighted(prop_iodide, 0.4, (cell_types_on_seg * 255).astype(np.uint8), 0.6, 0))
legend_handles = [Patch(color=color, label=key) for key, color in cell_type_colors.items()]

# Create a dummy plot to show the legend
plt.scatter([], [], label='Legend', alpha=0)  # Create an invisible point for the legend
plt.legend(handles=legend_handles, loc=(0, -0.15), ncol=3, frameon=False, prop = {"size": 6})
#plt.title(fov)

plt.tight_layout()
plt.axis("off")
plt.savefig(f"../result_plots/cell_types_on_prop/cell_type_analysis_{fov}_cells.pdf", format='pdf', bbox_inches='tight', pad_inches=0.1, transparent=True)
#plt.show()

In [None]:
plt.figure(figsize=(6, 4))

plt.tight_layout()
plt.savefig("../result_plots/cell_type_assignment.pdf", format='pdf', bbox_inches='tight', pad_inches=0.1, transparent=True)

In [None]:
hm

In [None]:
df = pd.DataFrame(sorted(all_cells), columns=["cell_type"])

In [None]:
sns.histplot(df, x="cell_type", hue="cell_type", palette=cell_type_colors, shrink=0.7)
plt.xticks(rotation=45, ha="right")
plt.title(f"Cell type distribution across {len(samples)} samples, {len(all_cells)} cells")
plt.tight_layout()
plt.legend('', frameon=False)
plt.savefig("../result_plots/cell_types.pdf")