In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import cv2

import scanpy as sc

import cuml
from micron2.spatial import get_neighbors, categorical_neighbors, sliding_window_niches, k_neighbor_niches
from micron2.clustering import cluster_leiden, run_tsne, plot_embedding, cluster_leiden_cu

from sklearn.cluster import MiniBatchKMeans

import scrna
import tqdm.auto as tqdm

In [None]:
!ls /storage/codex/datasets_v1

In [None]:
adata = sc.read_h5ad("/storage/codex/datasets_v1/bladder_merged_v6.h5ad")
adata

In [None]:
adata.obs.sample_id.value_counts()

In [None]:
adata.obs.subtype_gating.value_counts()

# all samples in a dataset

In [None]:
all_clusters = adata.obs['subtype_rescued'].copy()
cluster_levels, all_clusters = np.unique(all_clusters, return_inverse=True)

all_neighbor_profiles = []
obs_reshuffle = []
with tqdm.tqdm(np.unique(adata.obs.sample_id)) as pbar:
    for s in pbar:
        pbar.set_description(s)
        ad = adata[adata.obs.sample_id == s]
        obs_reshuffle.append( np.array(adata.obs_names[adata.obs.sample_id == s]) )

        coords = ad.obsm['coordinates'].copy()

        # Flip the height coordinate
        coords[:,1] = -coords[:,1]

        clusters = all_clusters[adata.obs.sample_id == s]
        neighbor_profiles = k_neighbor_niches(coords, clusters, k = 10, 
                                              u_clusters=np.arange(len(cluster_levels)), 
                                              aggregate='sum', max_dist=100)

        all_neighbor_profiles.append(neighbor_profiles.copy())
    
all_neighbor_profiles = np.concatenate(all_neighbor_profiles, axis=0)
OBS_reshuffle = np.concatenate(obs_reshuffle)

In [None]:
MBKM = MiniBatchKMeans(n_clusters=15, batch_size=1000, n_init=10, random_state=999)
niches = MBKM.fit_predict(all_neighbor_profiles)

# niches = cluster_leiden_cu(all_neighbor_profiles, resolution=0.3, nn_metric='correlation')
niche_levels , niches = np.unique(niches, return_inverse=True)
print(len(niche_levels))
for n in niche_levels:
    print(f'{n}: {np.sum(niches==n)}')

In [None]:
adata.obs['cell_niches'] = pd.DataFrame(niches , index=OBS_reshuffle, dtype='category')
adata.obs['cell_niches'] = pd.Categorical([f'{x:02d}' for x in adata.obs['cell_niches']])

In [None]:
profiles_df = pd.DataFrame(all_neighbor_profiles, index=OBS_reshuffle, columns=cluster_levels)
profiles_df = profiles_df.loc[adata.obs_names, :]
adata.obsm['niche_profiles'] = profiles_df.values
adata.uns['niche_profiles_colnames'] = cluster_levels
profiles_df

In [None]:
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*7,7), dpi=300) 
sc.pl.embedding(adata, basis='coordinates_shift', color='cell_niches', 
                ax=plt.gca(), s=0.5, palette='tab20')

In [None]:
adata.uns['subtype_rescued_colors'][4] = "#800410"
adata.uns['subtype_rescued_colors'][8] = "#6e597a"
adata.uns['subtype_rescued_colors'][10] = "#91078f"
adata.uns['subtype_rescued_colors'][12] = "#cc510a"
adata.uns['subtype_rescued_colors'][-1] = "#1fd1ce"

In [None]:
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*10,10), dpi=300) 
sc.pl.embedding(adata, basis='coordinates_shift', color='subtype_rescued', 
                ax=plt.gca(), s=1)

In [None]:
scrna.plot_group_percents

In [None]:
scrna.plot_group_percents(adata, 'subtype_rescued', 'cell_niches',
                          title = 'niche composition',
                          colors={k:c for k,c in zip(u_celltypes,adata.uns['subtype_rescued_colors'])},
                          annotate_total=False)

scrna.plot_group_percents(adata, 'cell_niches', 'subtype_rescued', title='niche occupation',
                          annotate_total=False)

In [None]:
from matplotlib import rcParams
rcParams['figure.facecolor'] = (1,1,1,1)
u_celltypes = np.unique(adata.obs.subtype_rescued)
def count_fn(x):
    l, c = np.unique(x['subtype_rescued'], return_counts=True)
    z = np.zeros(len(u_celltypes), dtype=np.int)
    for i,r in zip(l,c):
        z[u_celltypes==i] = r
    z = z/np.sum(c)
    return pd.Series(z, index=u_celltypes)

cols = ['cell_niches', 'subtype_rescued']
df = adata.obs.loc[:, cols].copy()
# df
df = df.groupby('cell_niches').apply(count_fn)

sns.clustermap(df, square=True, yticklabels=True, lw=1, cmap='Reds', 
               dendrogram_ratio=(0.2, 0.05),
               annot=True,
               fmt='2.2f'
               )


In [None]:
for l, n in zip(u_celltypes, df.index[ np.argmax(df.values, axis=0) ].to_numpy()):
    print(l, n)

In [None]:
l, u = np.unique(adata.obs.loc[adata.obs.cell_niches=='00', 'subtype_rescued'], return_counts=True)
_ = plt.pie(u, labels=l, autopct='%2.2f', pctdistance=2.3, rotatelabels=True, colors=adata.uns['subtype_rescued_colors'])

In [None]:
len(adata.uns['niche_profiles_colnames'])

In [None]:
niche_groups = {
    0: 'Epithelial-core',
    1: 'EpCDH-Mac',
    2: 'EpKRT',
    3: 'Stromal-Mac',
    4: 'Ep-EpCDH',
    5: 'Bcell',
    6: 'EpKRT',
    7: 'Mac-Ep',
    8: 'EpCDH',
    9: 'Stromal',
    10: 'Epithelial-core',
    11: 'EpCDH-Ep',
    12: 'EpCDH-Stromal',
    13: 'Ep-Stromal',
    14: 'Immune-mix',
#     15: 'Epithelial',
#     16: 'Stromal',
#     17: 'Endothelial',
#     18: 'EpCDH-Bcell',
#     19: 'Mac',
}
niche_groups_v = [niche_groups[n] for n in adata.obs['cell_niches']]
adata.obs['niche_labels'] = pd.Categorical(niche_groups_v)

In [None]:
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*15,15), dpi=300) 
sc.pl.embedding(adata, basis='coordinates_shift', color='niche_labels', 
                ax=plt.gca(), s=2)

In [None]:
adata.uns['niche_labels_colors'][-5] = '#65a8a2'
# adata.uns['niche_labels_colors'][4] = '#f2b6e4'

In [None]:
u_niche_labels = np.unique(niche_groups_v)
colors = adata.uns['niche_labels_colors']
row_colors = pd.DataFrame(index=np.arange(20), columns=['label'])
for j,v in sorted(niche_groups.items()):
    c = colors[np.argwhere(u_niche_labels==v).flatten()[0]]
    print(j,v,c)
    row_colors.loc[j,'label'] = c
    
sns.clustermap(df, square=True, yticklabels=True, lw=1, cmap='Reds', 
               dendrogram_ratio=(0.2, 0.05),
               annot=True,
               fmt='2.2f',
               row_colors=row_colors,
               figsize=(6,10)
               )

In [None]:
help(scrna.plot_group_percents)

In [None]:
colors = {l:c for l,c in zip(np.unique(adata.obs.niche_labels),adata.uns['niche_labels_colors'])}
scrna.plot_group_percents(adata, 'niche_labels', 'sample_id', sort_by='Immune-mix',
                          colors=colors, ncol=1)

In [None]:
s = 'Epithelial_CDH'
fig, axs = plt.subplots(5,3, figsize=(3*3,5*3), dpi=180)
axs = axs.ravel()
for i,s in enumerate(np.unique(adata.obs.subtype_rescued)):
    ax = axs[i]
    d = pd.DataFrame(adata.obsm['niche_profiles'][adata.obs.subtype_rescued==s],
                     columns=adata.uns['niche_profiles_colnames'])
    d = d.loc[:, ~d.columns.str.contains('Epithelial')]
    p = d.sum(axis=0)
    _ = ax.pie(p, labels=p.index, rotatelabels=True, 
               labeldistance=0.7)
    ax.set_title(s)


In [None]:
fig, axs = plt.subplots(1,3, figsize=(3*3,1*3), dpi=180)
axs = axs.ravel()
for i,s in enumerate(['Epithelial', 'Epithelial_KRT', 'Epithelial_CDH']):
    ax = axs[i]
    
    d = pd.DataFrame(adata.obsm['niche_profiles'][adata.obs.subtype_rescued==s],
                     columns=adata.uns['niche_profiles_colnames'])
    #d = d.loc[:, ~d.columns.isin(['Blank', 'Immune', 'Saturated'])]
    p = d.sum(axis=0)
#     p = p.loc[~p.index.str.contains('Epithelial')]
    _ = ax.pie(p, labels=p.index, rotatelabels=True, 
              labeldistance=0.9)
    
#     ep = np.sum(p.loc[p.index.str.contains('Epithelial')])
#     im = np.sum(p.loc[~p.index.str.contains('Epithelial') & (p.index!='Stromal')])
#     st = np.sum(p.loc['Stromal'])
#     _ = ax.pie([ep,im,st], labels=['Other Epithelial', 'Immune', 'Stromal'], rotatelabels=True, 
#                labeldistance=0.9)
    ax.set_title(s)