In [None]:
import numpy as np
import pandas as pd
import anndata
import scanpy as sc
from ALLCools.mcds import MCDS


import xarray as xr
import dask
from ALLCools.plot import *
import pathlib
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
import json

import warnings

In [None]:
cluster_col = 'L1'
mc_type = 'CHN'
metadata_path = './CellMetadata.AfterQC.pdpkl'
adata_path = './adata.with_coords.h5ad'
old_annot_path = None

auto_annot_prefix = ''

gene_annotation_path = '~/refs/human/hg38/gencode/v33/gencode.v33.basic.annotation.gene.flat.tsv.gz'
obs_dim = 'cell'
var_dim = 'gene'

plot_top_n_markers = 10

n_samples = 5000

mcds_path_list = [
]

region_supregion_path = '~/map.region-supregion.json'
supregion_color_path = '~/map.supregion-color.json'

gene_mcds_path_pattern = None


In [None]:
# Parameters
cluster_col = "L3"
mc_type = "CHN"
metadata_path = "./CellMetadata.AfterQC.pdpkl"
adata_path = "./adata.with_coords.h5ad"
gene_annotation_path = "~/refs/human/hg38/gencode/v33/gencode.v33.basic.annotation.gene.flat.tsv.gz"
obs_dim = "cell"
var_dim = "gene"
plot_top_n_markers = 10
n_samples = 20000
mcds_path_list = [
    "~/mcds/hba_pool18_h1930002_A1C.mcds",
    "~/mcds/hba_pool76_h1930002_BL-La.mcds",
    "~/mcds/hba_pool88_h1930002_A38.mcds",
    "~/mcds/hba_pool2_h1930001_MTG.mcds",
    "~/mcds/hba_pool49_h1930002_SI.mcds",
    "~/mcds/hba_pool11_h1930001_ACC.mcds",
    "~/mcds/hba_pool87_h1930001_CA1C-CA2C-CA3C-DGC-CA4C.mcds",
    "~/mcds/hba_pool9_h1930001_FI.mcds",
    "~/mcds/hba_pool84_h1930001_Idg.mcds",
    "~/mcds/hba_pool23_h1930002_S1C.mcds",
    "~/mcds/hba_pool59_h1930001_DGC-CA4Cpy.mcds",
    "~/mcds/hba_pool39_h1930002_Pu.mcds",
    "~/mcds/hba_pool56_h1930001_MEC.mcds",
    "~/mcds/hba_pool46_h1930001_CBL.mcds",
    "~/mcds/hba_pool26_h1930002_Pro.mcds",
    "~/mcds/hba_pool71_h1930002_Cla.mcds",
    "~/mcds/hba_pool48_h1930002_BNST_a.mcds",
    "~/mcds/hba_pool54_h1930001_CMN.mcds",
    "~/mcds/hba_pool50_h1930001_BM_a.mcds",
    "~/mcds/hba_pool17_h1930001_A44-A45.mcds",
    "~/mcds/hba_pool64_h1930002_IC.mcds",
    "~/mcds/hba_pool66_h1930001_CA1R-CA2R-CA3R-DGR-CA4R.mcds",
    "~/mcds/hba_pool7_h1930002_V2.mcds",
    "~/mcds/hba_pool40_h1930001_TH-TL.mcds",
    "~/mcds/hba_pool3_h1930001_CBV.mcds",
    "~/mcds/hba_pool74_h1930001_IC.mcds",
    "~/mcds/hba_pool13_h1930002_FI.mcds",
    "~/mcds/hba_pool33_h1930002_A19.mcds",
    "~/mcds/hba_pool3_h1930002_CBV.mcds",
    "~/mcds/hba_pool65_h1930001_BL-La.mcds",
    "~/mcds/hba_pool43_h1930002_CaB.mcds",
    "~/mcds/hba_pool20_h1930001_A5-A7.mcds",
    "~/mcds/hba_pool51_h1930002_CMN_a.mcds",
    "~/mcds/hba_pool42_h1930001_CaB.mcds",
    "~/mcds/hba_pool52_h1930002_MD-Re_a.mcds",
    "~/mcds/hba_pool45_h1930001_SI.mcds",
    "~/mcds/hba_pool38_h1930001_Pu.mcds",
    "~/mcds/hba_pool8_h1930001_V1C.mcds",
    "~/mcds/hba_pool12_h1930002_V1C.mcds",
    "~/mcds/hba_pool10_h1930001_A46.mcds",
    "~/mcds/hba_pool81_h1930002_A25.mcds",
    "~/mcds/hba_pool69_h1930002_Amy.mcds",
    "~/mcds/hba_pool86_h1930001_A25.mcds",
    "~/mcds/hba_pool37_h1930002_GPe.mcds",
    "~/mcds/hba_pool73_h1930002_CA1C-CA2C-CA3C-DGC-CA4C.mcds",
    "~/mcds/hba_pool80_h1930002_Pul.mcds",
    "~/mcds/hba_pool27_h1930002_Ig.mcds",
    "~/mcds/hba_pool78_h1930002_PN.mcds",
    "~/mcds/hba_pool90_h1930002_CBL.mcds",
    "~/mcds/hba_pool29_h1930002_Pir.mcds",
    "~/mcds/hba_pool22_h1930001_S1C.mcds",
    "~/mcds/hba_pool67_h1930001_CA1C-CA2C-CA3C-DGC-CA4C.mcds",
    "~/mcds/hba_pool4_h1930001_M1C.mcds",
    "~/mcds/hba_pool85_h1930001_ITG.mcds",
    "~/mcds/hba_pool61_h1930001_Sub.mcds",
    "~/mcds/hba_pool41_h1930001_BNST_a.mcds",
    "~/mcds/hba_pool82_h1930002_Idg.mcds",
    "~/mcds/hba_pool44_h1930002_SEP.mcds",
    "~/mcds/hba_pool55_h1930001_Pul.mcds",
    "~/mcds/hba_pool47_h1930002_TH-TL.mcds",
    "~/mcds/hba_pool25_h1930001_Pro.mcds",
    "~/mcds/hba_pool36_h1930001_GPe.mcds",
    "~/mcds/hba_pool58_h1930001_SEP.mcds",
    "~/mcds/hba_pool75_h1930001_CA1R-CA2R-CA3R-DGR-CA4R.mcds",
    "~/mcds/hba_pool1_h1930002_MTG.mcds",
    "~/mcds/hba_pool14_h1930002_A46.mcds",
    "~/mcds/hba_pool15_h1930002_ACC.mcds",
    "~/mcds/hba_pool68_h1930001_Amy.mcds",
    "~/mcds/hba_pool79_h1930001_MD_Re.mcds",
    "~/mcds/hba_pool89_h1930002_ITG.mcds",
    "~/mcds/hba_pool63_h1930002_CA1R-CA2R-CA3R.mcds",
    "~/mcds/hba_pool31_h1930002_LEC.mcds",
    "~/mcds/hba_pool16_h1930001_A1C.mcds",
    "~/mcds/hba_pool6_h1930001_V2.mcds",
    "~/mcds/hba_pool57_h1930002_MEC.mcds",
    "~/mcds/hba_pool32_h1930001_A19.mcds",
    "~/mcds/hba_pool62_h1930002_DGR-CA4Rpy.mcds",
    "~/mcds/hba_pool30_h1930001_LEC.mcds",
    "~/mcds/hba_pool72_h1930002_CA1C-CA2C-CA3C-DGC-CA4C.mcds",
    "~/mcds/hba_pool60_h1930002_Sub.mcds",
    "~/mcds/hba_pool53_h1930002_CEN_a.mcds",
    "~/mcds/hba_pool28_h1930001_Pir.mcds",
    "~/mcds/hba_pool77_h1930001_PN.mcds",
    "~/mcds/hba_pool35_h1930002_NAC.mcds",
    "~/mcds/hba_pool70_h1930001_Cla.mcds",
    "~/mcds/hba_pool0_h1930002_CBV.mcds",
    "~/mcds/hba_pool19_h1930002_A44-A45.mcds",
    "~/mcds/hba_pool5_h1930002_M1C.mcds",
    "~/mcds/hba_pool21_h1930002_A5-A7.mcds",
    "~/mcds/hba_pool34_h1930001_NAC.mcds",
    "~/mcds/hba_pool0_h1930001_CBV.mcds",
    "~/mcds/hba_pool83_h1930001_A38.mcds",
    "~/mcds/hba_pool24_h1930001_Ig.mcds",
]
old_annot_path = (
    "~/cell_annotation_meta.pdpkl"
)
auto_annot_prefix = "VLMC"

gene_mcds_path_pattern = None

In [None]:
if auto_annot_prefix.strip()!='':
    auto_annot_prefix += ':'
else:
    auto_annot_prefix  = ''

In [None]:
gene_meta = pd.read_csv(gene_annotation_path, index_col='gene_id', sep='\t')
# gene_meta.index = gene_meta['gene']
gene_name_to_id = {v:k for k, v in gene_meta['gene_name'].iteritems()}
gene_id_to_name = {k:v for k, v in gene_meta['gene_name'].iteritems()}
gene_id_base_to_id = pd.Series(gene_meta.index, index = gene_meta.index.map(lambda i: i.split('.')[0])).to_dict()

In [None]:
with open(region_supregion_path) as f:
    map_region_supregion = json.load(f)
with open(supregion_color_path) as f:
    palette_supregion_color = json.load(f)

In [None]:
cell_meta = pd.read_pickle(metadata_path)
adata = anndata.read_h5ad(adata_path)
for col in adata.obs.columns:
    cell_meta[col] = adata.obs[col]
    
cell_meta['SupRegion'] = cell_meta['Region'].map(map_region_supregion)


In [None]:
if len(cell_meta)>n_samples:
    sample_cell_meta = cell_meta.sample(n_samples).copy()
else:
    sample_cell_meta = cell_meta.copy()
    

In [None]:
try:
    dmg_table = pd.read_pickle(f'{cluster_col}.{mc_type}.OneVsRestDMG.pdpkl')
except:
    dmg_table = None

In [None]:
try:
    pw_dmg_table = pd.read_pickle(f'{cluster_col}.{mc_type}.PairwiseDMG.pdpkl')
except:
    pw_dmg_table = None

In [None]:
top_markers = set()

if dmg_table is not None:
    tmp = dmg_table.groupby('cluster')['AUROC'].nlargest(plot_top_n_markers).index
    if isinstance(tmp, pd.MultiIndex):
        tmp = tmp.droplevel()
    top_markers.update(tmp)
if pw_dmg_table is not None:
#     tmp = pw_dmg_table.groupby('left-right')['AUROC'].nlargest(plot_top_n_markers).index
#     if isinstance(tmp, pd.MultiIndex):
#         tmp = tmp.droplevel()
#     top_markers.update(tmp)
    top_markers.update(pw_dmg_table.index.unique())

In [None]:
if gene_mcds_path_pattern is None:
    gene_mcds_path_pattern = [x.replace('.mcds','.gene_da_rate.mcds').replace('mcds/','gene_mcds/') \
                              for x in mcds_path_list]

In [None]:
mcds = MCDS.open(gene_mcds_path_pattern, obs_dim=obs_dim)
# mcds = mcds.sel({var_dim:use_features})
mcds = mcds.sel(mc_type=mc_type)

In [None]:
if len(top_markers)!=0:
    use_genes = mcds.get_index('gene')
    use_genes = use_genes[use_genes.isin(top_markers)]
    use_cells = mcds.get_index('cell')
    use_cells = use_cells[use_cells.isin(sample_cell_meta.index)]
    submcds = mcds.sel({obs_dim:use_cells}).sel({var_dim:use_genes})
    submcds.load()

In [None]:
def plot_cluster_and_genes(cluster, cell_meta, cluster_col, genes_data,
                           coord_base='umap', ncols=5, axes_size=3, dpi=150, hue_norm=(0.67, 1.5)):
    if isinstance(cluster, str):
        cluster = [cluster]
        
    ncols = max(2, ncols)
    nrows = 1 + (genes_data.shape[1] - 1) // ncols + 1

    # figure
    fig = plt.figure(figsize=(ncols * axes_size, nrows * axes_size), dpi=dpi)
    gs = fig.add_gridspec(nrows=nrows, ncols=ncols)

    # cluster axes
   
    ax = fig.add_subplot(gs[0, 0])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        axis_format=None,
                        hue=cluster_col,
                        palette='tab20')
    ax.set_title('All Clusters')
    
    ax = fig.add_subplot(gs[0, 1])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue=cell_meta[cluster_col].isin(cluster),
                        axis_format=None,
                        palette={
                            True: 'red',
                            False: 'lightgray'
                        })
    ax.set_title('This Cluster')
    
    ax = fig.add_subplot(gs[0, 2])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue='Donor',
                        axis_format=None,
                        scatter_kws=dict(s=1),
#                         show_legend=True,
                       )
    ax.set_title('Donor')
    
    ax = fig.add_subplot(gs[0, 3])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue='Region',
                        axis_format=None,
                        scatter_kws=dict(s=1),
                       )
    ax.set_title('Region')
    
    ax = fig.add_subplot(gs[0, 4])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue='Region',
                        axis_format=None,
                        show_legend=True,
                        legend_kws=dict(fontsize=6, bbox_to_anchor=(0,1)),
                        scatter_kws=dict(s=0),
                       )
#     ax.set_title('Region')
    
    
    # gene axes
    for i, (gene, data) in enumerate(genes_data.iteritems()):
        col = i % ncols
        row = i // ncols + 1
        ax = fig.add_subplot(gs[row, col])

        if ax.is_first_col() and ax.is_last_row():
            axis = 'tiny'
        else:
            axis = None

        continuous_scatter(ax=ax,
                           data=cell_meta,
                           hue=data,
                           axis_format=axis,
                           hue_norm=hue_norm,
                           coord_base=coord_base)
        ax.set_title(f'{data.name}')
    fig.suptitle(f'Cluster {" : ".join(cluster)} Top Markers')
    return fig

In [None]:
def plot_meta_and_old_annot(cell_meta, old_cell_meta_path, coord_base='tsne', dpi=150, supregion=True):
    axes_size = 3

    fig = plt.figure(figsize=(6 * axes_size, 2 * axes_size), dpi=dpi)
    gs = fig.add_gridspec(nrows=2, ncols=6)

    # cluster axes
   
    ax = fig.add_subplot(gs[0, 0])
    continuous_scatter(data=cell_meta,
                       ax=ax,
                       coord_base=coord_base,
                       axis_format=None,
                       scatter_kws=dict(s=1),
                       hue='mCHFrac')
    ax.set_title('mCHFrac')
    
    ax = fig.add_subplot(gs[0, 1])
    continuous_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        scatter_kws=dict(s=1),
                        hue='mCGFrac',)
    ax.set_title('mCGFrac')

    ax = fig.add_subplot(gs[0, 2])
    continuous_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        scatter_kws=dict(s=1),
                        hue='FinalmCReads',)
    ax.set_title('FinalmCReads')
    
    if supregion:
        ax = fig.add_subplot(gs[0, 3])
        categorical_scatter(data=cell_meta,
                            ax=ax,
                            coord_base=coord_base,
                            hue='SupRegion',
                            palette=palette_supregion_color,
                            axis_format=None,
                            scatter_kws=dict(s=1),
                           )
        ax.set_title('SupRegion')

        ax = fig.add_subplot(gs[0, 4])
        categorical_scatter(data=cell_meta,
                            ax=ax,
                            coord_base=coord_base,
                            hue='SupRegion',
                            palette=palette_supregion_color,
                            axis_format=None,
                            show_legend=True,
                            legend_kws=dict(fontsize=6, bbox_to_anchor=(0,1)),
                            scatter_kws=dict(s=0),
                           )
    else:
        ax = fig.add_subplot(gs[0, 3])
        categorical_scatter(data=cell_meta,
                            ax=ax,
                            coord_base=coord_base,
                            hue='Region',
                            palette='tab20',
                            axis_format=None,
                            scatter_kws=dict(s=1),
                           )
        ax.set_title('Region')

        ax = fig.add_subplot(gs[0, 4])
        categorical_scatter(data=cell_meta,
                            ax=ax,
                            coord_base=coord_base,
                            hue='Region',
                            palette='tab20',
                            axis_format=None,
                            show_legend=True,
                            legend_kws=dict(fontsize=6, bbox_to_anchor=(0,1)),
                            scatter_kws=dict(s=0),
                           )
    
    ax = fig.add_subplot(gs[1, 0])
    continuous_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue='pool_relative_read',
                        axis_format=None,
                        scatter_kws=dict(s=1),
#                         show_legend=True,
                       )
    ax.set_title('Relative Reads')
    
    ax = fig.add_subplot(gs[1, 1])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue='Donor',
                        axis_format=None,
                        scatter_kws=dict(s=1),
#                         show_legend=True,
                       )
    ax.set_title('Donor')
    
    

    
    ax = fig.add_subplot(gs[1, 2])
    categorical_scatter(data=cell_meta,
                        ax=ax,
                        coord_base=coord_base,
                        hue=cluster_col,
                        axis_format=None,
                        scatter_kws=dict(s=1),
                        text_anno=cluster_col,
                       )
    ax.set_title(cluster_col)
    
    if old_cell_meta_path is not None:
        tmp_cell_meta = cell_meta.copy()
        old_df = pd.read_pickle(old_cell_meta_path)
        
        if 'L1_annot' in old_df.columns:
            tmp_cell_meta['L1_annot'] = old_df['L1_annot']
            ax = fig.add_subplot(gs[1, 3])
            categorical_scatter(data=tmp_cell_meta[~tmp_cell_meta['L1_annot'].isna()],
                                ax=ax,
                                coord_base=coord_base,
                                hue='L1_annot',
                                scatter_kws=dict(s=1),
                                text_anno='L1_annot',
                               )
            ax.set_title('L1_annot')
        elif 'CellClass' in old_df.columns:
            tmp_cell_meta['CellClass'] = old_df['CellClass']
            ax = fig.add_subplot(gs[1, 3])
            categorical_scatter(data=tmp_cell_meta[~tmp_cell_meta['CellClass'].isna()],
                                ax=ax,
                                coord_base=coord_base,
                                hue='CellClass',
                                scatter_kws=dict(s=1),
                                text_anno='CellClass',
                               )
            ax.set_title('CellClass')   
        
        if 'L2_annot' in old_df.columns:
            tmp_cell_meta['L2_annot'] = old_df['L2_annot']        
            ax = fig.add_subplot(gs[1, 4])
            categorical_scatter(data=tmp_cell_meta[~tmp_cell_meta['L2_annot'].isna()],
                                ax=ax,
                                coord_base=coord_base,
                                hue='L2_annot',
                                scatter_kws=dict(s=1),
                                text_anno='L2_annot',
                               )
            ax.set_title('L2_annot')
        elif 'MajorType' in old_df.columns:
            tmp_cell_meta['MajorType'] = old_df['MajorType']        
            ax = fig.add_subplot(gs[1, 4])
            categorical_scatter(data=tmp_cell_meta[~tmp_cell_meta['MajorType'].isna()],
                                ax=ax,
                                coord_base=coord_base,
                                hue='MajorType',
                                scatter_kws=dict(s=1),
                                text_anno='MajorType',
                               )
            ax.set_title('MajorType')

In [None]:
def get_gene_ids(genes):
    def _get_gene_id(gene):
        if gene.startswith('ENSG'):
            gene_id = gene
        else:
            gene_id = gene_name_to_id[gene]
        return gene_id
    return [_get_gene_id(g) for g in genes]

def plot_genes(cell_meta, genes, coord_base='tsne', ncols=5, axes_size=3, dpi=150, hue_norm=(0.67, 1.5)):
    genes_data = mcds.sel(gene=get_gene_ids(known_markers))['gene_da_frac'].to_pandas()
    genes_data.columns = genes
    
    nrows = (genes_data.shape[1] - 1) // ncols + 1

    # figure
    fig= plt.figure(figsize=(ncols * axes_size, nrows * axes_size), dpi=dpi)
    gs = fig.add_gridspec(nrows=nrows, ncols=ncols)
    

    # gene axes
    for i, (gene, data) in enumerate(genes_data.iteritems()):
        col = i % ncols
        row = i // ncols
        ax = fig.add_subplot(gs[row, col])

        if ax.is_first_col() and ax.is_last_row():
            axis = 'tiny'
        else:
            axis = None

        continuous_scatter(ax=ax,
                           data=cell_meta,
                           hue=data,
                           axis_format=axis,
                           hue_norm=hue_norm,
                           coord_base=coord_base)
        ax.set_title(f'{data.name}')
#     fig.suptitle(f'Cluster {cluster} Top Markers')
    return fig

In [None]:
sample_cell_meta[cluster_col]=sample_cell_meta[cluster_col].astype(str)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if dmg_table is not None:
        for cluster, sub_dmg_table in dmg_table.groupby('cluster'):
            genes = sub_dmg_table.sort_values('AUROC', ascending=False)[:plot_top_n_markers]
            genes_data = submcds.sel(gene=genes.index)['gene_da_frac'].to_pandas()
            genes_data.columns = genes_data.columns.map(gene_meta['gene_name'])

            fig = plot_cluster_and_genes(cluster=cluster,
                                         cell_meta=sample_cell_meta,
                                         cluster_col=cluster_col,
                                         genes_data=genes_data,
                                         coord_base='tsne',
                                         ncols=5,
                                         axes_size=3,
                                         dpi=100,
                                         hue_norm=(0,1),
                                        )

            plt.show()


In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if pw_dmg_table is not None:
        for cluster, sub_dmg_table in pw_dmg_table.groupby('left-right'):
            genes = sub_dmg_table.sort_values('AUROC', ascending=False)[:plot_top_n_markers]
            genes_data = submcds.sel(gene=genes.index)['gene_da_frac'].to_pandas()
            genes_data.columns = genes_data.columns.map(gene_meta['gene_name'])

            fig = plot_cluster_and_genes(cluster=cluster.split('-'),
                                         cell_meta=sample_cell_meta,
                                         cluster_col=cluster_col,
                                         genes_data=genes_data,
                                         coord_base='tsne',
                                         ncols=5,
                                         axes_size=3,
                                         dpi=100,
                                         hue_norm=(0,1),
                                        )

            plt.show()


In [None]:
if len(top_markers)!=0:
    tmp = submcds['gene_da_frac'].to_pandas()
    tmp[cluster_col] = sample_cell_meta[cluster_col]
    tmp = tmp.groupby(cluster_col).mean().T
    sns.clustermap(tmp,vmax=2, vmin=0, cmap='bwr')
    sns.clustermap(tmp.corr(),vmin=0.5)

In [None]:
plot_meta_and_old_annot(sample_cell_meta, old_annot_path)

In [None]:
# known_markers = ['GAD1','GAD2','SATB2','SLC17A3','SLC17A7']

# plot_genes(cell_meta=sample_cell_meta,
#            genes=known_markers,
#            coord_base='tsne',
#            ncols=5,
#            axes_size=3,
#            dpi=100,
#            hue_norm=(0,1),)
# plt.show()

In [None]:
for clt in sorted(cell_meta[cluster_col].unique()):
    print(f"    '{clt}':'',")

In [None]:
annot_based_on = cluster_col
annot_to = f'{cluster_col}_annot'

========================== repeat below ===============================

In [None]:
cluster_name_map = {
    #copy above output  
}
if len(cluster_name_map)==0:
    cluster_name_map = {clt:f'{auto_annot_prefix}{cluster_col}_{clt}' for clt in sorted(cell_meta[cluster_col].unique())}

In [None]:
for k,v in cluster_name_map.items():
    if v.strip()=='':
        cluster_name_map[k] = k
        
cell_meta[annot_to] = cell_meta[annot_based_on].apply(
        lambda i: cluster_name_map[i] if i in cluster_name_map else i)


In [None]:
fig, axes = plt.subplots(figsize=(10, 5), dpi=100, ncols=2)
categorical_scatter(data=cell_meta,
                    coord_base='umap',
                    ax=axes[0],
                    hue=annot_to,
                    palette='tab20',
                    text_anno_kws=dict(fontsize=10),
                    text_anno=annot_to)
categorical_scatter(data=cell_meta,
                    coord_base='tsne',
                    ax=axes[1],
                    hue=annot_to,
                    palette='tab20',
                    text_anno_kws=dict(fontsize=10),
                    text_anno=annot_to)

plt.suptitle(annot_to)
plt.show()

In [None]:
if hasattr(adata, 'uns') and f'dendrogram_{cluster_col}' in adata.uns:
    fig, ax = plt.subplots(figsize=(9, 1), dpi=80)
    _ = plot_dendrogram(dendro=adata.uns[f'dendrogram_{cluster_col}']['dendrogram_info'],
                        linkage_df=pd.DataFrame(adata.uns[f'dendrogram_{cluster_col}']['linkage']),
                        ax=ax,
                        plot_non_singleton=False,
                       )
    plt.show()

    try:
        tmpdendro = adata.uns[f'dendrogram_{cluster_col}']['dendrogram_info'].copy()
        tmpdendro['ivl'] = pd.Series(tmpdendro['ivl']).map(cluster_name_map).tolist()
        fig, ax = plt.subplots(figsize=(9, 1), dpi=80)
        _ = plot_dendrogram(dendro=tmpdendro,
                            linkage_df=pd.DataFrame(adata.uns[f'dendrogram_{cluster_col}']['linkage']),
                            ax=ax,
                            plot_non_singleton=False,
                           )
        plt.show()
    except NameError:
        pass


========================== repeat above ===============================

In [None]:
cell_meta.to_pickle(f'{cluster_col}.Annotation.pdpkl')

In [None]:
# dendro = joblib.load(f'./{cluster_col}.Dendrogram.lib')

# fig, ax = plt.subplots(figsize=(9, 1), dpi=200)
# _ = plot_dendrogram(dendro=dendro.dendrogram,
#                     linkage_df=dendro.linkage,
#                     ax=ax,
#                     plot_non_singleton=False,
#                     line_hue=dendro.edge_stats['au'], # au is the branch confidence score, see pvclust documentation
#                     line_hue_norm=(0.5, 1))
# plt.show()

# try:
#     dendro.dendrogram['ivl'] = pd.Series(dendro.dendrogram['ivl']).map(cluster_name_map).tolist()
#     fig, ax = plt.subplots(figsize=(9, 1), dpi=200)
#     _ = plot_dendrogram(dendro=dendro.dendrogram,
#                         linkage_df=dendro.linkage,
#                         ax=ax,
#                         plot_non_singleton=False,
#                         line_hue=dendro.edge_stats['au'], # au is the branch confidence score, see pvclust documentation
#                         line_hue_norm=(0.5, 1))
#     plt.show()
# except NameError:
#     pass
