In [None]:
import logging
import os
from pathlib import Path

import scanpy as sc
import numpy as np
import pandas as pd
import scipy
import anndata
import squidpy as sq
import matplotlib.pyplot as plt
import tifffile
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.metrics import roc_auc_score, roc_curve
from gensim.corpora import Dictionary
from gensim.models import LdaModel
from umap import UMAP

In [None]:
from mgitools.os_helpers import listfiles

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from mip.gating import get_ideal_window

In [None]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)

In [None]:
analysis_dir = '/diskmnt/Projects/Users/estorrs/multiplex_data/analysis/brca_dcis_v1'
sc.settings.figdir = os.path.join(analysis_dir, 'figures')
Path(os.path.join(analysis_dir, 'figures')).mkdir(parents=True, exist_ok=True)

In [None]:
fps = sorted(listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/brca',
                       regex=r'dcis_neighborhood_analysis/preprocessed_adata.h5ad$'))
fps

In [None]:
def cell_to_neighbors(adata, radius=50):
    X = adata.obs[['centroid_row', 'centroid_col']].values
    nbrs = NearestNeighbors(algorithm='ball_tree').fit(X)
    
    g = nbrs.radius_neighbors_graph(X, radius=radius)
    rows, cols, _ = scipy.sparse.find(g)
    
    cell_to_neighbhors = {}
    for r, c in zip(rows, cols):
        cid = adata.obs.index[r]
        if cid not in cell_to_neighbhors:
            cell_to_neighbhors[cid] = []
        else:
            cell_to_neighbhors[cid].append(adata.obs.index[c])
            
    return cell_to_neighbhors

In [None]:
sample_to_adata = {}
cell_to_nbhrs = {}
for fp in fps:
    sample = fp.split('/')[-4]
    a = sc.read_h5ad(fp)
    print(sample, a.shape)
    
    if 'passes_qc' in a.obs.columns:
        a = a[a.obs['passes_qc']]
        
    a = a[a.obs['harmonized_cell_type']!='Noise']
        
    a.obs['cell_id'] = a.obs.index.to_list()
    a.obs.index = [f'{sample}_{x}' for x in a.obs.index.to_list()]
    cell_to_nbhrs.update(cell_to_neighbors(a, radius=100))
    sample_to_adata[sample] = a

In [None]:
# fps = sorted(listfiles('/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/', regex=r'pseudo.tiff$'))
# fps

In [None]:
# sample_to_pseudo = {fp.split('/')[-3]:tifffile.imread(fp) for fp in fps}

In [None]:
# sample_to_adata.keys()

In [None]:
cells = []
docs = []
for s, a in sample_to_adata.items():
    cell_to_cell_type = {c:ct for c, ct in zip(a.obs.index.to_list(), a.obs['harmonized_cell_type'].to_list())}
    docs += [[cell_to_cell_type[neighbor] for neighbor in cell_to_nbhrs[cell_id]]
            for cell_id in a.obs.index.to_list()]
    print(s, len(docs))
    cells += a.obs.index.to_list()

In [None]:
dictionary = Dictionary(docs)
corpus = [dictionary.doc2bow(doc) for doc in docs]

In [None]:
len(dictionary), len(corpus), len(cells)

In [None]:
num_topics = 10
chunksize = len(corpus)
passes = 2
iterations = 100
eval_every = 10 # turn this on to see how well everything is converging. off by default bc is takes time

In [None]:
temp = dictionary[0]
id2word = dictionary.id2token

model = LdaModel(
    corpus=corpus,
    id2word=id2word,
    chunksize=chunksize,
    alpha='auto',
    eta='auto',
    iterations=iterations,
    num_topics=num_topics,
    passes=passes,
    eval_every=eval_every
)

In [None]:
top_topics = model.top_topics(corpus)
avg_topic_coherence = sum([t[1] for t in top_topics]) / num_topics

In [None]:
def transformed_corpus_to_emb(tc, n_topics):
    embs = []
    for entity in tc:
        default = [0] * n_topics
        for topic, value in entity:
            default[topic] = value
        embs.append(default)
    return np.asarray(embs)
    

In [None]:
transformed = model[corpus]
embs = transformed_corpus_to_emb(transformed, num_topics)
embs.shape

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [None]:
kmeans = KMeans(n_clusters=20, random_state=0).fit(embs)
set(kmeans.labels_)

In [None]:
df = pd.DataFrame(data=embs, columns=np.arange(num_topics), index=cells)
lda_adata = anndata.AnnData(df)
lda_adata

In [None]:
lda_adata.obs['LDA_kmeans_cluster'] = [str(x) for x in kmeans.labels_]

In [None]:
lda_adata.write_h5ad(os.path.join(analysis_dir, 'lda.h5ad'))

In [None]:
sc.pl.matrixplot(lda_adata, var_names=lda_adata.var.index, groupby='LDA_kmeans_cluster', dendrogram=True,
                 save='lda_kmeans_cluster.pdf')

In [None]:
topic_df = pd.DataFrame(data=model.get_topics(), columns=[dictionary.get(i) for i in range(len(dictionary))],
                        index=np.arange(num_topics))
import seaborn as sns
sns.clustermap(topic_df, cmap='Blues')
plt.tight_layout()
plt.savefig(os.path.join(analysis_dir, 'figures', 'topic_heatmap.pdf'))

In [None]:
topic_map = {
    0: 'Immune - M2 macrophage dominant',
    1: 'Mixed Stroma/Immune',
    2: 'Endothelial',
    3: 'Tumor boundary - Immune Low',
    4: 'Immune - M1 macrophage dominant',
    5: 'Tumor boundary - Immune High',
    6: 'Normal Duct',
    7: 'Immune - T cell High',
    8: 'Stroma',
    9: 'Tumor'
}


In [None]:
for sample, a in sample_to_adata.items():
    for t in lda_adata.var.index:
        a.obsm[f'LDA_topics'] = lda_adata[a.obs.index].X
        a.obs['LDA_kmeans_cluster'] = lda_adata[a.obs.index].obs['LDA_kmeans_cluster'].to_list()
        a.obsm['LDA_topics_named'] = pd.DataFrame(data=lda_adata[a.obs.index].X,
                                                  columns=[topic_map[c] for c in range(lda_adata.shape[1])],
                                                  index=a.obs.index.to_list())
    a.uns['topic_map'] = topic_map

In [None]:
def visualize_roi(adata, scale=1000, size=10, color='harmonized_cell_type'):
    fig, ax = plt.subplots(
        figsize=(int(max(adata.obs['centroid_col']) / scale), int(max(adata.obs['centroid_row']) / scale)))
    sc.pl.scatter(adata, x='centroid_col', y='centroid_row_inverted',
                  color=color, size=size, ax=ax)
    
def visualize_topics(adata, size=2, n_cols=5, scale=3):
    ratio = max(adata.obs['centroid_row']) / max(adata.obs['centroid_col'])
    n_topics = len([c for c in adata.obs.columns if 'topic' in c])
    n_rows = (n_topics // n_cols) + 1
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
        figsize=(n_cols * scale, n_rows * (ratio * scale)))
    i = 0
    for r in range(n_rows):
        for c in range(n_cols):
            if i < n_topics:
                ax = axs[r, c]
                ax.scatter(adata.obs['centroid_col'], adata.obs['centroid_row_inverted'],
                           s=size, c=adata.obsm['LDA_topics_named'][topic_map[i]])
                ax.set_xticks([])
                ax.set_yticks([])
#                 ax.set_title(f'topic_{i}')
                ax.set_title(topic_map[i])
            i += 1

In [None]:
visualize_topics(next(iter(sample_to_adata.values())))

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    visualize_roi(a)

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    visualize_roi(a, color='LDA_kmeans_cluster')

exporting to qitissue cluster file

In [None]:
for sample, a in sample_to_adata.items():
    df = a.obs[['LDA_kmeans_cluster', 'cell_id']]
    df = df.set_index('cell_id')
    df.index.name = 'CellID'
    df.columns = ['Cluster']
    directory = os.path.join(analysis_dir, 'qitissue')
    Path(directory).mkdir(parents=True, exist_ok=True)
    df.to_csv(os.path.join(directory, f'kmeans_cluster_file_{sample}.csv'))

In [None]:
topic_map

In [None]:
p = lda_adata.copy()
p.var.index = [topic_map[int(i)] for i in lda_adata.var.index]
sc.pl.matrixplot(p, var_names=p.var.index, groupby='LDA_kmeans_cluster', dendrogram=True,
                 save='lda_kmeans_cluster_named.pdf')

In [None]:
topic_df = pd.DataFrame(data=model.get_topics(), columns=[dictionary.get(i) for i in range(len(dictionary))],
                        index=[topic_map[i] for i in np.arange(num_topics)])
import seaborn as sns
sns.clustermap(topic_df, cmap='Blues')
plt.tight_layout()
plt.savefig(os.path.join(analysis_dir, 'figures', 'topic_heatmap_named.pdf'))

In [None]:
metacluster_to_cluster = {
    'Tumor Boundary - Stroma Enriched 1': [13],
    'Tumor': [1, 14],
    'Myoepithelial': [4, 10, 17],
    'Tumor Boundary - M2 Macrophage Enriched': [0],
    'Tumor Boundary - Mixed Immune Enriched': [16],
    'Tumor Boundary - Stroma Enriched 2': [8, 3],
    'Macrophage M1 enriched': [2, 15],
    'Tumor Boundary - Mixed': [9],
    'Mixed Immune/Stroma': [18, 6, 5],
    'Stroma': [7],
    'T cell Enriched': [11],
    'Endothelial': [19],
    'Macrophage M2 - enriched': [12],
    
}
cluster_to_metacluster = {str(v):k for k, vs in metacluster_to_cluster.items() for v in vs}
sorted(cluster_to_metacluster.items())

In [None]:
cell_to_kmeans = {c:str(k) for c, k in zip(cells, kmeans.labels_)}
cell_to_metacluster = {c:cluster_to_metacluster[cell_to_kmeans[c]]
                      for c in cells}
for s, a in sample_to_adata.items():
    a.obsm['X_lda'] = lda_adata[a.obs.index.to_list()].X
    a.obs['metacluster'] = [cell_to_metacluster[c] for c in a.obs.index.to_list()]

In [None]:
for s, a in sample_to_adata.items():
    print(s)
    visualize_roi(a, color='metacluster')

In [None]:
order = sorted(set(next(iter(sample_to_adata.values())).obs['metacluster']))
[(i + 1, o) for i, o in enumerate(order)]

In [None]:
for sample, a in sample_to_adata.items():
    df = a.obs[['metacluster', 'cell_id']]
    df = df.set_index('cell_id')
    df.index.name = 'CellID'
    df.columns = ['Cluster']
    df['Cluster'] = [order.index(x) + 1 for x in df['Cluster']]
    directory = os.path.join(analysis_dir, 'qitissue')
    Path(directory).mkdir(parents=True, exist_ok=True)
    df.to_csv(os.path.join(directory, f'metacluster_{sample}.csv'))

In [None]:
for s, a in sample_to_adata.items():
    # topic map wont save with int
    a.uns['topic_map'] = {str(k):v for k, v in a.uns['topic_map'].items()}
    directory = os.path.join(analysis_dir, 'metaclustered')
    Path(directory).mkdir(parents=True, exist_ok=True)
    a.write_h5ad(os.path.join(directory, f'{s}.h5ad'))

In [None]:
from collections import Counter
data, idxs = [], []
cols = sorted(set(cell_to_metacluster.values()))
for s, a in sample_to_adata.items():
    counts = Counter(a.obs['metacluster'])
    data.append([counts.get(c, 0) for c in cols])
    idxs.append(s)
df = pd.DataFrame(data=data, index=idxs, columns=cols)
df

In [None]:
ax = df.plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
ax = (df / df.sum(axis=1).values.reshape(-1, 1)).plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
data, idxs = [], []
cols = sorted(set(a.obs['harmonized_cell_type']))
for s, a in sample_to_adata.items():
    counts = Counter(a.obs['harmonized_cell_type'])
    data.append([counts.get(c, 0) for c in cols])
    idxs.append(s)
df = pd.DataFrame(data=data, index=idxs, columns=cols)
df

In [None]:
ax = df.plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
ax = (df / df.sum(axis=1).values.reshape(-1, 1)).plot(kind='bar', stacked=True, color=sns.color_palette('tab20'))
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

for spatial analyses to work we need to format the anndata object how squidpy expects

In [None]:
for s, a in sample_to_adata.items():
    a.obsm['spatial'] = a.obs[['centroid_col', 'centroid_row']].values

In [None]:
directory = os.path.join(analysis_dir, 'distance_metrics')
Path(directory).mkdir(parents=True, exist_ok=True)
for s, a in sample_to_adata.items():
    print(s)
    sq.gr.spatial_neighbors(a, key_added='spatial')
    sq.gr.interaction_matrix(a, cluster_key="metacluster")
    sq.gr.co_occurrence(a, cluster_key="metacluster", n_splits=1, n_jobs=40,
                        interval=[32, 64, 128, 256])
    sq.gr.interaction_matrix(a, cluster_key="harmonized_cell_type")
    sq.gr.co_occurrence(a, cluster_key="harmonized_cell_type", n_splits=1, n_jobs=40,
                        interval=[32, 64, 128, 256])
    a.write_h5ad(os.path.join(directory, f'{s}.h5ad'))

In [None]:
# sq.pl.interaction_matrix(a, cluster_key="metacluster", vmax=10000)

In [None]:
for s, a in sample_to_adata.items():
    sq.pl.co_occurrence(
        a,
        cluster_key="metacluster",
        clusters=["Tumor"],
        figsize=(15, 4),
    )
    plt.title(s)
    plt.show()

In [None]:
for s, a in sample_to_adata.items():
    sq.pl.co_occurrence(
        a,
        cluster_key="harmonized_cell_type",
        clusters=["Tumor"],
        figsize=(15, 4),
    )
    plt.title(s)
    plt.show()

In [None]:
for s, a in sample_to_adata.items():
    f = a[a.obs['harmonized_cell_type']=='Tumor']
    scores = f.raw[:, 'MGP'].X.flatten()
    for i, metacluster in enumerate(order):
        y_true = [1 if m==metacluster else 0
                  for m in f.obs['metacluster']]
        fpr, tpr, thresholds = roc_curve(y_true, scores)
        plt.plot(fpr, tpr, label=metacluster, c=sns.color_palette('tab20')[i])
    plt.title(s)
    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.show()

In [None]:
data = None
for s, a in sample_to_adata.items():
    f = a[a.obs['harmonized_cell_type']=='Tumor'].copy()
    scores = f.raw[:, 'MGP'].X.flatten()
    f.obs['tumor_MGP'] = scores
    df = f.obs[['sample_id', 'tumor_MGP', 'harmonized_cell_type', 'metacluster']]
    
    if data is None:
        data = df
    else:
        data = pd.concat((data, df), axis=0)

    
fig, ax = plt.subplots(figsize=(14, 8))
sns.boxplot(data=data, x='sample_id', y='tumor_MGP', hue='metacluster', ax=ax, palette=sns.color_palette('tab20'))
plt.xticks(rotation=90)
plt.legend(bbox_to_anchor=(1,1), loc="upper left")
plt.show()

In [None]:
for s, a in sample_to_adata.items():
    ls = a.raw[a.obs['harmonized_cell_type']=='Tumor', 'MGP'].X.flatten()
    ls = np.random.choice(ls, size=5000)
    sns.distplot(ls, label=s)
    plt.legend()

In [None]:
from skimage.segmentation import find_boundaries
import mip.utils as utils
from mip.gating import gate_region
base_dir = '/diskmnt/Projects/Users/estorrs/multiplex_data/codex/htan/brca'
sample_to_mgp_img = {}
sample_to_boundary_img = {}
for sample, a in sample_to_adata.items():
    print(sample)
    channel_to_img = utils.extract_ome_tiff(os.path.join(base_dir, sample, 'level_2', f'{sample}.ome.tiff'))
    seg_img = tifffile.imread(os.path.join(base_dir, sample, 'level_3', 'segmentation', 'cell_segmentation.tif'))
    boundary_img = find_boundaries(seg_img)
    
    sample_to_mgp_img[sample] = channel_to_img['MGP'].copy()
    sample_to_boundary_img[sample] = boundary_img
    
    del channel_to_img        

In [None]:
sample = 'HT323B1-H3'
gate_region(sample_to_adata[sample], 'MGP', channel_img=sample_to_mgp_img[sample],
            boundary_img=sample_to_boundary_img[sample],
            cell_type='Tumor', cell_type_col='harmonized_cell_type',
            default_value=7., radius=2000)

In [None]:
mgp_thresholds = {
    'HT206B1-H1': 8.57,
    'HT206B1_H1_06252022': 8.3,
    'HT323B1-H1A1': 8.36,
    'HT323B1-H1A4': 8.36,
    'HT323B1-H3': 8.75,
    'HT397B1-H2A2': 7.8,
    'HT397B1-H3A1': 7.25
}

In [None]:
for sample, val in mgp_thresholds.items():
    a = sample_to_adata[sample]
    a.obs['is_mgp_positive'] = [True if x >= mgp_thresholds[sample] and ct == 'Tumor' else False
                                for ct, x in zip(a.obs['harmonized_cell_type'], a.raw[:, 'MGP'].X.flatten())]

In [None]:
visualize_roi(sample_to_adata['HT206B1-H1'], color='is_mgp_positive')

In [None]:
def call_mgp_cell(is_mgp, ct, m):
    if ct == 'Tumor' and is_mgp:
        return 'Tumor - MGP positive'
    elif ct == 'Tumor' and not is_mgp:
        return 'Tumor - MGP negative'
    return m
    
for s, a in sample_to_adata.items():
    a.obs['metacluster_with_mgp'] = pd.Categorical(
        [call_mgp_cell(is_mgp, ct, m)
         for is_mgp, m, ct in zip(a.obs['is_mgp_positive'], a.obs['metacluster'], a.obs['harmonized_cell_type'])])
    
    a.obs['harmonized_cell_type_with_mgp'] = pd.Categorical(
        [call_mgp_cell(is_mgp, ct, m)
         for is_mgp, m, ct in zip(a.obs['is_mgp_positive'], a.obs['harmonized_cell_type'], a.obs['harmonized_cell_type'])])
    

In [None]:
visualize_roi(sample_to_adata['HT206B1-H1'], color='metacluster_with_mgp')

In [None]:
visualize_roi(sample_to_adata['HT206B1-H1'], color='harmonized_cell_type_with_mgp')

In [None]:
directory = os.path.join(analysis_dir, 'distance_metrics')
for s, a in sample_to_adata.items():
    print(s)
    sq.gr.interaction_matrix(a, cluster_key="metacluster_with_mgp")
    sq.gr.co_occurrence(a, cluster_key="metacluster_with_mgp", n_splits=1, n_jobs=40,
                        interval=[32, 64, 128, 256])
    sq.gr.interaction_matrix(a, cluster_key="harmonized_cell_type_with_mgp")
    sq.gr.co_occurrence(a, cluster_key="harmonized_cell_type_with_mgp", n_splits=1, n_jobs=40,
                        interval=[32, 64, 128, 256])
    a.write_h5ad(os.path.join(directory, f'{s}.h5ad'))

In [None]:
a = sample_to_adata['HT206B1-H1']
a

In [None]:
sq.pl.interaction_matrix(a, cluster_key="metacluster_with_mgp", vmax=10000)

In [None]:
for sample, a in sample_to_adata.items():
    sq.pl.interaction_matrix(a, cluster_key="metacluster_with_mgp", vmax=50000)
    plt.title(sample)
    plt.show()