This tutorial demonstrates how to use scSpecies to perform latent alignments of three datasets spanning multiple species (mice, humans and hamsters).  
Let us start by specifying the path where the datasets are downloaded to.

In [None]:
import numpy as np
import muon as mu
import scanpy as sc
import os

path = os.path.abspath('').replace('\\', '/')+'/'

As a first step we load the context and target dataset as `.h5ad` files.  
This file format is used for storing annotated multidimensional data arrays, particularly in the field of genomics and bioinformatics.  
It can be used via the AnnData python package, more information can be fond here: https://anndata.readthedocs.io/en/latest/.  

We will use a subset of the mouse, hamster and human liver cell atlas.  
As context dataset we will use the `mouse_liver_filtered.h5ad` file, which contains mice liver cell samples.   
As target datasets we will use the `human_liver_filtered.h5ad` file, which contains human liver cell samples and
the `hamster_liver_filtered.h5ad` file, which contains hamster liver cell samples.

The human and mice datasets we use for this tutorial are preprocessed to speed up computations.  
First, the dimensionality of the gene sets was reduced to 4000 highly variable genes that are expressed in more than 2.5% of cells.  
Second, cells belonging to large cell types were randomly sampled to contain only around 1000 samples.  
Third, unlabeled cells and cells with a labeling conflicts between fine and coarse labels were removed.  
Lastly, we only included cells obtained via CITE-seq and scRNA-seq.  

The full datasets can be downloaded at https://www.livercellatlas.org/.

In [None]:
context_adata = sc.read_h5ad(path+"dataset/mouse_liver_filtered.h5ad")
target_adata_human = sc.read_h5ad(path+"dataset/human_liver_filtered.h5ad")
target_adata_hamster = sc.read_h5ad(path+"dataset/hamster_liver.h5ad")

context_adata.X = context_adata.X.astype('float32')
target_adata_human.X = target_adata_human.X.astype('float32')
target_adata_hamster.X = target_adata_hamster.X.astype('float32')

Next, we specify the `.obs` key under which the cell and batch labels for the context and target dataset are stored.  
The cell labels for the target dataset are used only for plotting and for computing performance metrics but are not needed during training.  
If the target cell labels are unknown this can be indicated by `target_cell_key = None`.  

For precise metrics calculations and better visualization it may be necessary to manually change some cell type label names, 
so that their naming convention is consistent across the datasets.  

We also have to specify the gene naming convention of the datasets.   
Gene names following the human gene naming convention are mostly written in uppercase letters,  
Gene names following the mice gene naming convention are mostly written in mixed case letters.  

scSpecies can translate homologous genes between the human and mouse nomenclature.   
For other species it can be beneficial to manually translate gene names of homologs to one of these conventions.   
For the hamster dataset we will just assume the mouse nomenclature.  
The number identified homologs in this way are enough so that we can perform a meaningful Nearest neighbor search.  

In [None]:
context_batch_key = 'batch'
context_cell_key = 'cell_type_fine'
print('\nMouse context gene names: ', context_adata.var_names[0:5])

human_batch_key = 'batch'
human_cell_key = 'cell_type_fine'
print('\nHuman target gene names: ', target_adata_human.var_names[0:5])

hamster_batch_key = 'batch'
hamster_cell_key = 'cell_type_coarse'
print('\nHamster target gene names: ', target_adata_hamster.var_names[0:5])

Next, we create a `muon.MuData` dataset (https://muon.readthedocs.io/en/latest/) which scSpecies uses during training.  
Muon lets us define a container for multimodal data.  
One modality will be the context mouse dataset and the other modalities are target datasets we want to align with the context.
In our case these are human and hamster datasets.  
We instantiate a preprocessing class and register context and target `anndata.AnnData` datasets.  

This class translates the gene names, performs the data-level nearest neighbor search on homologous genes,   
one-hot-encodes experimental batch effects, computes the library encoder prior parameters,   
and optionally filters cells with low gene expression counts and reduces the gene dimensionality to highly variable genes.   
We reduce the dimensionality of the human and mouse to the 2500 and of the hamster dataset to 3000 most highly variable genes.  
This way we investigate if scSpecies can handle different context and target data dimensionalities.

In [None]:
from preprocessing import create_mdata

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

preprocess = create_mdata(context_adata, context_batch_key, context_cell_key, context_dataset_name='mouse', context_gene_naming_convention='mouse', context_n_top_genes=2500)
preprocess.setup_target_adata(target_adata_human, human_batch_key, human_cell_key, target_dataset_name='human', target_gene_naming_convention='human', target_n_top_genes=2500)
preprocess.setup_target_adata(target_adata_hamster, hamster_batch_key, hamster_cell_key, target_dataset_name='hamster', target_gene_naming_convention='mouse', target_n_top_genes=3000)

preprocess.save_mdata(path, 'liver_atlas')

Lets look at the created mdata object:  
We have obtained a multimodal container that spans the three datasets.   
The create_mdata class has added additional keys corresponding to the preprocessing steps.  

![Created mdata file](figures/mdata.jpg)

In [None]:
mdata = mu.read_h5mu(path+"dataset/liver_atlas.h5mu")
print(mdata)

First lets focus on the mouse and human latent alignment. 
The alignment of the hamster will be performed later.

We define the context and target scVI models by instantiating the scSpecies class.   
We recommend using an NVIDIA GPU during training.  
CPU training can be slow, and Apple Silicon runs into errors when trying to compute the log-gamma function for the scVI loss.   

In [None]:
from models import scSpecies
import torch

device = ("cuda" if torch.cuda.is_available() else "cpu")

model = scSpecies(device, 
                mdata, 
                path,
                context_dataset_key = 'mouse', 
                target_dataset_key = 'human',          
                alignment = 'inter', 
                )

We train and evaluate the context scVI model.  
The model parameters are automatically saved to the specified path and the latent representations saved in the `muon.MuData` object at the context modality in the `.obsm` layer.

In [None]:
model.train_context(25, save_key='_for_the_human_dataset')
model.eval_context()

Next we train and evaluate the target scVI model.
We also track the loss of prototype cells, to monitor alignment.

In [None]:
model.train_target(25, track_prototypes=True, save_key='_for_the_human_dataset')
model.eval_target()

After training, we can predict cell labels using the aligned representation.   
We can compare the quality of the predicted labels with the data level nearest neighbor search.  
The function takes as input a list of tuples of cell label keys that should be compared, in our case coarse and fine cell labels.  
We see that the accuracy is higher for coarse cell labels, which is expected.  

In [None]:
model.eval_label_transfer(cell_keys = [('cell_type_coarse', 'cell_type_coarse'), ('cell_type_fine', 'cell_type_fine')])

We can plot the results with a provided plotting function.  
For other datasets the color palette should be adjusted when a consistent is wanted.  
First we generate bar plots that indicate improvement over the data level NNS for the coarse cell labels:

In [None]:
from plot_utils import bar_plot

bar_plot(model, save_path=path, cell_key= 'cell_type_coarse')

And then for the fine cell labels:

In [None]:
bar_plot(model, save_path=path, cell_key= 'cell_type_fine')

We can also visualize the aligned latent space:

In [None]:
from plot_utils import plot_umap

plot_umap(model, context_cell_key = 'cell_type_fine', target_cell_key = 'cell_type_fine')        

We can also visualize the likelihood of the cell prototypes during the alignment phase:

In [None]:
from plot_utils import plot_prototypes

plot_prototypes(model)        

Finally, the difference in modeled gene expression can be analyzed by comparing the log2-fold change in normalized gene expression.  
The function plots the log2-fold change from context compared to target genes in all shared cell labels on the x-axis.  
On the y-axis are the probabilities of a gene being differentially expressed at level `lfc_delta`

In [None]:
from plot_utils import plot_lfc

model.compute_logfold_change(lfc_delta=1, context_cell_key='cell_type_fine', target_cell_key='cell_type_fine')
plot_lfc(model, save_path=path, name='human')

All results are written to the `model.mudata` object.  
Let us take a look at how the scPecies workflow has modified the `MuData` object:

![](figures/mdata_after_training.jpg)

In [None]:
print(model.mdata)

Let's train a second scPecies model to align the hamster dataset.  
We instantiate a second model and load the encoder parameters of the context encoder.

If we want to identify differentially expressed genes the context decoder should be retrained.  

In [None]:
model_hamster = scSpecies(device, 
                mdata, 
                path,
                context_dataset_key = 'mouse', 
                target_dataset_key = 'hamster',                                                   
                )

model_hamster.load_params('context_encoder', name='_for_the_human_dataset')

model_hamster.train_context(25, train_decoder_only=True, save_key='_for_the_hamster_dataset')
model_hamster.eval_context()

model_hamster.train_target(25, save_key='_for_the_hamster_dataset')
model_hamster.eval_target()

model_hamster.eval_label_transfer(cell_keys = [('cell_type_coarse', 'cell_type_coarse')])
model_hamster.compute_logfold_change(context_cell_key='cell_type_coarse', target_cell_key='cell_type_coarse')

Again, we can visualize the results.  
We can see for example that the hamster cDCs are aligned with the mice and humans cDC1s, cDC2s and pDCs.  
As the fine hamster labels are unknown can this help us to infer the missing information.

In [None]:
from plot_utils import plot_umap_three_species

plot_umap_three_species(model, model_hamster, context_cell_key = 'cell_type_coarse', target_cell_key_1 = 'cell_type_coarse', target_cell_key_2 = 'cell_type_coarse', save_path=path)
bar_plot(model_hamster, save_path=path, cell_key= 'cell_type_coarse', name='_hamster')
plot_lfc(model_hamster, save_path=path, name='hamster')