In [None]:
import logging
import os
from pathlib import Path

import scanpy as sc
import numpy as np
import pandas as pd
import scipy
import anndata
import squidpy as sq
import matplotlib.pyplot as plt
import tifffile
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from gensim.corpora import Dictionary
from gensim.models import LdaModel
from umap import UMAP

In [None]:
from mgitools.os_helpers import listfiles

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mip.gating import get_ideal_window

In [None]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)

In [None]:
analysis_dir = '/diskmnt/Projects/Users/estorrs/multiplex_data/analysis/brca_dcis_v1'
sc.settings.figdir = os.path.join(analysis_dir, 'figures')
Path(os.path.join(analysis_dir, 'figures')).mkdir(parents=True, exist_ok=True)

In [None]:
fps = sorted(listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/brca',
                       regex=r'dcis_neighborhood_analysis/preprocessed_adata.h5ad$'))
fps

In [None]:
def cell_to_neighbors(adata, radius=50):
    X = adata.obs[['centroid_row', 'centroid_col']].values
    nbrs = NearestNeighbors(algorithm='ball_tree').fit(X)
    
    g = nbrs.radius_neighbors_graph(X, radius=radius)
    rows, cols, _ = scipy.sparse.find(g)
    
    cell_to_neighbhors = {}
    for r, c in zip(rows, cols):
        cid = adata.obs.index[r]
        if cid not in cell_to_neighbhors:
            cell_to_neighbhors[cid] = []
        else:
            cell_to_neighbhors[cid].append(adata.obs.index[c])
            
    return cell_to_neighbhors

In [None]:
sample_to_adata = {fp.split('/')[-4]:sc.read_h5ad(fp) for fp in fps}
sample_to_adata.keys()

In [None]:
cell_to_nbhrs = {}
for sample, a in sample_to_adata.items():
    a = a[a.obs['passes_qc']]
    a.obs.index = [f'{sample}_{x}' for x in a.obs.index.to_list()]
    cell_to_nbhrs.update(cell_to_neighbors(a, radius=100))

In [None]:
# fps = sorted(listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/', regex=r'pseudo.tiff$'))
# fps

In [None]:
# sample_to_pseudo = {fp.split('/')[-3]:tifffile.imread(fp) for fp in fps}

In [None]:
# sample_to_adata.keys()

In [None]:
cells = []
docs = []
for s, a in sample_to_adata.items():
    cell_to_cell_type = {c:ct for c, ct in zip(a.obs.index, a.obs['harmonized_cell_type'])}
    docs += [[cell_to_cell_type[neighbor] for neighbor in cell_to_nbhrs[cell_id]]
            for cell_id in a.obs.index.to_list()]
    print(s, len(docs))
    cells += a.obs.index.to_list()

In [None]:
dictionary = Dictionary(docs)
corpus = [dictionary.doc2bow(doc) for doc in docs]

In [None]:
len(dictionary), len(corpus), len(cells)

In [None]:
num_topics = 10
chunksize = len(corpus)
passes = 2
iterations = 100
eval_every = 10 # turn this on to see how well everything is converging. off by default bc is takes time

In [None]:
temp = dictionary[0]
id2word = dictionary.id2token

model = LdaModel(
    corpus=corpus,
    id2word=id2word,
    chunksize=chunksize,
    alpha='auto',
    eta='auto',
    iterations=iterations,
    num_topics=num_topics,
    passes=passes,
    eval_every=eval_every
)

In [None]:
top_topics = model.top_topics(corpus)
avg_topic_coherence = sum([t[1] for t in top_topics]) / num_topics

In [None]:
def transformed_corpus_to_emb(tc, n_topics):
    embs = []
    for entity in tc:
        default = [0] * n_topics
        for topic, value in entity:
            default[topic] = value
        embs.append(default)
    return np.asarray(embs)
    

In [None]:
transformed = model[corpus]
embs = transformed_corpus_to_emb(transformed, num_topics)
embs.shape

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [None]:
kmeans = KMeans(n_clusters=20, random_state=0).fit(embs)
set(kmeans.labels_)

In [None]:
df = pd.DataFrame(data=embs, columns=np.arange(num_topics), index=cells)
lda_adata = anndata.AnnData(df)
lda_adata

In [None]:
lda_adata.obs['LDA_kmeans_cluster'] = [str(x) for x in kmeans.labels_]

In [None]:
lda_adata.write_h5ad(os.path.join(analysis_dir, 'lda.h5ad'))

In [None]:
sc.pl.matrixplot(lda_adata, var_names=lda_adata.var.index, groupby='LDA_kmeans_cluster', dendrogram=True,
                 save='lda_kmeans_cluster.pdf')

In [None]:
topic_df = pd.DataFrame(data=model.get_topics(), columns=[dictionary.get(i) for i in range(len(dictionary))],
                        index=np.arange(num_topics))
import seaborn as sns
sns.clustermap(topic_df, cmap='Blues')
plt.tight_layout()
plt.savefig(os.path.join(analysis_dir, 'figures', 'topic_heatmap.pdf'))

In [None]:
topic_map = {
    0: 'Immune - M1/M2 macrophage',
    1: 'Normalish Duct - center',
    2: 'Stroma - Fibroblast',
    3: 'Noise',
    4: 'Tumor - boundary',
    5: 'Immune - Mixed T cell',
    6: 'Normalish Duct - boundary',
    7: 'Tumor - center',
    8: 'Vasculature',
    9: 'Immune - CD8 T cell dominant'
}


In [None]:
for sample, a in sample_to_adata.items():
    for t in lda_adata.var.index:
        a.obs[f'topic_{t}'] = lda_adata[a.obs.index, t].X.flatten()
        a.obs['LDA_kmeans_cluster'] = lda_adata[a.obs.index].obs['LDA_kmeans_cluster'].to_list()
    a.uns['topic_map'] = topic_map

In [None]:
def visualize_roi(adata, scale=1000, size=10, color='harmonized_cell_type'):
    fig, ax = plt.subplots(
        figsize=(int(max(adata.obs['centroid_col']) / scale), int(max(adata.obs['centroid_row']) / scale)))
    sc.pl.scatter(adata, x='centroid_col', y='centroid_row_inverted',
                  color=color, size=size, ax=ax)
    
def visualize_topics(adata, size=2, n_cols=5, scale=3):
    ratio = max(adata.obs['centroid_row']) / max(adata.obs['centroid_col'])
    n_topics = len([c for c in adata.obs.columns if 'topic' in c])
    n_rows = (n_topics // n_cols) + 1
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
        figsize=(n_cols * scale, n_rows * (ratio * scale)))
    i = 0
    for r in range(n_rows):
        for c in range(n_cols):
            if i < n_topics:
                ax = axs[r, c]
                ax.scatter(adata.obs['centroid_col'], adata.obs['centroid_row_inverted'],
                           s=size, c=adata.obs[f'topic_{i}'])
                ax.set_xticks([])
                ax.set_yticks([])
#                 ax.set_title(f'topic_{i}')
                ax.set_title(topic_map[i])
            i += 1

In [None]:
visualize_topics(next(iter(sample_to_adata.values())))

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    visualize_roi(a)

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    visualize_roi(a, color='LDA_kmeans_cluster')

In [None]:
metacluster_to_cluster = {
    'Tumor - Pure': [3, 17],
    'Tumor - Vasculature': [2],
    'Tumor - Immune Cold': [14, 1],
    'Macrophage - Mixed': [0],
    'Macrophage - TAM': [8],
    'Endothelial': [16],
    'Fibroblast - Deserted': []
    'Noise': [12],
    
}
cluster_to_metacluster = {str(v):k for k, vs in metacluster_to_cluster.items() for v in vs}
sorted(cluster_to_metacluster.items())

In [None]:
cell_to_kmeans = {c:str(k) for c, k in zip(cells, kmeans.labels_)}
cell_to_metacluster = {c:cluster_to_metacluster[cell_to_kmeans[c]]
                      for c in cells}
for s, a in sample_to_adata.items():
    f = df.loc[a.obs.index.to_list()]
    a.obsm['X_lda'] = f.values
    a.obs['LDA_kmeans_cluster'] = [cell_to_kmeans[c] for c in a.obs.index.to_list()]
    a.obs['metacluster'] = [cell_to_metacluster[c] for c in a.obs.index.to_list()]

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
plt.rcParams["figure.dpi"] = 120

In [None]:
a.uns['metacluster_colors']

In [None]:
def show_cluster(adata, cluster, cluster_col='metacluster', radius=300):
    r1, r2, c1, c2 = get_ideal_window(
        adata, radius=radius, cell_type=cluster, cell_type_col=cluster_col,
        return_filtered=False)
    sc.pl.spatial(adata, color=cluster_col, crop_coord=[c1, c2, r1, r2], size=1.)
    return r1, r2, c1, c2
    
def display_on_img(adata, img, cluster, cluster_col='metacluster', radius=300, show_all=False,
                  pallete=sns.color_palette('tab20'), s=5, edgecolors='black', ax=None, legend=True,
                  pallete_map=None):
    f, (r1, r2, c1, c2) = get_ideal_window(
        adata, radius=radius, cell_type=cluster, cell_type_col=cluster_col,
        return_filtered=True)
    
    if ax is None:
        fig, ax = plt.subplots()
    im = ax.imshow(img[r1:r2, c1:c2])
    
    if show_all:
        f = adata[((adata.obs['centroid_row']>r1)&(adata.obs['centroid_row']<r2))]
        f = f[((f.obs['centroid_col']>c1)&(f.obs['centroid_col']<c2))]
        
        for ct, color in zip(sorted(set(f.obs[cluster_col])), pallete):
            fx = f[f.obs[cluster_col]==ct]
            if pallete_map is not None:
                color = pallete_map[ct]
            ax.scatter(fx.obs['centroid_col'] - c1, fx.obs['centroid_row'] - r1, c=color, label=ct, s=s,
                      edgecolors=edgecolors)
            
        if legend:
            ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    else:
        ax.scatter(f.obs['centroid_col'] - c1, f.obs['centroid_row'] - r1, c='red', s=s,
                  edgecolors=edgecolors)
    return r1, r2, c1, c2

In [None]:
s = 'HT206B1-H1'
a = sample_to_adata[s]
a.shape

In [None]:
display_on_img(a, sample_to_pseudo[s], 'Tumor - Infiltrating T cell', cluster_col='metacluster', radius=1000,
               show_all=True, s=1, edgecolors=None)

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    display_on_img(a, sample_to_pseudo[s], metacluster, cluster_col='metacluster', radius=300)
    
    plt.show()

In [None]:
metacluster = 'Tumor - Infiltrating T cell'

fig, axs = plt.subplots(ncols=len(sample_to_adata), figsize=(20, 5))
m = {}
for s, a in sample_to_adata.items():
    m.update({ct:c for ct, c in zip(sorted(set(a.obs['metacluster'])), sns.color_palette('tab20'))})
for i, (s, a) in enumerate(sample_to_adata.items()):
    print(s)
    ax = axs[i]
    display_on_img(a, sample_to_pseudo[s], metacluster, cluster_col='metacluster', radius=300, show_all=False,
                   s=5, ax=ax, edgecolors=None)
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()

In [None]:
metacluster = 'Tumor - Pure'

fig, axs = plt.subplots(ncols=len(sample_to_adata), figsize=(20, 5))
m = {}
for s, a in sample_to_adata.items():
    m.update({ct:c for ct, c in zip(sorted(set(a.obs['metacluster'])), sns.color_palette('tab20'))})
for i, (s, a) in enumerate(sample_to_adata.items()):
    print(s)
    ax = axs[i]
    display_on_img(a, sample_to_pseudo[s], metacluster, cluster_col='metacluster', radius=500, show_all=False,
                   s=5, ax=ax, edgecolors=None)
    ax.set_xticks([])
    ax.set_yticks([])
plt.show()

In [None]:
metacluster = 'Tumor - Infiltrating T cell'
for s, a in sample_to_adata.items():
    print(s)
    display_on_img(a, sample_to_pseudo[s], metacluster, cluster_col='metacluster', radius=300, show_all=True,
                   s=20, pallete_map=m)
    plt.show()

In [None]:
metacluster = 'Noise'
for s, a in sample_to_adata.items():
    print(s)
    display_on_img(a, sample_to_pseudo[s], metacluster, cluster_col='metacluster', radius=300, show_all=True,
                   s=20)
    plt.show()

In [None]:
for s, a in sample_to_adata.items():
    fp = f'/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/{s}/level_4/metacluster_lda.h5ad'
    a.write_h5ad(fp)

In [None]:
from collections import Counter
data, idxs = [], []
cols = sorted(set(cell_to_metacluster.values()))
for s, a in sample_to_adata.items():
    counts = Counter(a.obs['metacluster'])
    data.append([counts.get(c, 0) for c in cols])
    idxs.append(s)
df = pd.DataFrame(data=data, index=idxs, columns=cols)
df

In [None]:
ax = df.plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
ax = (df / df.sum(axis=1).values.reshape(-1, 1)).plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
data, idxs = [], []
cols = sorted(set(a.obs['cell_type']))
for s, a in sample_to_adata.items():
    counts = Counter(a.obs['cell_type'])
    data.append([counts.get(c, 0) for c in cols])
    idxs.append(s)
df = pd.DataFrame(data=data, index=idxs, columns=cols)
df

In [None]:
ax = df.plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
ax = (df / df.sum(axis=1).values.reshape(-1, 1)).plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
sq.gr.spatial_neighbors(a, key_added='spatial')

In [None]:
sq.gr.interaction_matrix(a, cluster_key="metacluster")

In [None]:
sq.pl.interaction_matrix(a, cluster_key="metacluster", vmax=10000)

In [None]:
sq.gr.co_occurrence(a, cluster_key="metacluster", n_splits=1, n_jobs=40,
                    interval=[32, 64, 128, 256, 512, 1028])

In [None]:
sq.pl.co_occurrence(
    a,
    cluster_key="metacluster",
    clusters=["Myoepithelium"],
    figsize=(15, 4),
)

In [None]:
sq.pl.co_occurrence(
    a,
    cluster_key="metacluster",
    clusters=["Immune - T cell"],
    figsize=(15, 4),
)

In [None]:
for c in sorted(set(a.obs['metacluster'])):
    sq.pl.co_occurrence(
        a,
        cluster_key="metacluster",
        clusters=[c],
#         dpi=220
        figsize=(15, 4),
    )

In [None]:
fps = sorted(listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/', regex=r'metacluster_lda.h5ad$'))
fps

In [None]:
sample_to_adata = {fp.split('/')[-3]:sc.read_h5ad(fp) for fp in fps}

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    sq.gr.spatial_neighbors(a, key_added='spatial')
    sq.gr.interaction_matrix(a, cluster_key="metacluster")
    sq.gr.co_occurrence(a, cluster_key="metacluster", n_splits=1, n_jobs=40, interval=[50, 100, 200, 500, 1000])
    sq.gr.interaction_matrix(a, cluster_key="cell_type")
    sq.gr.co_occurrence(a, cluster_key="cell_type", n_splits=1, n_jobs=40, interval=[50, 100, 200, 500, 1000])
    a.write_h5ad(f'/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/{s}/level_4/metacluster_spatial_analysis.h5ad')
    