# DIRECTi tutorial

In [None]:
import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import Cell_BLAST as cb

np.set_printoptions(threshold=200)
pd.set_option("max_rows", 20)
tf.logging.set_verbosity(0)
cb.config.N_JOBS = 4
cb.config.RANDOM_SEED = 0

## Load data

In this tutorial, we demonstrate how to train DIRECTi models used by Cell BLAST.

Let's first load a dataset (*Baron, M. et al., Cell Syst, 2016*), which profiles >8,000 human pancreatic islet cells.

In [None]:
baron_human = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5")

The data object has a similar structure to `AnnData` objects.

The expression matrix (in the cell by gene orientation) is stored in the `exprs` slot (in this case a sparse matrix containing raw UMI counts):

In [None]:
baron_human.exprs

In [None]:
baron_human.exprs[0:10, 0:10].toarray()

Meta table for cells is stored in the `obs` slot:

In [None]:
baron_human.obs.head()

Meta table for genes is stored in the `var` slot (in this case it's empty, only containing row names):

In [None]:
baron_human.var.head()

Other unstructured data are stored in the `uns` slot (which is a python dict). In this case we have lists of genes selected by various methods prestored here.

In [None]:
baron_human.uns.keys()

In [None]:
baron_human.uns["seurat_genes"]

## Gene selection

For custom datasets, preselected genes may not be available. In such cases, gene selection can be manually performed using the [`find_variable_genes()`](../modules/Cell_BLAST.data.html#Cell_BLAST.data.ExprDataSet.find_variable_genes) method, which is a reimplementation of the `FindVariableGenes()` function in Seurat v2.

In addition to the basic functionality in the Seurat function, we also support gene selection on different "batches" of the data, and merging selected genes from individual "batches" via a voting strategy. This helps filtering out batch effect related genes. E.g., to mitigate batch effect among 4 different donors in the pancreatic dataset, we specify the grouping variable "donor" (column in the `obs` data frame), and leave other parameters as default:

In [None]:
%%capture
selected_genes, axes = baron_human.find_variable_genes(grouping="donor")

In [None]:
selected_genes

In this case the selected genes should be identical to that prestored in the `uns` slot:

In [None]:
np.setdiff1d(selected_genes, baron_human.uns["seurat_genes"]).size, \
np.setdiff1d(baron_human.uns["seurat_genes"], selected_genes).size

By default, genes selected in 50% of the "batches" will be preserved, which is controlled by the `min_group_frac` parameter. In cases where cell type composition varies considerably across "batches", it might be helpful to reduce the `min_group_frac` cutoff.

## Unsupervised dimension reduction

Now we build and fit a DIRECTi model (for Cell BLAST) with the one-step [`fit_DIRECTi()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.fit_DIRECTi) function.

Note that we passed the raw data and previously selected genes directly. The function will handle data normalization and gene subsetting internally. Performing data normalization or genes subsetting externally beforehand is **NOT** recommended.

Also, we set the cell embedding dimensionality to 10, and use a 20-dimensional categorical latent *c*.

> Note that though 20-dimensional categorical latent was used, less clusters are formed in the embedding space. This is because the model is flexible to discard categories or to use multiple categories to represent the same cluster if a redundant number of categories is specified. If the data contains one single continuous spectrum of cell states, it can be more appropriate to set `cat_dim` to `None`.

In [None]:
%%capture
model = cb.directi.fit_DIRECTi(
    baron_human, genes=selected_genes,
    latent_dim=10, cat_dim=20
)

After model training, we can project cells into the cell embedding space using the [`inference()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.DIRECTi.inference) method.
It is recommended that you store cell embeddings into the `latent` slot of the original dataset object. This facilitates visualization via the [`visualize_latent()`](../modules/Cell_BLAST.data.html#Cell_BLAST.data.ExprDataSet.visualize_latent) method.

In [None]:
baron_human.latent = model.inference(baron_human)

In [None]:
ax = baron_human.visualize_latent("cell_ontology_class")

We see that different cell types can readily be distinguished.

This is the prerequisite to successful cell querying. If certain cell types are largely intermingled at this step, they are unlikely to be unambiguously predicted. It might be useful to reconsider the feature selection step or model hyperparameter setting (see function documentation of [`find_variable_genes()`](../modules/Cell_BLAST.data.html#Cell_BLAST.data.ExprDataSet.find_variable_genes) and [`fit_DIRECTi()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.fit_DIRECTi) for details).

You can also save the model for future use via the [`save()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.save) method. It is straightforward to load a saved model via the [`load()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.load) function.

In [None]:
model.save("./baron_human_model")
model.close()
del model
model = cb.directi.DIRECTi.load("./baron_human_model")

We can also project other datasets using the same model. Here we test with the Muraro dataset (*Muraro, M. et al., Cell Systems, 2016*), which also profiled human pancreatic islets.

Note that we also do not normalize the dataset or subset genes beforehand, as these are handled by the [`inference()`](../modules/Cell_BLAST.directi.html#Cell_BLAST.directi.DIRECTi.inference) method internally.

Also, there will be a warning saying that we have some genes missing in the new dataset, but it doesn't really matter as long as the number is small. Distinct cell types are still well separated.

In [None]:
muraro = cb.data.ExprDataSet.read_dataset("../../Datasets/data/Muraro/data.h5")
muraro.latent = model.inference(muraro)
ax = muraro.visualize_latent("cell_ontology_class")

## Batch effect correction

Now we demonstrate the function of batch effect correction by training models on a "meta-dataset" merged from multiple datasets (all profiling human pancreatic islets), where we expect significant batch effect among different datasets.

We first merge different datasets using the [`merge_datasets()`](../modules/Cell_BLAST.data.html#Cell_BLAST.data.ExprDataSet.merge_datasets) function.

In [None]:
cb.utils.logger.setLevel(logging.WARNING)  # Suppress a long list of genes not shared by all datasets
combined_dataset = cb.data.ExprDataSet.merge_datasets({
    "Baron_human": cb.data.ExprDataSet.read_dataset("../../Datasets/data/Baron_human/data.h5"),
    "Segerstolpe": cb.data.ExprDataSet.read_dataset("../../Datasets/data/Segerstolpe/data.h5"),
    "Muraro": cb.data.ExprDataSet.read_dataset("../../Datasets/data/Muraro/data.h5"),
    "Xin": cb.data.ExprDataSet.read_dataset("../../Datasets/data/Xin_2016/data.h5"),
    "Lawlor": cb.data.ExprDataSet.read_dataset("../../Datasets/data/Lawlor/data.h5")
}, meta_col="study", merge_uns_slots=["seurat_genes"])
cb.utils.logger.setLevel(logging.INFO)

The `meta_col` argument is used to specify a column ("study" in this case) that will be added to the `obs` data frame, which stores the dataset origination of each cell. This column serves as the batch indicator for cross-dataset batch effect.

In [None]:
combined_dataset.obs["study"]

The `merge_uns_slots` argument is specified to merge preselected genes stored in the `uns["seurat_genes"]` slot in individual datasets.

We first train a model on the "meta-dataset" without batch effect correction, and validate that significant batch effect exists among different datasets.

In [None]:
%%capture
model = cb.directi.fit_DIRECTi(
    combined_dataset, genes=combined_dataset.uns["seurat_genes"],
    latent_dim=10, cat_dim=20
)
combined_dataset.latent = model.inference(combined_dataset)

In [None]:
ax = combined_dataset.visualize_latent("cell_ontology_class")

In [None]:
ax = combined_dataset.visualize_latent("study")

To perform batch effect correction, specify `batch_effect` as a column in the `obs` slot corresponding to the batch indicator, which is "study" in this case.

In [None]:
%%capture
model_rmbatch = cb.directi.fit_DIRECTi(
    combined_dataset, genes=combined_dataset.uns["seurat_genes"],
    batch_effect="study", latent_dim=10, cat_dim=20
)
combined_dataset.latent = model_rmbatch.inference(combined_dataset)

We see that batch effect is largely removed in the embedding space. Cells of the same cell type from different studies are well aligned.

In [None]:
ax = combined_dataset.visualize_latent("study")

In [None]:
ax = combined_dataset.visualize_latent("cell_ontology_class")