In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from matplotlib import rcParams 
from matplotlib import pyplot as plt 
import seaborn as sns

import scanpy as sc
import pandas as pd
import numpy as np
import os

import cuml
from micron2.spatial import get_neighbors, categorical_neighbors, sliding_window_niches, k_neighbor_niches
from micron2.clustering import cluster_leiden, run_tsne, plot_embedding, cluster_leiden_cu

from sklearn.cluster import MiniBatchKMeans

In [None]:
!ls /storage/codex/datasets_v1

In [None]:
adata = sc.read_h5ad("/storage/codex/datasets_v1/bladder_merged_v5.h5ad")
all_clusters = np.array(adata.obs.celltype_gating)
cluster_levels, all_clusters = np.unique(all_clusters, return_inverse=True)
adata

In [None]:
adata.obs.sample_id.value_counts()

In [None]:
rcParams['figure.facecolor'] = (1,1,1,1)

n_niches = 10

ref_sample = '210226_Bladder_TMA1_reg30_v4'
ad = adata[adata.obs.sample_id==ref_sample].copy()
coords = ad.obsm['coordinates']
clusters = all_clusters[adata.obs.sample_id==ref_sample]

neighbor_profiles = k_neighbor_niches(coords, clusters, k = 20, 
                                      u_clusters=np.arange(len(cluster_levels)), 
                                      aggregate='sum', max_dist=200, backend='sklearn')

MBKM = MiniBatchKMeans(n_clusters=n_niches, batch_size=1000, n_init=10, random_state=999)
niches = MBKM.fit_predict(neighbor_profiles)

rcParams['figure.figsize'] = (4,4)
rcParams['figure.dpi'] = 180
ad.obs['niches'] = pd.Categorical(niches)
sc.pl.embedding(ad, basis='coordinates', color='niches', s=4)

In [None]:
q_sample = '210226_Bladder_TMA1_reg28_v4'
ad = adata[adata.obs.sample_id==q_sample].copy()
coords = ad.obsm['coordinates']
clusters = all_clusters[adata.obs.sample_id==q_sample]

q_neighbor_profiles = k_neighbor_niches(coords, clusters, k = 20, 
                                        u_clusters=np.arange(len(cluster_levels)), 
                                        aggregate='sum', max_dist=200, backend='sklearn')

In [None]:
rcParams['figure.figsize'] = (4,2)
rcParams['figure.dpi'] = 90

vmax = 0
for i in range(n_niches):
    p = neighbor_profiles[niches==i,:].mean(axis=0, keepdims=True)
    score = np.matmul(p,q_neighbor_profiles.T).T
    ad.obs[f'niche_{i}_score'] = score.copy()
    if np.max(score)>vmax:
        vmax=np.max(score)

rcParams['figure.figsize'] = (2,2)
sc.pl.embedding(ad, basis='coordinates', color=[f'niche_{i}_score' for i in range(n_niches)],
                vmax=vmax)