In [None]:
import pathlib
import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc

from ALLCools.clustering import ConsensusClustering, Dendrogram, get_pc_centers
from ALLCools.plot import *

In [None]:
# clustering name
clustering_name = 'L1'

# input data
metadata_path = './CellMetadata.AfterQC.pdpkl'

adata_path = './adata.with_coords.h5ad'
coord_base = 'tsne'

# ConsensusClustering
# Important factores
n_neighbors = 25
leiden_resolution = 1
# this parameter is the final target that limit the total number of clusters
# Higher accuracy means more conservative clustering results and less number of clusters
target_accuracy = 0.96
min_cluster_size = 20

# Other ConsensusClustering parameters
metric = 'euclidean'
consensus_rate = 0.7
leiden_repeats = 500
random_state = 0
train_frac = 0.5
train_max_n = 500
max_iter = 50
n_jobs = 40

# Dendrogram via Multiscale Bootstrap Resampling
nboot = 10000
method_dist = 'correlation'
method_hclust = 'average'

plot_type = 'static'

plot_merge_steps = False

In [None]:
cell_meta = pd.read_pickle(metadata_path)
adata = anndata.read_h5ad(adata_path)

In [None]:
cc = ConsensusClustering(model=None,
                         n_neighbors=n_neighbors,
                         metric=metric,
                         min_cluster_size=min_cluster_size,
                         leiden_repeats=leiden_repeats,
                         leiden_resolution=leiden_resolution,
                         consensus_rate=consensus_rate,
                         random_state=random_state,
                         train_frac=train_frac,
                         train_max_n=train_max_n,
                         max_iter=max_iter,
                         n_jobs=n_jobs,
                         target_accuracy=target_accuracy)

In [None]:
if 'X_pca' not in adata.obsm:
    raise KeyError(
        'X_pca do not exist in the adata file, run PCA first before clustering.'
    )
cc.fit_predict(adata.obsm['X_pca'])

In [None]:
if plot_merge_steps:
    cc.plot_merge_process(plot_size=3)

In [None]:
# fig, axes = cc.plot_leiden_cases(coord_data=adata.obs,
#                                  coord_base=coord_base)

In [None]:
if plot_merge_steps:
    cc.plot_steps(coord_data=adata.obs, coord_base=coord_base)


In [None]:
adata.obs[clustering_name] = cc.label

fig, ax = plt.subplots(figsize=(4, 4), dpi=250)
_ = categorical_scatter(data=adata.obs,
                        ax=ax,
                        hue=clustering_name,
                        coord_base=coord_base,
                        palette='tab20',
                        text_anno=clustering_name,
                        show_legend=True)

In [None]:
adata.obs[clustering_name + '_proba'] = cc.label_proba

fig, ax = plt.subplots(figsize=(4, 4), dpi=250)
_ = continuous_scatter(data=adata.obs,
                       ax=ax,
                       hue_norm=(0, 1),
                       hue=clustering_name + '_proba',
                        text_anno=clustering_name,
                       coord_base=coord_base)

In [None]:
fig, ax = plt.subplots(figsize=(6, 3), dpi=300)

sns.violinplot(data=adata.obs,
               x=clustering_name,
               y=clustering_name + '_proba',
               scale='width',
               linewidth=0.5,
               cut=0,
               ax=ax)
ax.set(ylim=(0, 1), title='Prediction Probability Per Cluster')
ax.xaxis.set_tick_params(rotation=90)
ax.grid(linewidth=0.5, color='gray', linestyle='--')
sns.despine(ax=ax)


In [None]:
# # using the cluster centroids in PC space to calculate dendrogram
# pc_center = get_pc_centers(adata, group=clustering_name)

# # calculate the cluster dendrogram using R package pvclust
# dendro = Dendrogram(nboot=nboot,
#                     method_dist=method_dist,
#                     method_hclust=method_hclust,
#                     n_jobs=n_jobs)
# dendro.fit(pc_center)

In [None]:
# fig, ax = plt.subplots(figsize=(9, 3), dpi=250)
# _ = plot_dendrogram(dendro=dendro.dendrogram,
#                     linkage_df=dendro.linkage,
#                     ax=ax,
#                     plot_non_singleton=False,
#                     line_hue=dendro.edge_stats['au'], # au is the branch confidence score, see pvclust documentation
#                     line_hue_norm=(0.5, 1))

In [None]:
if len(adata.obs[clustering_name].unique())>1:
    adata.obs[clustering_name] = adata.obs[clustering_name].astype('category')
    sc.tl.dendrogram(adata, clustering_name, n_pcs=0)

In [None]:
if len(adata.obs[clustering_name].unique())>1:
    fig, ax = plt.subplots(figsize=(9, 3), dpi=80)
    _ = plot_dendrogram(dendro=adata.uns[f'dendrogram_{clustering_name}']['dendrogram_info'],
                        linkage_df=pd.DataFrame(adata.uns[f'dendrogram_{clustering_name}']['linkage']),
                        ax=ax,
                        plot_non_singleton=False,)

In [None]:
cc.save(f'{clustering_name}.ConcensusClustering.model.lib')
# dendro.save(f'{clustering_name}.Dendrogram.lib')
adata.write_h5ad(adata_path)
