# CellRank 

[Lange et al., 2022](https://www.nature.com/articles/s41592-021-01346-6)

In [None]:
import scvelo as scv
import cellrank as cr
import scanpy as sc
import scanorama
import scipy
import anndata as ad
import numpy as np
import pandas as pd

import os

In [None]:
print(cr.__version__)

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# rpy2 
os.environ['R_HOME'] = '/home/fdeckert/bin/miniconda3/envs/p.3.8.12-FD20200109SPLENO/lib/R'

In [None]:
sc.settings.vector_friendly = False

sc.set_figure_params(figsize=(2, 3), dpi_save=1200, fontsize=8, frameon=False)
sc.settings.figdir = 'result/figures/'

scv.set_figure_params(figsize=(2, 3), dpi_save=1200, fontsize=8, frameon=False)
scv.settings.figdir = 'result/figures/'

cr.settings.figdir = 'result/figures/'

In [None]:
os.chdir('/research/peer/fdeckert/FD20200109SPLENO')

In [None]:
# Plotting 
import rpy2.robjects as robjects
color_load = robjects.r.source('plotting_global.R')
color = dict()
for i in range(len(color_load[0])):
    color[color_load[0].names[i]] = {key : color_load[0][i].rx2(key)[0] for key in color_load[0][i].names}

# Import data 

In [None]:
from cellrank.tl.kernels import VelocityKernel
vk_prog_nacl = VelocityKernel.read('data/object/cellrank/kernel/vk_prog_nacl.pickle')
vk_prog_cpg = VelocityKernel.read('data/object/cellrank/kernel/vk_prog_cpg.pickle')

vk_m_nacl = VelocityKernel.read('data/object/cellrank/kernel/vk_m_nacl.pickle')
vk_m_cpg = VelocityKernel.read('data/object/cellrank/kernel/vk_m_cpg.pickle')

vk_mo_nacl = VelocityKernel.read('data/object/cellrank/kernel/vk_mo_nacl.pickle')
vk_mo_cpg = VelocityKernel.read('data/object/cellrank/kernel/vk_mo_cpg.pickle')

from cellrank.tl.kernels import ConnectivityKernel
ck_prog_nacl = ConnectivityKernel.read('data/object/cellrank/kernel/ck_prog_nacl.pickle')
ck_prog_cpg = ConnectivityKernel.read('data/object/cellrank/kernel/ck_prog_cpg.pickle')

ck_m_nacl = ConnectivityKernel.read('data/object/cellrank/kernel/ck_m_nacl.pickle')
ck_m_cpg = ConnectivityKernel.read('data/object/cellrank/kernel/ck_m_cpg.pickle')

ck_mo_nacl = ConnectivityKernel.read('data/object/cellrank/kernel/ck_mo_nacl.pickle')
ck_mo_cpg = ConnectivityKernel.read('data/object/cellrank/kernel/ck_mo_cpg.pickle')

# Workflow functions 

In [None]:
def cr_workflow(vk, ck, vk_ratio, ck_ratio, n_components): 
    
    # Transission matrix
    vk = vk.compute_transition_matrix()
    ck = ck.compute_transition_matrix()
    
    combined_kernel = vk_ratio*vk + ck_ratio*ck
    
    from cellrank.tl.estimators import GPCCA
    g = GPCCA(combined_kernel)
    
    g.compute_schur(n_components=n_components)
    
    return(g)

In [None]:
def umap_workflow(adata, suffix='', n_neighbor=30,  n_pcs=50, n_top_genes=2000): 
    
    adata = adata.raw.to_adata()

    # Filter genes
    sc.pp.filter_genes(adata, min_counts=10)
    
    # Highly variable 
    sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=n_top_genes)
    
    # Normalize and scale 
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    
    # Dim reduction 
    sc.pp.pca(adata)
    sc.pp.neighbors(adata, n_neighbors=n_neighbor, n_pcs=n_pcs, use_rep='X_pca')
    
    sc.tl.umap(adata)
    
    # Save umap 
    pd.DataFrame(adata.obsm['X_umap'], index=adata.obs_names).to_csv('result/cellrank/umap'+suffix+'.csv')
    return(pd.DataFrame(adata.obsm['X_umap'], index=adata.obs_names))

In [None]:
def dpt_workflow(adata, absorption_probabilities, suffix='', n_neighbor=30,  n_pcs=50, n_dcs=10, lineage=None, cluster=None, absorption_probabilities_thr=0, compute=False):
    
    if compute: 
        
        absorption_probabilities = absorption_probabilities[absorption_probabilities.index.isin(adata.obs_names)]
        
        # Subset absorption probability and cluster
        absorption_probabilities = absorption_probabilities[absorption_probabilities.idxmax(axis=1)==lineage]
        absorption_probabilities = absorption_probabilities[absorption_probabilities[lineage] >= absorption_probabilities_thr]
        adata = adata[(adata.obs_names.isin(absorption_probabilities.index)) | (adata.obs['cell_type_fine'].isin(cluster))]
        
        adata_treatment = dict()
        for treatment in ["NaCl", "CpG"]: 
            
            # Subset adata by treatment 
            adata_tmp = adata[adata.obs['treatment']==treatment, ]
            
            # Subset absorption probability 
            absorption_probabilities_tmp = absorption_probabilities[absorption_probabilities.index.isin(adata_tmp.obs_names)]
        
            # Set raw data as default 
            adata_tmp.X = adata_tmp.X.astype(int)

            # Filter genes
            sc.pp.filter_genes(adata_tmp, min_counts=10)

            print('Number of cells:', adata_tmp.n_obs)
            print('Number of genes:', adata_tmp.n_vars)

            # Normalize
            sc.pp.normalize_total(adata_tmp)
            sc.pp.log1p(adata_tmp)
            sc.pp.pca(adata_tmp)
            sc.pp.neighbors(adata_tmp, n_neighbors=n_neighbor, n_pcs=n_pcs, use_rep='X_pca')

            # Diffusion map and pseudotime 
            sc.tl.diffmap(adata_tmp, n_comps=n_dcs)
            adata_tmp.uns['iroot'] = np.flatnonzero(absorption_probabilities_tmp.index==absorption_probabilities_tmp[lineage].idxmin())[0]
            sc.tl.dpt(adata_tmp, n_branchings=0, n_dcs=n_dcs)
            
            # Draw FA graph 
            sc.tl.draw_graph(adata_tmp, layout='fa')
            sc.pl.draw_graph(adata_tmp, color=['cell_type_fine', 'cc_phase_class', 'dpt_pseudotime'], wspace=0.5, ncols=3, size=20, vmin=0, vmax=1)
            
            # Save output 
            adata_tmp.obs['dpt_pseudotime'].to_csv('result/cellrank/dpt_pseudotime'+suffix+'_'+treatment.lower()+'.csv')
            pd.DataFrame(adata_tmp.obsm['X_draw_graph_fa'], index=adata_tmp.obs_names, columns=['FAG_1', 'FAG_2']).to_csv('result/cellrank/fag'+suffix+'_'+treatment.lower()+'.csv')
            
            adata_treatment[treatment] = adata_tmp
         
        # Set raw data as default 
        adata.X = adata.X.astype(int)

        # Filter genes
        sc.pp.filter_genes(adata, min_counts=10)

        print('Number of cells:', adata.n_obs)
        print('Number of genes:', adata.n_vars)

        # Normalize 
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
#         sc.pp.pca(adata)
#         sc.pp.neighbors(adata, n_neighbors=n_neighbor, n_pcs=n_pcs, use_rep='X_pca')
        
        adata_sub = dict()
        for sample_group in adata.obs['sample_group'].unique():
            adata_tmp = adata[adata.obs['sample_group']==sample_group].copy()
            sc.pp.filter_genes(adata_tmp, min_counts=10)
            adata_sub[sample_group] = adata_tmp
        adata_sub = list(adata_sub.values())
        
        # Run Scanorama
        scanorama.integrate_scanpy(adata_sub, verbose=True)

        # Concatenate scanorama output 
        X_scanorama = [ad.obsm['X_scanorama'] for ad in adata_sub]
        X_scanorama = np.concatenate(X_scanorama)

        obs_names = [ad.obs_names for ad in adata_sub]
        obs_names = np.concatenate(obs_names)
        all(obs_names==adata.obs_names)

        # Add X_scanorama integration to adata 
        adata.obsm["X_scanorama"] = X_scanorama
        
        sc.pp.neighbors(adata, n_neighbors=n_neighbor, n_pcs=n_pcs, use_rep='X_scanorama')
        
        # Diffusion map and pseudotime 
        sc.tl.diffmap(adata, n_comps=n_dcs)
        adata.uns['iroot'] = np.flatnonzero(absorption_probabilities.index==absorption_probabilities[lineage].idxmin())[0]
        sc.tl.dpt(adata, n_branchings=0, n_dcs=n_dcs)
        
        # Draw FA graph 
        sc.tl.draw_graph(adata, layout='fa')
        sc.pl.draw_graph(adata, color=['cell_type_fine', 'treatment', 'dpt_pseudotime'], wspace=0.5, ncols=3, size=20, vmin=0, vmax=1)
        
        # Save output 
        adata.obs['dpt_pseudotime'].to_csv('result/cellrank/dpt_pseudotime'+suffix+'.csv')
        pd.DataFrame(adata.obsm['X_draw_graph_fa'], index=adata.obs_names, columns=['FAG_1', 'FAG_2']).to_csv('result/cellrank/fag'+suffix+'.csv')
        
    else: 
        
        print('Not implemented. Set compute=True')
        
    return(adata)

# CellRank progenitor

In [None]:
g_prog_nacl = cr_workflow(vk_prog_nacl, ck_prog_nacl, 0.4, 0.6, n_components=10)
g_prog_cpg = cr_workflow(vk_prog_cpg, ck_prog_cpg, 0.4, 0.6, n_components=10)

## Steady state (NaCl)

In [None]:
scv.pl.velocity_embedding_stream(g_prog_nacl.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', title='CpG', density=1.5, linewidth=1, arrow_size=1, save='stream_nacl.svg')

In [None]:
g_prog_nacl.compute_macrostates(n_states=10, cluster_key='cell_type_fine', n_cells=5)
g_prog_nacl.plot_macrostates(same_plot=False, ncols=5, figsize=(3, 3))

In [None]:
g_prog_nacl.set_terminal_states_from_macrostates(['MastP', 'MegP', 'EB (5)'], n_cells=15)
g_prog_nacl.plot_terminal_states(same_plot=True, ncols=5, legend_loc='none', save='terminal_states_prog_nacl.svg')

In [None]:
g_prog_nacl.compute_absorption_probabilities(solver='gmres')
g_prog_nacl.plot_absorption_probabilities(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='none', title='NaCl', save='absorption_probabilities_prog_nacl.png')

In [None]:
cr.pl.circular_projection(g_prog_nacl.adata, keys='cell_type_fine', lineage_order='default', figsize=(10, 10), legend_loc='none', title='', save='circular_projection_prog_nacl.pdf')

In [None]:
drivers_1 = g_prog_nacl.compute_lineage_drivers(lineages='MastP', cluster_key='cell_type_fine', cluster=['MastP', 'MegP', 'MEP (1)', 'MEP (2)', 'MEP (3)', 'MEP (4)'], return_drivers=True)
g_prog_nacl.plot_lineage_drivers('MastP', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
drivers_2 = g_prog_nacl.compute_lineage_drivers(lineages='MegP', cluster_key='cell_type_fine', cluster=['MastP', 'MegP', 'MEP (1)', 'MEP (2)', 'MEP (3)', 'MEP (4)'], return_drivers=True)
g_prog_nacl.plot_lineage_drivers('MegP', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
drivers_3 = g_prog_nacl.compute_lineage_drivers(lineages='EB (5)', cluster_key='cell_type_main', cluster=['ProEB', 'EB'], return_drivers=True)
g_prog_nacl.plot_lineage_drivers('EB (5)', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
# Save results 
pd.concat([drivers_1, drivers_2, drivers_3]).to_csv('result/cellrank/compute_lineage_drivers_prog_nacl.csv')
pd.DataFrame(g_prog_nacl.absorption_probabilities.X, columns=g_prog_nacl.absorption_probabilities.names.tolist(), index=g_prog_nacl.adata.obs_names).to_csv('result/cellrank/absorption_probabilities_prog_nacl.csv')
pd.DataFrame(g_prog_nacl.macrostates_memberships, columns=g_prog_nacl.macrostates_memberships.names.tolist(), index=g_prog_nacl.adata.obs_names).to_csv('result/cellrank/macrostates_memberships_prog_nacl.csv')
pd.DataFrame(scipy.sparse.csr_matrix.toarray(g_prog_nacl.transition_matrix), columns=g_prog_nacl.adata.obs_names, index=g_prog_nacl.adata.obs_names).to_csv('result/cellrank/transition_matrix_nacl.csv')

## Stress (CpG)

In [None]:
scv.pl.velocity_embedding_stream(g_prog_cpg.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', title='CpG', density=1.5, linewidth=1, arrow_size=1, save='stream_cpg.svg')

In [None]:
g_prog_cpg.compute_macrostates(n_states=10, cluster_key='cell_type_fine', n_cells=5)
g_prog_cpg.plot_macrostates(same_plot=False, ncols=5, figsize=(3, 3))

In [None]:
g_prog_cpg.set_terminal_states_from_macrostates(['MastP', 'MegP', 'EB (5)'], n_cells=15)
g_prog_cpg.plot_terminal_states(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='right')

In [None]:
g_prog_cpg.compute_absorption_probabilities(solver='gmres')
g_prog_cpg.plot_absorption_probabilities(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='none', title='CpG', save='absorption_probabilities_prog_cpg.png')

In [None]:
cr.pl.circular_projection(g_prog_cpg.adata, keys='cell_type_fine', figsize=(7.5, 7.5), legend_loc='none', title='', save='circular_projection_prog_cpg.pdf')

In [None]:
drivers_1 = g_prog_cpg.compute_lineage_drivers(lineages='MastP', cluster_key='cell_type_fine', cluster=['MastP', 'MegP', 'MEP (1)', 'MEP (2)', 'MEP (3)', 'MEP (4)'], return_drivers=False)
g_prog_cpg.plot_lineage_drivers('MastP', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
drivers_2 = g_prog_cpg.compute_lineage_drivers(lineages='MegP', cluster_key='cell_type_fine', cluster=['MastP', 'MegP', 'MEP (1)', 'MEP (2)', 'MEP (3)', 'MEP (4)'], return_drivers=True)
g_prog_cpg.plot_lineage_drivers('MegP', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
drivers_3 = g_prog_cpg.compute_lineage_drivers(lineages='EB (5)', cluster_key='cell_type_main', cluster=['ProEB', 'EB'], return_drivers=True)
g_prog_cpg.plot_lineage_drivers('EB (5)', n_genes=5, ncols=5, figsize=(15, 3))

In [None]:
# Save results 
pd.concat([drivers_1, drivers_2, drivers_3]).to_csv('result/cellrank/compute_lineage_drivers_prog_cpg.csv')
pd.DataFrame(g_prog_cpg.absorption_probabilities.X, columns=g_prog_cpg.absorption_probabilities.names.tolist(), index=g_prog_cpg.adata.obs_names).to_csv('result/cellrank/absorption_probabilities_prog_cpg.csv')
pd.DataFrame(g_prog_cpg.macrostates_memberships, columns=g_prog_cpg.macrostates_memberships.names.tolist(), index=g_prog_cpg.adata.obs_names).to_csv('result/cellrank/macrostates_memberships_prog_cpg.csv')
pd.DataFrame(scipy.sparse.csr_matrix.toarray(g_prog_cpg.transition_matrix), columns=g_prog_cpg.adata.obs_names, index=g_prog_cpg.adata.obs_names).to_csv('result/cellrank/transition_matrix_cpg.csv')

# CellRank Myeloid

In [None]:
g_m_nacl = cr_workflow(vk_m_nacl, ck_m_nacl, 0.2, 0.8, n_components=10)
g_m_cpg = cr_workflow(vk_m_cpg, ck_m_cpg, 0.2, 0.8, n_components=10)

In [None]:
umap_m_nacl = umap_workflow(g_m_nacl.adata.copy(), n_neighbor=30,  n_pcs=50, n_top_genes=8000, suffix='_m_nacl')
umap_m_cpg = umap_workflow(g_m_cpg.adata.copy(), n_neighbor=30,  n_pcs=50, n_top_genes=8000, suffix='_m_cpg')

In [None]:
g_m_nacl.adata.obsm['X_umap'] = umap_m_nacl.loc[g_m_nacl.adata.obs_names].to_numpy()
g_m_cpg.adata.obsm['X_umap'] = umap_m_cpg.loc[g_m_cpg.adata.obs_names].to_numpy()

## Stready state (NaCl)

In [None]:
scv.pl.velocity_embedding_stream(g_m_nacl.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', density=1.5, linewidth=1, arrow_size=1)

In [None]:
g_m_nacl.compute_macrostates(n_states=5, cluster_key='cell_type_fine', n_cells=30)
g_m_nacl.plot_macrostates(same_plot=False, ncols=5, figsize=(3, 3))

In [None]:
g_m_nacl.set_terminal_states_from_macrostates(['RPM', 'cDC2 (2)', 'cDC1 (2)', 'ncMo (1)'], n_cells=30)
g_m_nacl.plot_terminal_states(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='right')

In [None]:
g_m_nacl.compute_absorption_probabilities(solver='gmres')
g_m_nacl.plot_absorption_probabilities(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='none', title='NaCl', save='absorption_probabilities_m_nacl.png')

In [None]:
cr.pl.circular_projection(g_m_nacl.adata, keys='cell_type_fine', figsize=(7.5, 7.5), legend_loc='none', lineages=['RPM', 'cDC2 (2)', 'cDC1 (2)', 'ncMo (1)'], lineage_order='default', title='')

## Steady state (CpG)

In [None]:
scv.pl.velocity_embedding_stream(g_m_cpg.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', density=1.5, linewidth=1, arrow_size=1)

In [None]:
g_m_cpg.compute_macrostates(n_states=5, cluster_key='cell_type_fine', n_cells=30)
g_m_cpg.plot_macrostates(same_plot=False, ncols=5, figsize=(3, 3))

In [None]:
g_m_cpg.set_terminal_states_from_macrostates(['RPM', 'cDC2 (3)', 'cDC1 (1)', 'ncMo (1)'], n_cells=30)
g_m_cpg.plot_terminal_states(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='right')

In [None]:
g_m_cpg.compute_absorption_probabilities(solver='gmres')
g_m_cpg.plot_absorption_probabilities(same_plot=True, ncols=5, figsize=(2, 3), legend_loc='none', title='NaCl', save='absorption_probabilities_m_cpg.png')

In [None]:
cr.pl.circular_projection(g_m_cpg.adata, keys='cell_type_fine', figsize=(7.5, 7.5), legend_loc='none', lineages=['RPM', 'cDC2 (3)', 'cDC1 (1)', 'ncMo (1)'], lineage_order='default', title='')

# CellRank monocytes

In [None]:
g_mo_nacl = cr_workflow(vk_mo_nacl, ck_mo_nacl, 0.4, 0.6, n_components=10)
g_mo_cpg = cr_workflow(vk_mo_cpg, ck_mo_cpg, 0.4, 0.6, n_components=10)

In [None]:
umap_mo_nacl = umap_workflow(g_mo_nacl.adata.copy(), n_neighbor=30,  n_pcs=35, n_top_genes=2000, suffix='_mo_nacl')
umap_mo_cpg = umap_workflow(g_mo_cpg.adata.copy(), n_neighbor=30,  n_pcs=35, n_top_genes=2000, suffix='_mo_cpg')

In [None]:
g_mo_nacl.adata.obsm['X_umap'] = umap_mo_nacl.loc[g_mo_nacl.adata.obs_names].to_numpy()
g_mo_cpg.adata.obsm['X_umap'] = umap_mo_cpg.loc[g_mo_cpg.adata.obs_names].to_numpy()

## CellRank monocytes (NaCl)

In [None]:
scv.pl.velocity_embedding_stream(g_mo_nacl.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', title='CpG', density=1.5, linewidth=1, arrow_size=1, save='stream_mo_nacl.svg')

## CellRank monocytes (CpG)

In [None]:
scv.pl.velocity_embedding_stream(g_mo_cpg.adata, vkey='velocity', basis='X_umap', color=['cell_type_fine'], size=50, alpha=1, legend_loc='none', title='CpG', density=1.5, linewidth=1, arrow_size=1, save='stream_mo_cpg.svg')

# Lineage diffusion pseudotime progenitors

In [None]:
absorption_probabilities_nacl = pd.DataFrame(g_prog_nacl.absorption_probabilities.X, columns=g_prog_nacl.absorption_probabilities.names.tolist(), index=g_prog_nacl.adata.obs_names)
absorption_probabilities_cpg = pd.DataFrame(g_prog_cpg.absorption_probabilities.X, columns=g_prog_cpg.absorption_probabilities.names.tolist(), index=g_prog_cpg.adata.obs_names)
absorption_probabilities = pd.concat([absorption_probabilities_nacl, absorption_probabilities_cpg])

In [None]:
adata_nacl = ad.AnnData(X=g_prog_nacl.adata.raw.X, obs=g_prog_nacl.adata.obs, var=g_prog_nacl.adata.raw.var)
adata_cpg = ad.AnnData(X=g_prog_cpg.adata.raw.X, obs=g_prog_cpg.adata.obs, var=g_prog_cpg.adata.raw.var)

adata = ad.concat([adata_nacl, adata_cpg])

In [None]:
umap_nacl = pd.DataFrame(g_prog_nacl.adata.obsm['X_umap'], index=g_prog_nacl.adata.obs_names)
umap_cpg = pd.DataFrame(g_prog_cpg.adata.obsm['X_umap'], index=g_prog_cpg.adata.obs_names)

adata.obsm['X_umap_seurat'] = pd.concat([umap_nacl, umap_cpg]).to_numpy()

In [None]:
absorption_probabilities['MastP'].mask(adata.obs['cell_type_fine']=='MastP', 1.0, inplace=True)
absorption_probabilities['MegP'].mask(adata.obs['cell_type_fine']=='MegP', 1.0, inplace=True)
absorption_probabilities['EB (5)'].mask(adata.obs['cell_type_fine']=='EB (5)', 1.0, inplace=True)

In [None]:
absorption_probabilities['MastP'].mask((absorption_probabilities['MegP']>=1) | (absorption_probabilities['EB (5)']>=1), 0, inplace=True)
absorption_probabilities['MegP'].mask((absorption_probabilities['MastP']>=1) | (absorption_probabilities['EB (5)']>=1), 0, inplace=True)
absorption_probabilities['EB (5)'].mask((absorption_probabilities['MegP']>=1) | (absorption_probabilities['MastP']>=1), 0, inplace=True)

In [None]:
# Plotting 
import rpy2.robjects as robjects
color_load = robjects.r.source('plotting_global.R')
color = dict()
for i in range(len(color_load[0])):
    color[color_load[0].names[i]] = {key : color_load[0][i].rx2(key)[0] for key in color_load[0][i].names}

In [None]:
def set_color(categories): 
    
    categories = [x for x in categories if x in list(adata.obs.columns)]

    for category in categories: 

        adata.obs[category] = pd.Series(adata.obs[category], dtype='category')
        
        keys = list(color[category].keys())
        keys = [x for x in keys if x in list(adata.obs[category])]

        adata.obs[category] = adata.obs[category].cat.reorder_categories(keys)
        adata.uns[category+'_colors'] = np.array([color[category].get(key) for key in keys], dtype=object)
        
# Set colors
set_color(list(color.keys()))

In [None]:
adata_mastp = dpt_workflow(adata, absorption_probabilities, suffix='_mastp', lineage='MastP', cluster=['MastP'], compute=True)

In [None]:
adata_megp = dpt_workflow(adata, absorption_probabilities, suffix='_megp', lineage='MegP', cluster=['MegP'], compute=True)

In [None]:
adata_eb = dpt_workflow(adata, absorption_probabilities, suffix='_eb', lineage='EB (5)', cluster=['ProEB (1)', 'ProEB (2)', 'ProEB (3)', 'ProEB (4)', 'EB (1)', 'EB (2)', 'EB (3)', 'EB (4)', 'EB (5)'], compute=True)