# Joint analysis of paired and unpaired multiomic data with MultiVI

MultiVI is used for the joint analysis of scRNA and scATAC-seq datasets that were jointly profiled (multiomic / paired) and single-modality datasets (only scRNA or only scATAC). MultiVI uses the paired data as an anchor to align and merge the latent spaces learned from each individual modality.

Based on the [scvi-tools MultiVI tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/multimodal/MultiVI_tutorial.html) that walks through how to read multiomic data, create a joint object with paired and unpaired data, set-up and train a MultiVI model, visualize the resulting latent space, and run differential analyses. 

## this notebook is modified directly from the scvi-tools tutorial

[MultiVI tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/MultiVI_tutorial.html)

<div class="alert alert-info">
Important

MultiVI requires the datasets to use shared features. scATAC-seq datasets need to be processed to use a shared set of peaks.

</div>

In [None]:
!date

#### import libraries

In [None]:
import scvi
from numpy import where
import scanpy as sc
from anndata import AnnData
from anndata import concat as ad_concat
from pandas import read_csv, concat
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# naming
project = 'aging_phase2'

# directories
wrk_dir = '/home/jupyter/brain_aging_phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'
public_dir = f'{wrk_dir}/public'

# in files
arc_file = f'{quants_dir}/{project}_ARC.raw.h5ad'
arc_scrublet_file = f'{quants_dir}/{project}_ARC.scrublet_scores.csv'
gex_file = f'{quants_dir}/{project}_GEX.raw.h5ad'
gex_scrublet_file = f'{quants_dir}/{project}_GEX.scrublet_scores.csv'
atac_file = f'{quants_dir}/{project}_ATAC.raw.h5ad'

# out files
raw_anndata_file =f'{quants_dir}/{project}.raw.h5ad'
out_h5ad_file = f'{quants_dir}/{project}.multivi.h5ad'
trained_model_path = f'{models_dir}/{project}_trained_multivi'
mvi_normalized_exp_file = f'{quants_dir}/{project}.multivi_norm_exp.parquet'
mvi_peak_est_file = f'{quants_dir}/{project}.multivi_peak_est.parquet'
mvi_df_results_file = f'{quants_dir}/{project}.multivi_diff_exp_by_cluster.parquet'

# variables and constants
DEBUG = False
MIN_CELL_PERCENT = 0.005
MAX_MITO_PERCENT = 10
TESTING = False
testing_cell_size = 5000
DETECT_HV_FEATURES = True
FILTER_HV_FEATURES = False
TOP_FEATURES_PERCENT = 0.15
leiden_res = 1.0
RUN_TRAINING = False

### load data

#### load ARC data

In [None]:
%%time
arc_data = sc.read_h5ad(arc_file)
# discovery data from Duffy
arc_data.obs['Study'] = 'LNG'
arc_data.obs['Study_type'] = 'discovery'
print(arc_data)
if DEBUG:
    display(arc_data.obs.head())

In [None]:
arc_srublet_df = read_csv(arc_scrublet_file, index_col=0)
print(f'shape of ARC scrublet df {arc_srublet_df.shape}')
if DEBUG:
    display(arc_srublet_df.predicted_doublet.value_counts())    
    display(arc_srublet_df.sample(5))

#### load GEX data

In [None]:
%%time
gex_data = sc.read_h5ad(gex_file)
print(gex_data)
# filter non-demultiplex cells
gex_data = gex_data[~gex_data.obs['donor_id'].isna()]
# discovery data from Duffy
gex_data.obs['Study'] = 'LNG'
gex_data.obs['Study_type'] = 'discovery'
print(gex_data)
if DEBUG:
    display(gex_data.obs.sample(5))
    display(gex_data.var.sample(5))    

In [None]:
gex_srublet_df = read_csv(gex_scrublet_file, index_col=0)
print(f'shape of GEX scrublet df {gex_srublet_df.shape}')
if DEBUG:
    display(gex_srublet_df.predicted_doublet.value_counts())    
    display(gex_srublet_df.sample(5))

#### filter any of the addtional non-genotype predicted doublets

the doublets detected based on genotype using demux to demultiplex the GEX pools have already been removed

In [None]:
# ARC
non_doublet_barcodes = arc_srublet_df.loc[~arc_srublet_df.predicted_doublet].index.values
arc_data = arc_data[arc_data.obs.index.isin(non_doublet_barcodes)]
# GEX
non_doublet_barcodes = gex_srublet_df.loc[~gex_srublet_df.predicted_doublet].index.values
gex_data = gex_data[gex_data.obs.index.isin(non_doublet_barcodes)]
print(f'ARC: {arc_data}')
print(f'GEX: {gex_data}')
if DEBUG:
    display(arc_data.obs.sample(5))
    display(gex_data.obs.sample(5))    

##### load public reference GEX data
Leng et al, entorhinal cortex samples only keep the Braak Stage 0 samples (n=3)

In [None]:
%%time
ec_file = f'{public_dir}/cellxgene_collections/Leng_entorhinal_cortex.h5ad'
adata_ref_ec = sc.read_h5ad(ec_file)
# retain original barcode
adata_ref_ec.obs['Barcode'] = adata_ref_ec.obs.index.astype('category')
# filter by Braak Stage
adata_ref_ec = adata_ref_ec[adata_ref_ec.obs.BraakStage == '0']
# reference data from Leng
adata_ref_ec.obs['Study'] = 'Leng'
adata_ref_ec.obs['Study_type'] = 'reference'
print(adata_ref_ec)
if DEBUG:
    display(adata_ref_ec.obs.sample(5))
    display(adata_ref_ec.var.sample(5))   

##### switch var attribute to use gene name instead of ID

In [None]:
adata_ref_ec.var['gene_id'] = adata_ref_ec.var.index.astype('category')
adata_ref_ec.var.index = adata_ref_ec.var.feature_name
if DEBUG:
    display(adata_ref_ec.var.head(10))

##### harmonize Leng et al reference data obs

In [None]:
adata_ref_ec.obs['age'] = (adata_ref_ec.obs.development_stage.str.replace('-year-old human stage','')
.str.replace(' year-old and over human stage',''))
drop_cols = ['donor_id', 'BraakStage', 'nUMI', 'nGene', 'initialClusterAssignments', 
             'cell_type', 'tissue_ontology_term_id', 'cell_type_ontology_term_id', 
             'assay_ontology_term_id', 'disease_ontology_term_id', 
             'self_reported_ethnicity_ontology_term_id', 'development_stage_ontology_term_id', 
             'is_primary_data', 'organism_ontology_term_id', 'suspension_type', 
             'cell_type', 'assay', 'disease', 'organism', 'self_reported_ethnicity', 
             'seurat.clusters', 'sex_ontology_term_id', 'development_stage', 'tissue', 'Barcode']
adata_ref_ec.obs = adata_ref_ec.obs.drop(columns=drop_cols)
adata_ref_ec.obs = adata_ref_ec.obs.rename(columns={'SampleID': 'sample_id', 
                                                'SampleBatch': 'gex_pool', 
                                                'clusterAssignment': 'phase1_cluster', 
                                                'clusterCellType': 'phase1_celltype'})
drop_cols = ['feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype']
adata_ref_ec.var = adata_ref_ec.var.drop(columns=drop_cols)
# drop any gene's that aren't in our GEX
intersect_features = gex_data.var.index.intersection(adata_ref_ec.var.index)
adata_ref_ec = adata_ref_ec[:, intersect_features]
adata_ref_ec.obsm = None

if DEBUG:
    display(adata_ref_ec.obs.sample(10))
    display(adata_ref_ec.var.sample(10))

##### combine the discovery and reference GEX data

In [None]:
adata_gex = ad_concat([gex_data, adata_ref_ec], join='outer')
adata_gex.obs_names_make_unique()
print(adata_gex)
if DEBUG:
    display(adata_gex.obs.sample(5))
    display(adata_gex.var.sample(5))

#### load ATAC data

In [None]:
%%time
atac_data = sc.read_h5ad(atac_file)
print(atac_data)
# filter non-demultiplex cells
atac_data = atac_data[~atac_data.obs['donor_id'].isna()]
# discovery data from Duffy
atac_data.obs['Study'] = 'LNG'
atac_data.obs['Study_type'] = 'discovery'
print(atac_data)
if DEBUG:
    display(atac_data.obs.sample(5))

#### if testing notebook subset the cells

In [None]:
def random_cells_subset(adata: AnnData, num_cells: int=5000, 
                        verbose: bool=False) -> AnnData:
    cells_subset = random.sample(list(adata.obs.index.values), num_cells)
    adata = adata[cells_subset]
    if DEBUG:
        print(adata)        
        display(adata.obs.sample(5))
    return adata

if TESTING:
    arc_data = random_cells_subset(arc_data, verbose=DEBUG)
    adata_gex = random_cells_subset(adata_gex, verbose=DEBUG)
    atac_data = random_cells_subset(atac_data, verbose=DEBUG)    

We can then use the `organize_multiome_anndatas` function to orgnize these three datasets into a single Multiome dataset.
This function sorts and orders the data from the multi-modal and modality-specific AnnDatas into a single AnnData (aligning the features, padding missing modalities with 0s, etc). 

In [None]:
%%time
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = scvi.data.organize_multiome_anndatas(arc_data, adata_gex, atac_data)

Note that `organize_multiome_anndatas` adds an annotation to the cells to indicate which modality they originate from:

In [None]:
print(adata_mvi)
display(adata_mvi.obs.modality.value_counts())
if DEBUG:
    display(adata_mvi.obs.sample(5))

#### clean up the obs as needed
the ARC samples are pooled but will use pool info as categorical variable so make one for non-pooled samples
drop the batch ID

In [None]:
if DEBUG:
    display(adata_mvi.obs.gex_pool.value_counts())
    display(adata_mvi.obs.atac_pool.value_counts())
    
# convert the age obs attribute feature to float from string
adata_mvi.obs.age = adata_mvi.obs.age.astype('float')
    
adata_mvi.obs.gex_pool = adata_mvi.obs.gex_pool.fillna('non')
adata_mvi.obs.atac_pool = adata_mvi.obs.atac_pool.fillna('non')
adata_mvi.obs.gex_pool = adata_mvi.obs.gex_pool.astype('str')
adata_mvi.obs.atac_pool = adata_mvi.obs.atac_pool.astype('str')
adata_mvi.obs.drop(columns=['batch_id'], inplace=True)

adata_mvi.obs.phase1_cluster = adata_mvi.obs.phase1_cluster.cat.add_categories(['phase2'])
adata_mvi.obs.phase1_celltype = adata_mvi.obs.phase1_celltype.cat.add_categories(['phase2'])
adata_mvi.obs.phase1_cluster = where(adata_mvi.obs.phase1_cluster == 'NA', 
                                     'phase2', adata_mvi.obs.phase1_cluster)
adata_mvi.obs.phase1_celltype = where(adata_mvi.obs.phase1_celltype == 'NA', 
                                     'phase2', adata_mvi.obs.phase1_celltype)

if DEBUG:
    display(adata_mvi.obs.sample(10))
    display(adata_mvi.obs.gex_pool.value_counts())
    display(adata_mvi.obs.atac_pool.value_counts())
    display(adata_mvi.obs.phase1_cluster.value_counts())
    display(adata_mvi.obs.phase1_celltype.value_counts())    

<div class="alert alert-info">
Important

MultiVI requires the features to be ordered so that genes appear before genomic regions. This must be enforced by the user.

</div>

MultiVI requires the features to be ordered, such that genes appear before genomic regions. In this case this is already the case, but it's always good to verify:

In [None]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
if DEBUG:
    display(adata_mvi.var)

#### save the MultiVi organized but unprocessed anndata object note that the subject is in the obs
only the discovery not the public reference

In [None]:
%%time
adata_disc_mvi = adata_mvi[adata_mvi.obs.Study_type == 'discovery']
print(adata_disc_mvi)
adata_disc_mvi.write(raw_anndata_file)

In [None]:
print(adata_mvi)

We also filter features to remove those that appear in fewer than MIN% of the cells

In [None]:
print(adata_mvi.shape)
# annotate the group of mitochondrial genes as 'mt'
adata_mvi.var['mt'] = adata_mvi.var_names.str.startswith('MT-')  
# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata_mvi, qc_vars=['mt'], percent_top=None, 
                           log1p=False, inplace=True)
adata_mvi = adata_mvi[adata_mvi.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
# Basic filtering:
# sc.pp.filter_cells(adata_mvi, min_genes=200)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * MIN_CELL_PERCENT))

print(adata_mvi)

if DEBUG:
    display(adata_mvi.obs.sample(5))
    display(adata_mvi.var.sample(5))    

#### if flag set then subset to highest variance features

MultiVI tutorial doesn't suggest this so probably typically will set to false

In [None]:
if DETECT_HV_FEATURES:
    n_top_genes = int(adata_mvi.var.shape[0] * TOP_FEATURES_PERCENT)
    sc.pp.highly_variable_genes(adata_mvi, n_top_genes=n_top_genes, 
                                batch_key='atac_pool',flavor='seurat_v3', 
                                subset=FILTER_HV_FEATURES)
    print(adata_mvi)
    print(adata_mvi.obs.modality.value_counts())
    print(adata_mvi.var.modality.value_counts())

## Setup and Training MultiVI
We can now set up and train the MultiVI model!

First, we need to setup the Anndata object using the `setup_anndata` function. At this point we specify any batch annotation that the model would account for.
**Importantly**, the main batch annotation, specific by `batch_key`, should correspond to the modality of the cells.

Other batch annotations (e.g if there are multiple ATAC batches) should be provided using the `categorical_covariate_keys`.

The actual values of categorical covariates (include `batch_key`) are not important, as long as they are different for different samples.
I.e it is not important to call the expression-only samples "expression", as long as they are called something different than the multi-modal and accessibility-only samples.

<div class="alert alert-info">
Important

MultiVI requires the main batch annotation to correspond to the modality of the samples. Other batch annotation, such as in the case of multiple RNA-only batches, can be specified using `categorical_covariate_keys`.

</div>

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality', 
                                 categorical_covariate_keys = ['sample_id', 'gex_pool', 'atac_pool']) 

When creating the object, we need to specify how many of the features are genes, and how many are genomic regions. This is so MultiVI can determine the exact architecture for each modality.

In [None]:
mvi = scvi.model.MULTIVI(
    adata_mvi, 
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.view_anndata_setup()

In [None]:
%%time
if RUN_TRAINING:
    mvi.train()

## Save and Load MultiVI models

Saving and loading models is similar to all other scvi-tools models, and is very straight forward:

In [None]:
if RUN_TRAINING:
    mvi.save(trained_model_path, overwrite=True)

In [None]:
mvi = scvi.model.MULTIVI.load(trained_model_path, adata=adata_mvi, accelerator='gpu')

## Extracting and visualizing the latent space

We can now use the `get_latent_representation` to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
adata_mvi.obsm['MultiVI_latent'] = mvi.get_latent_representation()

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_mvi, use_rep='MultiVI_latent')
# sc.tl.umap(adata_mvi, min_dist=0.3)
sc.tl.umap(adata_mvi)

#### visualize the latent representation

In [None]:
figure_file = f'_{project}.umap.samples.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['sample_id'], save=figure_file)

In [None]:
figure_file = f'_{project}.umap.studies.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['Study'], save=figure_file)

### Clustering on the MultiVI latent space
The user will note that we imported curated labels from the original publication. Our interface with scanpy makes it easy to cluster the data with scanpy from MultiVI's latent space and then reinject them into MultiVI (e.g., for differential expression).

In [None]:
%%time
# neighbors were already computed using scVI
sc.tl.leiden(adata_mvi, key_added='leiden_MultiVI', resolution=leiden_res)

In [None]:
figure_file = f'_{project}.umap.leiden_on.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_mvi, color=['leiden_MultiVI'], 
               frameon=False, legend_loc='on data', save=figure_file)

In [None]:
figure_file = f'_{project}.umap.leiden_off.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['leiden_MultiVI'], 
               frameon=False, save=figure_file)

In [None]:
figure_file = f'_{project}.umap.age.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['age'], 
               frameon=False, save=figure_file)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_mvi, color=['phase1_cluster'], 
               frameon=False)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['phase1_celltype'], 
               frameon=False)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['phase1_celltype'], 
               frameon=False, legend_loc='on data')

In [None]:
phase1_data = adata_mvi[adata_mvi.obs.phase1_celltype != 'phase2']
figure_file = f'_{project}.umap.phase1_celltype.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(phase1_data, color=['phase1_celltype'], 
               frameon=False, legend_loc='on data', save=figure_file)

In [None]:
print(phase1_data)

### save quantification layers as needed

In a well-mixed space, MultiVI can seamlessly impute the missing modalities for single-modality cells.
First, imputing expression and accessibility is done with `get_normalized_expression` and `get_accessibility_estimates`, respectively.

We'll demonstrate this by imputing gene expression for all cells in the dataset (including those that are ATAC-only cells):

here saving the accessibiltiy estimates and nulling object quickly to save on some memory

In [None]:
%%time
# get accessiblility estimates from model and save
accessibility = mvi.get_accessibility_estimates()
accessibility.to_parquet(mvi_peak_est_file)
print(f'shape of accessibility estimates {accessibility.shape}')
if DEBUG:
    display(accessibility.sample(5))    
accessibility = None

# # get normalized expression values from model and save
expression = mvi.get_normalized_expression()
expression.to_parquet(mvi_normalized_exp_file)
print(f'shape of normalized expression {expression.shape}')
if DEBUG:
    display(expression.sample(5))    

In [None]:
print(adata_mvi)

### transfer to cell types to replication data
split data set

In [None]:
adata_np2 = adata_mvi[adata_mvi.obs.phase1_celltype != 'phase2']
print('#### non-phase2 data ####')
print(adata_np2)

#### for the non-phase2 data what is the likely cell-type per cluster
per leiden cluster which labeled cell-type is most frequent

In [None]:
cluster_to_celltype = {}
cluster_to_refcluster = {}
for cluster_num in adata_np2.obs.leiden_MultiVI.unique():
    temp = adata_np2.obs.loc[adata_np2.obs.leiden_MultiVI == cluster_num]
    cluster_to_celltype[cluster_num] = temp.phase1_celltype.value_counts().idxmax()
    cluster_to_refcluster[cluster_num] = temp.phase1_cluster.value_counts().idxmax()
    if DEBUG:
        display(temp.phase1_celltype.value_counts().head())
        display(temp.phase1_cluster.value_counts().head())
display(cluster_to_celltype)
display(cluster_to_refcluster)

#### assign the labels

In [None]:
for cluster_num in adata_mvi.obs.leiden_MultiVI.unique():
    cell_type = cluster_to_celltype.get(cluster_num)
    ref_cluster = cluster_to_refcluster.get(cluster_num)
    print(cluster_num, cell_type, ref_cluster)
    adata_mvi.obs.loc[adata_mvi.obs.leiden_MultiVI == cluster_num, 'Cell_type'] = cell_type
    adata_mvi.obs.loc[adata_mvi.obs.leiden_MultiVI == cluster_num, 'RefCluster'] = ref_cluster    
if DEBUG:
    display(adata_mvi.obs.Cell_type.value_counts())
    display(adata_mvi.obs.RefCluster.value_counts())

### visualize likely cell-types

In [None]:
figure_file = f'_{project}.umap.likely_celltype.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['Cell_type'], 
               frameon=False, legend_loc='on data', save=figure_file)

### save the modified anndata object

In [None]:
adata_mvi.write(out_h5ad_file)

### now the the mvi adata object has already been save add the normalized expression as a layer for visualization on known marker genes

In [None]:
adata_exp = adata_mvi[:,expression.columns.to_list()]
adata_exp.layers['X_mvi'] = expression

We can demonstrate this on some known marker genes:

In [None]:
def plot_gene_in_umap(adata: AnnData, gene: str, layer: str='X_mvi'):
    if gene in adata.var.index:
        with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
            plt.style.use('seaborn-v0_8-talk')
            sc.pl.umap(adata, color=gene)
            sc.pl.umap(adata, color=gene, layer=layer)
    else:
        print(f'{gene} not present')

neuron, SNAP23. 

In [None]:
plot_gene_in_umap(adata_exp, 'SNAP25')

GABAerigc, GAD1:

In [None]:
plot_gene_in_umap(adata_exp, 'GAD1')

Glutamatergic, GRIN1:

In [None]:
plot_gene_in_umap(adata_exp, 'GRIN1')

Microglia, CSF1R:

In [None]:
plot_gene_in_umap(adata_exp, 'CSF1R')        

Astrocyte, GFAP:

In [None]:
plot_gene_in_umap(adata_exp, 'GFAP')           

Oligodendrocyte, PLP1:

In [None]:
plot_gene_in_umap(adata_exp, 'PLP1')    

All three marker genes clearly identify their respective populations. Importantly, the imputed gene expression profiles are stable and consistent within that population, **even though many of those cells only measured the ATAC profile of those cells**.

### Differential expression of Leiden clusters

In [None]:
%%time
de_df = mvi.differential_expression(groupby='leiden_MultiVI',)
if DEBUG:
    display(de_df.sample(10))

#### save the differential expression results

In [None]:
de_df.to_parquet(mvi_df_results_file)

#### We now extract top markers for each cluster using the DE results.

In [None]:
markers = {}
number_of_top_markers = 5
cats = adata_exp.obs['leiden_MultiVI'].cat.categories
for i, c in enumerate(cats):
    cid = f"{c} vs Rest"
    cell_type_df = de_df.loc[de_df.comparison == cid]

    cell_type_df = cell_type_df[cell_type_df.lfc_mean > 0]

    cell_type_df = cell_type_df[cell_type_df["bayes_factor"] > 3]
    cell_type_df = cell_type_df[cell_type_df["non_zeros_proportion1"] > 0.1]

    markers[c] = cell_type_df.index.tolist()[:number_of_top_markers]

In [None]:
sc.tl.dendrogram(adata_exp, groupby='leiden_MultiVI', use_rep='MultiVI_latent')

In [None]:
figure_file = f'{project}.cluster_markers.png'
with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 200}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.dotplot(adata_exp, markers, groupby='leiden_MultiVI', dendrogram=True,
                  color_map='Blues', swap_axes=True, use_raw=False,
                  standard_scale='var', save=figure_file)

#### We can also visualize the scVI normalized gene expression values with the layer option.

In [None]:
figure_file = f'{project}.cluster_markers_heatmap.png'
with rc_context({'figure.figsize': (15, 15), 'figure.dpi': 200, 'font.size': 6}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.heatmap(adata_exp, markers, groupby='leiden_MultiVI', layer='X_mvi', 
                  standard_scale='var', dendrogram=True, figsize=(8, 12),
                  show_gene_labels=True, save=figure_file)

In [None]:
!date