In [None]:
import pandas as pd

import os

import numpy as np

import sys
sys.path.append("../")

import matplotlib.pyplot as plt
import seaborn as sns
from digitalhistopathology.embeddings.gene_embedding import GeneEmbedding

import warnings
warnings.filterwarnings("ignore")

import anndata as ad
import glob
import scanpy as sc

import gzip
import pickle

import matplotlib
matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

In [None]:
import json

# Load configuration
with open("../config/config_notebooks.json", "r") as config_file:
    config = json.load(config_file)

In [None]:
emb_raw_gene = ad.read_h5ad("../results/molecular/filtered_gene_expression.h5ad")

with gzip.open(config['patches_info_path']) as f:
    patches_info = pickle.load(f)
    
    
def create_legend_gene_expression(gene_exp):
    if gene_exp == 0:
        return "Not expressed"
    elif gene_exp < 2:
        return "Low expression"
    elif gene_exp < 4:
        return "Medium expression"
    else:
        return "High expression"


In [None]:
patches_info = pd.DataFrame(patches_info)


In [None]:
# UNI base 
labels = pd.read_csv("../results/benchmark/her2_final_without_A/base_models/invasive_cancer_clustering/kmeans/uni/invasive_labels_7_clusters_umap_min_dist_0.001_n_neighbors_10.csv", index_col=0)
palette={0: 'orange', 2:'gray', 4: 'lime', 3: 'magenta', 1: 'cyan', 'not invasive': 'white'}
for patient in ["B1", "C1", "D1", "E1", "F1", "G2", "H1"]:
    print(patient)
    subset_emb = GeneEmbedding()
    n_patches = len(patches_info[patches_info["name_origin"] == patient])
    subset_emb.emb = ad.AnnData(X=np.zeros((n_patches, 1)))
    subset_emb.emb.obs = patches_info[patches_info["name_origin"] == patient]
    # subset_emb.emb = emb_raw_gene[emb_raw_gene.obs["name_origin"] == patient]
    # subset_emb.emb.obs = subset_emb.emb.obs.drop(['tumor', 'name_origin'], axis=1).merge(pd.DataFrame(patches_info), left_index=True, right_on='name')
    subset_emb.emb.obs.set_index('name', inplace=True)
    subset_emb.emb.obs['predicted_label'] = [str(labels.loc[idx, 'predicted_label']) if idx in labels.index else "not invasive" for idx in subset_emb.emb.obs.index]
    
    subset_emb.emb = subset_emb.emb[subset_emb.emb.obs["predicted_label"] != "not invasive"]

    subset_emb.plot_spot_location_with_color_on_origin_image(color='predicted_label', s=10)

    plt.savefig(f"../Figures/Fig4/slide_uni_{patient}_predicted_labels.pdf", bbox_inches='tight')

In [None]:
palette = {'0': '#66BB46', 
           '1': "#AD66FF", 
           '2': "#F9A11B", 
           '3': '#31C4F3', 
           '4': '#ACB5B6',
           '5': "#965D59",
           '6': '#EC2A90',
           '7': '#2F2F8E',
           '8': "#FFE340"}

In [None]:
sns.palplot(sns.color_palette(palette.values()))

In [None]:
# UNI full koleo
labels = pd.read_csv("../results/benchmark/her2_final_without_A/uni_full_models/invasive_cancer_clustering/kmeans/uni_full_koleo_16384_prototypes/invasive_labels_5_clusters_umap_min_dist_0.001_n_neighbors_250.csv", index_col=0)
palette={'0': '#F9A11B', '2': '#ACB5B6', '4': '#66BB46', '3': '#EC2A90', '1': '#31C4F3', 'not invasive': 'white'}
for patient in ["B1", "C1", "D1", "E1", "F1", "G2", "H1"]:
    print(patient)
    subset_emb = GeneEmbedding()
    n_patches = len(patches_info[patches_info["name_origin"] == patient])
    subset_emb.emb = ad.AnnData(X=np.zeros((n_patches, 1)))
    subset_emb.emb.obs = patches_info[patches_info["name_origin"] == patient]
    # subset_emb.emb = emb_raw_gene[emb_raw_gene.obs["name_origin"] == patient]
    # subset_emb.emb.obs = subset_emb.emb.obs.drop(['tumor', 'name_origin'], axis=1).merge(pd.DataFrame(patches_info), left_index=True, right_on='name')
    subset_emb.emb.obs.set_index('name', inplace=True)
    subset_emb.emb.obs['predicted_label'] = [str(labels.loc[idx, 'predicted_label']) if idx in labels.index else "not invasive" for idx in subset_emb.emb.obs.index]
    
    subset_emb.emb = subset_emb.emb[subset_emb.emb.obs["predicted_label"] != "not invasive"]

    subset_emb.plot_spot_location_with_color_on_origin_image(color='predicted_label', s=10, palette=palette)

    plt.savefig(f"../Figures/Fig4/slide_uni_full_koleo_16384_{patient}_predicted_labels.pdf", bbox_inches='tight')