# Cluster each modality separately and then integrate across the modalities
- cluster RNA with [scVI](https://docs.scvi-tools.org/en/1.3.0/tutorials/notebooks/quick_start/api_overview.html)
- cluster ATAC with [peakVI](https://docs.scvi-tools.org/en/1.3.0/tutorials/notebooks/atac/PeakVI.html)
- integrate RNA and ATAC clusters with [MOSCOT TranslationProblem](https://moscot.readthedocs.io/en/latest/notebooks/tutorials/600_tutorial_translation.html)

In [None]:
!date

#### import libraries

In [None]:
import scvi
import scanpy as sc
from anndata import AnnData
from sklearn.metrics import silhouette_score
from numpy import arange, mean
import moscot.plotting as mtp
from moscot.problems.cross_modality import TranslationProblem
from pandas import DataFrame
from scipy import sparse

import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
from seaborn import lineplot

import warnings
warnings.filterwarnings('ignore')

scvi.settings.seed = 42

%matplotlib inline
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

#### set notebook variables

In [None]:
# variables and constants
DEBUG = True
project = 'aging_phase2'
MAX_MITO_PERCENT = 10
TOP_FEATURES_PERCENT = 0.10
RNA_LATENT_KEY = 'X_scVI'
ATAC_LATENT_KEY = 'X_peakVI'

# directories
wrk_dir = '/labshare/raph/datasets/adrd_neuro/brain_aging/phase2'
quants_dir = f'{wrk_dir}/quants'

# in files
raw_anndata_file =f'{quants_dir}/{project}.raw.h5ad'

if DEBUG:
    print(f'{raw_anndata_file=}')

## functions

In [None]:
def peek_anndata(adata: AnnData, message: str=None, verbose: bool=False):
    if not message is None and len(message) > 0:
        print(message)
    print(adata)
    if verbose:
        display(adata.obs.head())
        display(adata.var.head())

def foscttm(
    x: np.ndarray,
    y: np.ndarray,
) -> float:
    d = scipy.spatial.distance_matrix(x, y)
    foscttm_x = (d < np.expand_dims(np.diag(d), axis=1)).mean(axis=1)
    foscttm_y = (d < np.expand_dims(np.diag(d), axis=0)).mean(axis=0)
    fracs = []
    for i in range(len(foscttm_x)):
        fracs.append((foscttm_x[i] + foscttm_y[i]) / 2)
    return np.mean(fracs).round(4)

## load the raw multiome data, multiVI anndata

In [None]:
%%time
adata_multi = sc.read_h5ad(raw_anndata_file)
peek_anndata(adata_multi, 'raw multiVI anndata', DEBUG)
if DEBUG:
    display(adata_multi.obs.modality.value_counts())
    display(adata_multi.var.modality.value_counts())

## split the RNA and ATAC modalities

In [None]:
%%time
adata_rna = adata_multi[adata_multi.obs.modality.isin(['paired', 'expression']), adata_multi.var.modality == 'Gene Expression'].copy()
adata_atac = adata_multi[adata_multi.obs.modality.isin(['paired', 'accessibility']), adata_multi.var.modality == 'Peaks'].copy()
peek_anndata(adata_rna, 'raw RNA anndata', DEBUG)
peek_anndata(adata_atac, 'raw ATAC anndata', DEBUG)
if DEBUG:
    display(adata_rna.obs.modality.value_counts())
    display(adata_rna.var.modality.value_counts())
    display(adata_atac.obs.modality.value_counts())
    display(adata_atac.var.modality.value_counts())    

## cluster the RNA cells

### prep the data, typical preprocessing

In [None]:
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata_rna.var['mt'] = adata_rna.var_names.str.startswith('MT-') 
# ribosomal genes
adata_rna.var['ribo'] = adata_rna.var_names.str.startswith(('RPS', 'RPL'))
# hemoglobin genes
adata_rna.var['hb'] = adata_rna.var_names.str.contains('^HB[^(P)]')
sc.pp.calculate_qc_metrics(adata_rna, qc_vars=['mt', 'ribo', 'hb'], 
                           inplace=True, log1p=True)
# Basic filtering:
adata_rna = adata_rna[adata_rna.obs.pct_counts_mt < MAX_MITO_PERCENT, :]
sc.pp.filter_cells(adata_rna, min_genes=200)
sc.pp.filter_genes(adata_rna, min_cells=3)

n_top_genes = int(adata_rna.n_vars * TOP_FEATURES_PERCENT)
sc.pp.highly_variable_genes(adata_rna, n_top_genes=n_top_genes, 
                            batch_key='gex_pool',flavor='seurat_v3', 
                            subset=True)

peek_anndata(adata_rna, 'prepped RNA anndata', DEBUG)
if DEBUG:
    display(adata_rna.obs.modality.value_counts())

In [None]:
adata_rna.layers['counts'] = adata_rna.X.copy()  # preserve counts

### setup the RNA anndata for scVI

In [None]:
scvi.model.SCVI.setup_anndata(
    adata_rna,
    layer="counts",
    categorical_covariate_keys=['gex_pool', 'sample_id'],
    continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'],
)

### create and train the model

In [None]:
%%time
rna_model = scvi.model.SCVI(adata_rna)
print(rna_model)
rna_model.train()

### get and store the model output

In [None]:
adata_rna.obsm[RNA_LATENT_KEY] = sparse.csr_matrix(adata_rna.obsm[RNA_LATENT_KEY])

In [None]:
adata_rna.obsm[RNA_LATENT_KEY] = sparse.csr_matrix(rna_model.get_latent_representation())
peek_anndata(adata_rna, 'post latent RNA anndata', DEBUG)
if DEBUG:
    display(adata_rna.obsm[RNA_LATENT_KEY].shape)

### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_rna, use_rep=RNA_LATENT_KEY)
sc.tl.umap(adata_rna)
peek_anndata(adata_rna, 'embedded latent RNA anndata', DEBUG)

In [None]:
peek_anndata(adata_rna, 'embedded latent RNA anndata', DEBUG)

### check range of Leiden resolutions for clustering

In [None]:
%%time
resolutions_to_try = arange(0.3, 1.05, 0.05)
print(resolutions_to_try)
mean_scores = {}
largest_score = 0
best_res = 0
new_leiden_key = 'leiden_VI'
for leiden_res in resolutions_to_try:
    # use only 2 decimals
    leiden_res = round(leiden_res, 2)    
    print(f'### using Leiden resolution of {leiden_res}')
    # neighbors were already computed using scVI
    sc.tl.leiden(adata_rna, key_added=new_leiden_key, resolution=leiden_res, 
                 flavor='igraph', n_iterations=2)
    silhouette_avg = silhouette_score(adata_rna.obsm[RNA_LATENT_KEY], adata_rna.obs[new_leiden_key])
    print((f'For res = {leiden_res:.2f}, average silhouette: {silhouette_avg:.3f} '
           f'for {adata_rna.obs[new_leiden_key].nunique()} clusters'))
    # mean sample count per cluster
    df_grouped = adata_rna.obs.groupby(new_leiden_key)['sample_id'].count()
    mean_sample_per_cluster = df_grouped.mean()
    # mean cell count per cluster
    df_grouped = adata_rna.obs[new_leiden_key].value_counts()
    mean_cell_per_cluster = df_grouped.mean()        
    mean_scores[leiden_res] = [silhouette_avg, adata_rna.obs[new_leiden_key].nunique(), 
                               mean_sample_per_cluster, mean_cell_per_cluster]
    # update best resolution info
    if silhouette_avg > largest_score:
        largest_score = silhouette_avg
        best_res = leiden_res

In [None]:
scores_df = DataFrame(index=mean_scores.keys(), data=mean_scores.values())
scores_df.columns = ['score', 'num_clusters', 'mean_samples', 'mean_cells']
print('max score at')
best_result = scores_df.loc[scores_df.score == scores_df.score.max()]
display(best_result)
best_resolution = best_result.index.values[0]
print(f'best resolution found at {best_resolution}')
if DEBUG:
    display(scores_df)
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    lineplot(x=scores_df.index, y='score', data=scores_df)
    plt.xlabel('resolution')
    plt.show()
lineplot(x=scores_df.index, y='num_clusters', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_samples', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_cells', data=scores_df)
plt.xlabel('resolution')
plt.show()

### re-cluster at the best resolution found based on Silhouette score

In [None]:
sc.tl.leiden(adata_rna, key_added='leiden_VI', resolution=best_resolution, 
             flavor='igraph', n_iterations=2)
peek_anndata(adata_rna, 'embedded latent RNA anndata', DEBUG)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_rna, color=['leiden_VI'], 
               frameon=False)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_rna, color=['gex_pool'], 
               frameon=False)

## cluster the ATAC cells

### prep the data, limited ATAC steps

In [None]:
sc.pp.filter_cells(adata_atac, min_genes=200)
# compute the threshold: 5% of the cells
min_cells = int(adata_atac.n_obs * 0.05)
# in-place filtering of regions
sc.pp.filter_genes(adata_atac, min_cells=min_cells)

n_top_genes = int(adata_atac.n_vars * TOP_FEATURES_PERCENT)
sc.pp.highly_variable_genes(adata_atac, n_top_genes=n_top_genes, 
                            batch_key='atac_pool',flavor='seurat_v3', 
                            subset=True)

peek_anndata(adata_atac, 'prepped ATAC anndata', DEBUG)
if DEBUG:
    display(adata_atac.obs.modality.value_counts())

In [None]:
adata_atac.layers['counts'] = adata_atac.X.copy()  # preserve counts

### setup the RNA anndata for scVI

In [None]:
scvi.model.PEAKVI.setup_anndata(
    adata_atac,
    layer="counts",
    categorical_covariate_keys=['atac_pool', 'sample_id']
)

### create and train the model

In [None]:
%%time
atac_model = scvi.model.PEAKVI(adata_atac)
print(atac_model)
atac_model.train()

In [None]:
print(atac_model)

### get and store the model output

In [None]:
adata_atac.obsm[ATAC_LATENT_KEY] = sparse.csr_matrix(atac_model.get_latent_representation())
peek_anndata(adata_atac, 'post latent ATAC anndata', DEBUG)
if DEBUG:
    display(adata_atac.obsm[ATAC_LATENT_KEY].shape)

### embed the graph based on latent representation

In [None]:
%%time
sc.pp.neighbors(adata_atac, use_rep=ATAC_LATENT_KEY)
sc.tl.umap(adata_atac)
peek_anndata(adata_atac, 'embedded latent ATAC anndata', DEBUG)

In [None]:
peek_anndata(adata_atac, 'embedded latent ATAC anndata', DEBUG)

### check range of Leiden resolutions for clustering

In [None]:
%%time
resolutions_to_try = arange(0.3, 1.05, 0.05)
print(resolutions_to_try)
mean_scores = {}
largest_score = 0
best_res = 0
new_leiden_key = 'leiden_VI'
for leiden_res in resolutions_to_try:
    # use only 2 decimals
    leiden_res = round(leiden_res, 2)    
    print(f'### using Leiden resolution of {leiden_res}')
    # neighbors were already computed using scVI
    sc.tl.leiden(adata_atac, key_added=new_leiden_key, resolution=leiden_res, 
                 flavor='igraph', n_iterations=2)
    silhouette_avg = silhouette_score(adata_atac.obsm[ATAC_LATENT_KEY], adata_atac.obs[new_leiden_key])
    print((f'For res = {leiden_res:.2f}, average silhouette: {silhouette_avg:.3f} '
           f'for {adata_atac.obs[new_leiden_key].nunique()} clusters'))
    # mean sample count per cluster
    df_grouped = adata_atac.obs.groupby(new_leiden_key)['sample_id'].count()
    mean_sample_per_cluster = df_grouped.mean()
    # mean cell count per cluster
    df_grouped = adata_atac.obs[new_leiden_key].value_counts()
    mean_cell_per_cluster = df_grouped.mean()        
    mean_scores[leiden_res] = [silhouette_avg, adata_atac.obs[new_leiden_key].nunique(), 
                               mean_sample_per_cluster, mean_cell_per_cluster]
    # update best resolution info
    if silhouette_avg > largest_score:
        largest_score = silhouette_avg
        best_res = leiden_res

In [None]:
scores_df = DataFrame(index=mean_scores.keys(), data=mean_scores.values())
scores_df.columns = ['score', 'num_clusters', 'mean_samples', 'mean_cells']
print('max score at')
best_result = scores_df.loc[scores_df.score == scores_df.score.max()]
display(best_result)
best_resolution = best_result.index.values[0]
print(f'best resolution found at {best_resolution}')
if DEBUG:
    display(scores_df)
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-talk')
    lineplot(x=scores_df.index, y='score', data=scores_df)
    plt.xlabel('resolution')
    plt.show()
lineplot(x=scores_df.index, y='num_clusters', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_samples', data=scores_df)
plt.xlabel('resolution')
plt.show()
lineplot(x=scores_df.index, y='mean_cells', data=scores_df)
plt.xlabel('resolution')
plt.show()

### re-cluster at the best resolution found based on Silhouette score

In [None]:
sc.tl.leiden(adata_atac, key_added='leiden_VI', resolution=best_resolution, 
             flavor='igraph', n_iterations=2)
peek_anndata(adata_atac, 'embedded latent ATAC anndata', DEBUG)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_atac, color=['leiden_VI'], 
               frameon=False)

In [None]:
with rc_context({'figure.figsize': (8, 8), 'figure.dpi': 100}):
    plt.style.use('seaborn-v0_8-bright')
    sc.pl.umap(adata_atac, color=['atac_pool'], 
               frameon=False)

## integrate the modalities

### Prepare the TranslationProblem
We need to initialize the TranslationProblem by passing the source and target AnnData objects. After initialization, we need to prepare() the problem. In this particular case, we need to pay attention to 3 parameters:

src_attr: specifies the attribute in AnnData that contains the source distribution. In our case it refers to the key in obsm that stores the ATAC LSI embedding.

tgt_attr: specifies the attribute in AnnData that contains the target distribution. In our case it refers to the key in obsm that stores the RNA PCA embedding.

joint_attr[optional]: specifies a joint attribute over a common feature space to incorporate a linear term into the quadratic optimization problem. Initially, we consider the pure Gromov-Wasserstein setting and subsequently explore the fused problem.

In [None]:
tp = TranslationProblem(adata_src=adata_atac, adata_tgt=adata_rna)
tp = tp.prepare(src_attr="X_peakVI", tgt_attr="X_scVI", batch_key='sample_id')

### Solve the TranslationProblem
In fused quadratic problems, the alpha parameter defines the convex combination between the quadratic and linear terms. By default, alpha = 1, that is, we only consider the quadratic problem, ignoring the joint_attr. We choose a small value for epsilon to obtain a sparse transport map.

In [None]:
%%time
tp = tp.solve(alpha=1.0, epsilon=1e-3)

### Translate the TranslationProblem
We can now project one domain onto the other. The boolean parameter forward determines the direction of the barycentric projection. In our case, we project the source distribution AnnData (ATAC) onto the target distribution AnnData (RNA), therefore we use forward = True. The function translate() returns the translated object in the target space (or source space respectively).

In [None]:
translated = tp.translate(source="src", target="tgt", forward=True)

### Analyzing the translation
We will use the average FOSCTTM metric implemented above to analyze the alignment performance.

In [None]:
print(
    "Average FOSCTTM score of translating ATAC onto RNA: ",
    foscttm(adata_rna.obsm["GEX_X_pca"], translated),
)