# Joint analysis of paired and unpaired multiomic data with MultiVI

MultiVI is used for the joint analysis of scRNA and scATAC-seq datasets that were jointly profiled (multiomic / paired) and single-modality datasets (only scRNA or only scATAC). MultiVI uses the paired data as an anchor to align and merge the latent spaces learned from each individual modality.

Based on the [scvi-tools MultiVI tutorial](https://docs.scvi-tools.org/en/stable/tutorials/notebooks/multimodal/MultiVI_tutorial.html) that walks through how to read multiomic data, create a joint object with paired and unpaired data, set-up and train a MultiVI model, visualize the resulting latent space, and run differential analyses. 

## Perform a joint clustering of the snRNA and snATAC modalities combine using the cell-type annotations from clustering and review of the snRNA data

In [None]:
!date

#### import libraries

In [None]:
import scvi
from numpy import where
import scanpy as sc
from anndata import AnnData
from anndata import concat as ad_concat
from pandas import read_csv, concat, DataFrame, Series
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import torch
from seaborn import lineplot
from sklearn.metrics import silhouette_score
from numpy import arange, mean, percentile

import random
random.seed(42)

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables and constants
project = 'aging_phase2'
DEBUG = True
MIN_CELL_PERCENT = 0.005
MAX_MITO_PERCENT = 20
TESTING = False
testing_cell_size = 5000
DETECT_HV_FEATURES = True
FILTER_HV_FEATURES = True
TOP_FEATURES_PERCENT = 0.3 # 0.075
leiden_res = 1.0
RUN_TRAINING = True
BATCH_SIZE = 10000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'
models_dir = f'{wrk_dir}/models'
figures_dir = f'{wrk_dir}/figures'
sc.settings.figdir = f'{figures_dir}/'
resolution_dir = f'{quants_dir}/resolution_selection'

# in files
arc_file = f'{quants_dir}/{project}_ARC.raw.h5ad'
arc_scrublet_file = f'{quants_dir}/{project}_ARC.scrublet_scores.csv'
gex_file = f'{quants_dir}/{project}_GEX.raw.h5ad'
gex_scrublet_file = f'{quants_dir}/{project}_GEX.scrublet_scores.csv'
atac_file = f'{quants_dir}/{project}_ATAC.raw.h5ad'
exn_cell_assignments = f'{resolution_dir}/ExN_res_obs_annotated_p1only_res_0.4.csv'
inh_cell_assignments = f'{resolution_dir}/InN_res_obs_annotated_p1only_res_0.4.csv'
nonneuronal_h5ad = f'{quants_dir}/{project}_NonNeuronal.dev.rna.scvi.h5ad'

# out files
raw_anndata_file =f'{quants_dir}/{project}.dev.raw.multivi_prep.h5ad'
out_h5ad_file = f'{quants_dir}/{project}.dev.multivi.h5ad'
trained_model_path = f'{models_dir}/{project}_dev_trained_multivi'

if DEBUG:
    print(f'{arc_file=}')
    print(f'{arc_scrublet_file=}')
    print(f'{gex_file=}')
    print(f'{gex_scrublet_file=}')
    print(f'{atac_file=}')
    print(f'{exn_cell_assignments=}')
    print(f'{inh_cell_assignments=}')
    print(f'{nonneuronal_h5ad=}')
    print(f'{device=}')

#### functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def peek_dataframe(df: DataFrame, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(f'{df.shape=}')
    if verbose:
        display(df.head())

def heatmap_compare(adata: AnnData, set1: str, set2: str):
    this_df = (
        adata.obs.groupby([set1, set2])
        .size()
        .unstack(fill_value=0)
    )
    norm_df = this_df/this_df.sum(axis=0)

    with rc_context({'figure.figsize': (12, 12), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-bright')
        _ = plt.pcolor(norm_df, edgecolor='black')
        _ = plt.xticks(arange(0.5, len(this_df.columns), 1), this_df.columns, rotation=90)
        _ = plt.yticks(arange(0.5, len(this_df.index), 1), this_df.index)
        plt.xlabel(set2)
        plt.ylabel(set1)
        plt.show()

## load data

### load ARC data

In [None]:
%%time
arc_data = sc.read_h5ad(arc_file)
# discovery data from Duffy
arc_data.obs['Study'] = 'LNG'
arc_data.obs['Study_type'] = 'discovery'
peek_anndata(arc_data, 'ARC anndata loaded', DEBUG)

In [None]:
arc_srublet_df = read_csv(arc_scrublet_file, index_col=0)
peek_dataframe(arc_srublet_df, 'ARC scrublet info loaded', DEBUG)
if DEBUG:
    display(arc_srublet_df.predicted_doublet.value_counts())    

### load GEX data

In [None]:
%%time
gex_data = sc.read_h5ad(gex_file)
print(gex_data)
# filter non-demultiplex cells
gex_data = gex_data[~gex_data.obs['donor_id'].isna()]
# discovery data from Duffy
gex_data.obs['Study'] = 'LNG'
gex_data.obs['Study_type'] = 'discovery'
peek_anndata(gex_data, 'GEX anndata loaded', DEBUG)   

In [None]:
gex_srublet_df = read_csv(gex_scrublet_file, index_col=0)
peek_dataframe(gex_srublet_df, 'GEX scrublet info loaded', DEBUG)
if DEBUG:
    display(gex_srublet_df.predicted_doublet.value_counts())    

#### filter any of the addtional non-genotype predicted doublets

the doublets detected based on genotype using demux to demultiplex the GEX pools have already been removed

In [None]:
# ARC
non_doublet_barcodes = arc_srublet_df.loc[~arc_srublet_df.predicted_doublet].index.values
arc_data = arc_data[arc_data.obs.index.isin(non_doublet_barcodes)]
# GEX
non_doublet_barcodes = gex_srublet_df.loc[~gex_srublet_df.predicted_doublet].index.values
gex_data = gex_data[gex_data.obs.index.isin(non_doublet_barcodes)]
peek_anndata(arc_data, 'filtered ARC anndata', DEBUG)
peek_anndata(gex_data, 'filtered GEX anndata', DEBUG)

## load the annotated cell-types for the RNA cells

In [None]:
%%time
exn_cells = read_csv(exn_cell_assignments, index_col='X')
inh_cells = read_csv(inh_cell_assignments, index_col='X')
adata_nonneur = sc.read_h5ad(nonneuronal_h5ad)
peek_dataframe(exn_cells, 'ExN cell types', DEBUG)
peek_dataframe(inh_cells, 'Inh cell types', DEBUG)
peek_anndata(adata_nonneur, 'processed Non-neuronal anndata loaded', DEBUG)

# combine the neuronal cells
neuronal_cells = concat([exn_cells[['final_anno']], inh_cells[['final_anno']]])
peek_dataframe(neuronal_cells, 'neuronal cell IDs to cell-type annotations', DEBUG)

if DEBUG:
    display(exn_cells.final_anno.value_counts())
    display(inh_cells.final_anno.value_counts())
    display(neuronal_cells.final_anno.value_counts())

### add the annotations for the non-neuronals that did not have the additional SAHA processing

In [None]:
# remove any cells already annotated as neuronal from the non-neuronal
nonner_cells = adata_nonneur.obs.copy()
print(len(set(neuronal_cells.index) & set(nonner_cells.index)))
nonner_cells = nonner_cells[~nonner_cells.index.isin(neuronal_cells.index)]
peek_dataframe(nonner_cells, 'only Non-neuronal cell types', DEBUG)

In [None]:
nonneur_labels = nonner_cells.liams_label.str.split('-', expand=True)
# just the cell type name is now first column of nonneur_labels
nonner_cells['final_anno'] = nonneur_labels[0]
if DEBUG:
    display(nonner_cells.final_anno.value_counts())

### combine the neuronal and non-neuronal cell ID to cell-type labels

In [None]:
cell_labels = concat([neuronal_cells , nonner_cells[['final_anno']]])
cell_labels = cell_labels.rename(columns={'final_anno': 'cell_label'})
# make sure there are no duplicates
cell_labels = cell_labels[~cell_labels.index.duplicated(keep='first')]
peek_dataframe(cell_labels, 'cell IDs to cell-type annotations', DEBUG)

if DEBUG:
    display(cell_labels.cell_label.value_counts())

### apply the cell-type labels the the ARC and GEX cells

In [None]:
%%time
arc_data.obs = arc_data.obs.merge(cell_labels, how='left', 
                                  left_index=True, right_index=True)
gex_data.obs = gex_data.obs.merge(cell_labels, how='left', 
                                  left_index=True, right_index=True)
peek_anndata(arc_data, 'ARC anndata with annotated cell-type labels', DEBUG)
peek_anndata(gex_data, 'GEX anndata with annotated cell-type labels', DEBUG)

if DEBUG:
    display(arc_data.obs.cell_label.value_counts())
    display(arc_data.obs.cell_label.isna().value_counts())
    display(gex_data.obs.cell_label.value_counts())
    display(gex_data.obs.cell_label.isna().value_counts())    

## load ATAC data

In [None]:
%%time
atac_data = sc.read_h5ad(atac_file)
print(atac_data)
# filter non-demultiplex cells
atac_data = atac_data[~atac_data.obs['donor_id'].isna()]
# discovery data from Duffy
atac_data.obs['Study'] = 'LNG'
atac_data.obs['Study_type'] = 'discovery'
atac_data.obs['cell_label'] = 'Unknown'
peek_anndata(atac_data, 'ATAC anndata loaded', DEBUG)

## if testing notebook subset the cells

In [None]:
def random_cells_subset(adata: AnnData, num_cells: int=5000, 
                        verbose: bool=False) -> AnnData:
    cells_subset = random.sample(list(adata.obs.index.values), num_cells)
    adata = adata[cells_subset]
    if DEBUG:
        print(adata)        
        display(adata.obs.sample(5))
    return adata

if TESTING:
    arc_data = random_cells_subset(arc_data, verbose=DEBUG)
    gex_data = random_cells_subset(gex_data, verbose=DEBUG)
    atac_data = random_cells_subset(atac_data, verbose=DEBUG)    

## organize the multiome anndata

We can then use the `organize_multiome_anndatas` function to orgnize these three datasets into a single Multiome dataset.
This function sorts and orders the data from the multi-modal and modality-specific AnnDatas into a single AnnData (aligning the features, padding missing modalities with 0s, etc). 

In [None]:
%%time
# We can now use the organizing method from scvi to concatenate these anndata
adata_mvi = scvi.data.organize_multiome_anndatas(arc_data, gex_data, atac_data)

Note that `organize_multiome_anndatas` adds an annotation to the cells to indicate which modality they originate from:

In [None]:
peek_anndata(adata_mvi, 'multiVI anndata created', DEBUG)
display(adata_mvi.obs.modality.value_counts())

#### what percentage of cells areaccessibility

In [None]:
modality_counts = adata_mvi.obs.modality.value_counts()
display(modality_counts)

#### clean up the obs as needed
the ARC samples are pooled but will use pool info as categorical variable so make one for non-pooled samples
drop the batch ID

In [None]:
if DEBUG:
    display(adata_mvi.obs.gex_pool.value_counts())
    display(adata_mvi.obs.atac_pool.value_counts())
    
# convert the age obs attribute feature to float from string
adata_mvi.obs.age = adata_mvi.obs.age.astype('float')
    
adata_mvi.obs.gex_pool = adata_mvi.obs.gex_pool.fillna('non')
adata_mvi.obs.atac_pool = adata_mvi.obs.atac_pool.fillna('non')
adata_mvi.obs.gex_pool = adata_mvi.obs.gex_pool.astype('str')
adata_mvi.obs.atac_pool = adata_mvi.obs.atac_pool.astype('str')
adata_mvi.obs.drop(columns=['batch_id'], inplace=True)

adata_mvi.obs.phase1_cluster = adata_mvi.obs.phase1_cluster.cat.add_categories(['phase2'])
adata_mvi.obs.phase1_celltype = adata_mvi.obs.phase1_celltype.cat.add_categories(['phase2'])
adata_mvi.obs.phase1_cluster = where(adata_mvi.obs.phase1_cluster == 'NA', 
                                     'phase2', adata_mvi.obs.phase1_cluster)
adata_mvi.obs.phase1_celltype = where(adata_mvi.obs.phase1_celltype == 'NA', 
                                     'phase2', adata_mvi.obs.phase1_celltype)
adata_mvi.obs.cell_label = adata_mvi.obs.cell_label.fillna('Unknown')
adata_mvi.obs.cell_label = where(adata_mvi.obs.cell_label == 'Not Present', 
                                 'Unknown', adata_mvi.obs.cell_label)

if DEBUG:
    display(adata_mvi.obs.sample(10))
    display(adata_mvi.obs.gex_pool.value_counts())
    display(adata_mvi.obs.atac_pool.value_counts())
    display(adata_mvi.obs.phase1_cluster.value_counts())
    display(adata_mvi.obs.phase1_celltype.value_counts())
    display(adata_mvi.obs.cell_label.value_counts())

<div class="alert alert-info">
Important

MultiVI requires the features to be ordered so that genes appear before genomic regions. This must be enforced by the user.

</div>

MultiVI requires the features to be ordered, such that genes appear before genomic regions. In this case this is already the case, but it's always good to verify:

In [None]:
adata_mvi = adata_mvi[:, adata_mvi.var["modality"].argsort()].copy()
if DEBUG:
    display(adata_mvi.var)

## save the MultiVi organized but unprocessed anndata object note that the subject is in the obs

In [None]:
%%time
adata_mvi.write(raw_anndata_file)

In [None]:
peek_anndata(adata_mvi, f'multiVI anndata that was just saved to {raw_anndata_file}', DEBUG)

## perform some typical pre-processing

We also filter features to remove those that appear in fewer than MIN% of the cells

In [None]:
%%time
print(adata_mvi.shape)
# annotate the group of mitochondrial genes as 'mt'
adata_mvi.var['mt'] = adata_mvi.var_names.str.startswith('MT-')  
# ribosomal genes
adata_mvi.var['ribo'] = adata_mvi.var_names.str.startswith(('RPS', 'RPL'))
# hemoglobin genes
adata_mvi.var['hb'] = adata_mvi.var_names.str.contains('^HB[^(P)]')

# With pp.calculate_qc_metrics, we can compute many metrics very efficiently.
sc.pp.calculate_qc_metrics(adata_mvi, qc_vars=['mt', 'ribo', 'hb'], 
                           inplace=True, log1p=True)
adata_mvi = adata_mvi[adata_mvi.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
# Basic filtering:
max_genes_threshold = percentile(adata_mvi.obs['n_genes_by_counts'], 90)
print(f"Suggested max_genes threshold: {max_genes_threshold:.0f}")
sc.pp.filter_cells(adata_mvi, min_genes=100)
sc.pp.filter_cells(adata_mvi, max_genes=max_genes_threshold)
sc.pp.filter_genes(adata_mvi, min_cells=10)

print(adata_mvi)

if DEBUG:
    display(adata_mvi.obs.sample(5))
    display(adata_mvi.var.sample(5))    

#### if flag set then subset to highest variance features

MultiVI tutorial doesn't suggest this so probably typically will set to false

In [None]:
if DETECT_HV_FEATURES:
    n_top_genes = int(adata_mvi.n_vars * TOP_FEATURES_PERCENT)
    sc.pp.highly_variable_genes(adata_mvi, n_top_genes=n_top_genes, 
                                batch_key='atac_pool',flavor='seurat_v3', 
                                subset=False)                                
    print(adata_mvi)
    print(adata_mvi.obs.modality.value_counts())
    print(adata_mvi.var.modality.value_counts())
    display(adata_mvi.var.groupby('modality')['highly_variable'].value_counts())

In [None]:
Series(adata_mvi.var.highly_variable).value_counts()

#### there will probably be an imbalance in the HVG towards peaks try rebalancing

In [None]:
%%time
# find HVG based on just gene data
# gene_features = adata_mvi.var.loc[adata_mvi.var.modality == 'Gene Expression', ['ID']]
# gene_adata = adata_mvi[:, gene_features.index].copy()
gene_adata = adata_mvi[adata_mvi.obs.modality.isin(['expression', 'paired']), 
                       adata_mvi.var.modality == 'Gene Expression'].copy()
num_genes = gene_adata.n_vars
n_top_genes = int(num_genes * TOP_FEATURES_PERCENT)
sc.pp.highly_variable_genes(gene_adata, n_top_genes=n_top_genes, 
                            batch_key='gex_pool', flavor='seurat_v3')
# find HVG based on just peak and adjust for number imbalance
# peak_features = adata_mvi.var.loc[adata_mvi.var.modality == 'Peaks', ['ID']]
# peak_adata = adata_mvi[:, peak_features.index].copy()
peak_adata = adata_mvi[adata_mvi.obs.modality == 'accessibility', 
                       adata_mvi.var.modality == 'Peaks'].copy()
# adjust number top peaks relative to peak/gene composistion of features
num_peaks = peak_adata.n_vars
# n_top_genes = int((num_peaks * TOP_FEATURES_PERCENT)/(num_peaks/num_genes))*2
# n_top_genes = int((num_peaks * TOP_FEATURES_PERCENT)/(num_peaks/num_genes))/2
sc.pp.highly_variable_genes(peak_adata, n_top_genes=n_top_genes, 
                            batch_key='atac_pool', flavor='seurat_v3')
# get the HVG feature set
gene_hvg_index = gene_adata.var.loc[gene_adata.var.highly_variable == True].index
peak_hvg_index = peak_adata.var.loc[peak_adata.var.highly_variable == True].index
hvg_features = set(gene_hvg_index) | set(peak_hvg_index)

# update the HVG in the main anndata object
# clear the flag
adata_mvi.var.highly_variable = False
# and reset to desired features
adata_mvi.var.loc[adata_mvi.var.index.isin(list(hvg_features)), 'highly_variable'] = True

if FILTER_HV_FEATURES:
    adata_mvi = adata_mvi[:, adata_mvi.var.highly_variable]

print(adata_mvi)
if DEBUG:
    display(adata_mvi.var.groupby('modality')['highly_variable'].value_counts())    

## Setup and Training MultiVI
We can now set up and train the MultiVI model!

First, we need to setup the Anndata object using the `setup_anndata` function. At this point we specify any batch annotation that the model would account for.
**Importantly**, the main batch annotation, specific by `batch_key`, should correspond to the modality of the cells.

Other batch annotations (e.g if there are multiple ATAC batches) should be provided using the `categorical_covariate_keys`.

The actual values of categorical covariates (include `batch_key`) are not important, as long as they are different for different samples.
I.e it is not important to call the expression-only samples "expression", as long as they are called something different than the multi-modal and accessibility-only samples.

<div class="alert alert-info">
Important

MultiVI requires the main batch annotation to correspond to the modality of the samples. Other batch annotation, such as in the case of multiple RNA-only batches, can be specified using `categorical_covariate_keys`.

</div>

In [None]:
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality', 
                                 categorical_covariate_keys = ['sample_id', 'gex_pool', 'atac_pool'],
                                 continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'],) 

When creating the object, we need to specify how many of the features are genes, and how many are genomic regions. This is so MultiVI can determine the exact architecture for each modality.

In [None]:
mvi = scvi.model.MULTIVI(
    adata_mvi, 
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.view_anndata_setup()
print(mvi)

In [None]:
%%time
if RUN_TRAINING:
    mvi.train(batch_size=BATCH_SIZE)

## Save and Load MultiVI models

Saving and loading models is similar to all other scvi-tools models, and is very straight forward:

In [None]:
if RUN_TRAINING:
    mvi.save(trained_model_path, overwrite=True)

In [None]:
mvi = scvi.model.MULTIVI.load(trained_model_path, adata=adata_mvi, accelerator='gpu')
print(mvi)

## Extracting and visualizing the latent space

We can now use the `get_latent_representation` to get the latent space from the trained model, and visualize it using scanpy functions:

In [None]:
adata_mvi.obsm['MultiVI_latent'] = mvi.get_latent_representation()

#### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_mvi, use_rep='MultiVI_latent')
# sc.tl.umap(adata_mvi, min_dist=0.3)
sc.tl.umap(adata_mvi)

#### visualize the latent representation

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    sc.pl.umap(adata_mvi, color=['sample_id'])
    sc.pl.umap(adata_mvi, color=['Study'])
    sc.pl.umap(adata_mvi, color=['modality'])
    sc.pl.umap(adata_mvi, color=['cell_label'], frameon=False, 
               legend_loc='on data')
    sc.pl.umap(adata_mvi, color=['phase1_celltype'], frameon=False, 
               legend_loc='on data')  

### Clustering on the MultiVI latent space
The user will note that we imported curated labels from the original publication. Our interface with scanpy makes it easy to cluster the data with scanpy from MultiVI's latent space and then reinject them into MultiVI (e.g., for differential expression).

In [None]:
%%time
# neighbors were already computed using scVI
sc.tl.leiden(adata_mvi, key_added='leiden_MultiVI', resolution=0.6, flavor='igraph', n_iterations=2)

#### check range of Leiden resolutions for clustering

In [None]:
%%time
resolutions_to_try = arange(0.1, 1.0, 0.1)
print(resolutions_to_try)
mean_scores = {}
largest_score = 0
best_res = 0
new_leiden_key = 'leiden_MultiVI'
for leiden_res in resolutions_to_try:
    # use only 2 decimals
    leiden_res = round(leiden_res, 2)    
    print(f'### using Leiden resolution of {leiden_res}')
    # neighbors were already computed using scVI
    sc.tl.leiden(adata_mvi, key_added=new_leiden_key, resolution=leiden_res, flavor='igraph', n_iterations=2)
    # silhouette_avg = silhouette_score(adata_mvi.obsm['MultiVI_latent'], adata_mvi.obs[new_leiden_key])
    adata_temp = adata_mvi[adata_mvi.obs.cell_label != 'Unknown']
    silhouette_avg = silhouette_score(adata_temp.obs[new_leiden_key].to_frame(), adata_temp.obs.cell_label)
    print((f'For res = {leiden_res:.2f}, average silhouette: {silhouette_avg:.3f} '
           f'for {adata_mvi.obs[new_leiden_key].nunique()} clusters'))    
    # mean sample count per cluster
    df_grouped = adata_mvi.obs.groupby(new_leiden_key)['sample_id'].count()
    df_grouped = df_grouped[df_grouped >= 30].to_frame().reset_index()
    df_grouped = df_grouped.groupby(new_leiden_key)['sample_id'].nunique()    
    mean_sample_per_cluster = df_grouped.mean()
    less_than_half = df_grouped[df_grouped < adata_mvi.obs.sample_id.nunique()/3].shape[0]
    # mean cell count per cluster
    df_grouped = adata_mvi.obs[new_leiden_key].value_counts()
    mean_cell_per_cluster = df_grouped.mean()        
    mean_scores[leiden_res] = [silhouette_avg, adata_mvi.obs[new_leiden_key].nunique(), 
                               mean_sample_per_cluster, mean_cell_per_cluster, less_than_half]
    
    # update best resolution info
    if silhouette_avg > largest_score:
        largest_score = silhouette_avg
        best_res = leiden_res

In [None]:
scores_df = DataFrame(index=mean_scores.keys(), data=mean_scores.values())
scores_df.columns = ['score', 'num_clusters', 'mean_samples', 'mean_cells', 'less_than_half']
print('max score at')
best_result = scores_df.loc[scores_df.score == scores_df.score.max()]
display(best_result)
best_resolution = best_result.index.values[0]
print(f'best resolution found at {best_resolution}')
if DEBUG:
    display(scores_df)
fig_filename = f'{figures_dir}/leiden_resolution_silhouette_score.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    lineplot(x=scores_df.index, y='score', data=scores_df)
    plt.xlabel('resolution')
    plt.savefig(fig_filename)
    plt.show()
lineplot(x=scores_df.index, y='num_clusters', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_samples', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_cells', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='less_than_half', data=scores_df)
plt.ylabel('number clusters with less than 1/3 of donors')
plt.xlabel('resolution')
plt.show()

In [None]:
best_resolution = round(best_resolution, 2)
best_resolution = 0.6
print(f'{best_resolution=}')
sc.tl.leiden(adata_mvi, key_added='leiden_MultiVI', resolution=best_resolution, flavor='igraph', n_iterations=2)

### visualize the clusters

In [None]:
# figure_file = f'_{project}.umap.leiden_on.png'
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 400}):
    plt.style.use('seaborn-v0_8-paper')
    sc.pl.umap(adata_mvi, color=['leiden_MultiVI'], frameon=False, 
               legend_loc='on data', legend_fontsize=6)
    sc.pl.umap(adata_mvi, color=['modality'], frameon=False)
    sc.pl.umap(adata_mvi, color=['Study'], frameon=False)
    sc.pl.umap(adata_mvi, color=['cell_label'], 
               frameon=False, legend_loc='on data', legend_fontsize=6)
    sc.pl.umap(adata_mvi, color=['age'], frameon=False)
    sc.pl.umap(adata_mvi, color=['phase1_celltype'], 
               frameon=False, legend_loc='on data', legend_fontsize=6)

In [None]:
temp = adata_mvi.obs[adata_mvi.obs.cell_label != 'Unknown']
display(temp.cell_label.value_counts())
print(temp.cell_label.nunique())

In [None]:
for cell_type in temp.cell_label.unique():
    print(f'## {cell_type}', end=':  ')
    cell_data = temp.loc[temp.cell_label == cell_type]
    print(cell_data.leiden_MultiVI.nunique())
    print(f'{cell_data.shape=}')
    display(cell_data.leiden_MultiVI.value_counts())

In [None]:
for cluster in temp.leiden_MultiVI.unique():
    print(f'##{cluster}', end=':  ')
    cluster_data = temp.loc[temp.leiden_MultiVI == cluster]
    print(cluster_data.cell_label.nunique())
    print(f'{cluster_data.shape=}')
    display(cluster_data.cell_label.value_counts())    

## are any of the cluster primarily ATAC 

In [None]:
%%time
sc.tl.embedding_density(adata_mvi, basis='umap', groupby='modality')
sc.pl.embedding_density(adata_mvi, basis='umap', key='umap_density_modality', ncols=3)

In [None]:
atac_clusters = []
for cluster in adata_mvi.obs.leiden_MultiVI.unique():
    this_cluster = adata_mvi.obs.loc[adata_mvi.obs.leiden_MultiVI == cluster]
    cluster_cnts = this_cluster.modality.value_counts()
    percentages = (cluster_cnts / cluster_cnts.sum()) * 100
    if percentages.accessibility > 80:
        print(f'{cluster=}')
        atac_clusters.append(cluster)
        cluster_cnts = cluster_cnts.to_frame()
        cluster_cnts['percentages'] = percentages.values
        display(cluster_cnts)
print(atac_clusters)

In [None]:
if len(atac_clusters) > 0:
    adata_mvi_sub = adata_mvi[adata_mvi.obs.leiden_MultiVI.isin(atac_clusters)]
    print(adata_mvi_sub)
    display(adata_mvi_sub.obs.cell_label.value_counts())
    with rc_context({'figure.figsize': (9, 9), 'figure.dpi': 100}):
        plt.style.use('seaborn-v0_8-talk')
        sc.pl.umap(adata_mvi_sub, color=['leiden_MultiVI'], frameon=False, legend_loc='on data')

## save to processed anndata object

In [None]:
%%time
adata_mvi.write(out_h5ad_file)

In [None]:
!date