## Batch-correction using scVI

scVI is a deep generative model that has been developed for probabilistic representation of scRNA-seq data and performs well in both harmonization and harmonization-based annotation, going beyond just correcting batch effects. 

The inference of the model is done using neural networks, stochastic optimization, and variational inference and scales to millions of cells and multiple datasets. Furthermore, scVI provides a complete probabilistic representation of the data, which non-linearly controls not only for sample-to-sample bias, but also for other technical factors of variation such as over-dispersion, variable library size, and zero-inflation.

Goals:

* Setting up and downloading datasets
* Performing data harmonization with scVI

In [1]:
import os
import numpy as np
import numpy.random as random
import pandas as pd

from scvi.dataset.dataset import GeneExpressionDataset
from scvi.dataset.csv import CsvDataset
from scvi.inference import UnsupervisedTrainer
from scvi.models import SCANVI, VAE
from scvi.inference.autotune import auto_tune_scvi_model

from umap import UMAP

import torch
import scanpy as sc
import louvain

import logging
import pickle
from hyperopt import hp


# %matplotlib inline

use_cuda = True
n_epochs_all = None
save_path = ''
show_plot = True
os.chdir(path = "/Users/janihuuh/Dropbox/gvhd_scrnaseq/")

[2019-11-18 17:57:41,763] INFO - scvi._settings | Added StreamHandler with custom formatter to 'scvi' logger.


## Download the data

In [2]:
gvhd_2013 = CsvDataset(filename='results/scvi/gvhd_2013.csv', save_path='', sep=',', new_n_genes=False)
gvhd_2017 = CsvDataset(filename='results/scvi/gvhd_2017.csv', save_path='', sep=',', new_n_genes=False)

## Combine
all_dataset = GeneExpressionDataset()
all_dataset.populate_from_per_batch_list(Xs = [gvhd_2013.X, gvhd_2017.X])

[2019-11-18 17:58:32,809] INFO - scvi.dataset.csv | Preprocessing dataset
[2019-11-18 17:59:05,095] INFO - scvi.dataset.csv | Finished preprocessing dataset
[2019-11-18 17:59:05,961] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2019-11-18 17:59:05,962] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2019-11-18 17:59:06,387] INFO - scvi.dataset.dataset | Computing the library size for the new data
[2019-11-18 17:59:06,812] INFO - scvi.dataset.dataset | Downsampled from 10589 to 10589 cells
[2019-11-18 17:59:06,895] INFO - scvi.dataset.dataset | Not subsampling. Expecting: 1 < (new_n_genes=False) <= self.nb_genes
[2019-11-18 17:59:06,896] INFO - scvi.dataset.csv | Preprocessing dataset
[2019-11-18 17:59:15,645] INFO - scvi.dataset.csv | Finished preprocessing dataset
[2019-11-18 17:59:16,031] INFO - scvi.dataset.dataset | Remapping labels to [0,N]
[2019-11-18 17:59:16,032] INFO - scvi.dataset.dataset | Remapping batch_indices to [0,N]
[2019-11-18 17:59:16,131]

## Define and build the model

In [3]:
vae = VAE(all_dataset.nb_genes, 
          n_batch=all_dataset.n_batches, 
          n_labels=all_dataset.n_labels,
          n_hidden=128, 
          n_latent=30, 
          n_layers=2, 
          dispersion='gene')

trainer  = UnsupervisedTrainer(vae, all_dataset, train_size=1.0)
n_epochs = 100 if n_epochs_all is None else n_epochs_all
trainer.train(n_epochs=n_epochs)

training: 100%|██████████| 100/100 [1:50:53<00:00, 66.54s/it]


In [4]:
full = trainer.create_posterior(trainer.model, all_dataset, indices=np.arange(len(all_dataset)))
latent, batch_indices, labels = full.sequential().get_latent()
batch_indices = batch_indices.ravel()

In [5]:
torch.save(trainer.model.state_dict(), "results/scvi/harmonization.vae.allgenes.30.model.pkl")
np.savetxt("results/scvi/batch_latent_best.csv", latent, delimiter=",")
np.savetxt("results/scvi/batch_indices_best.csv", batch_indices, delimiter=",")

## Visualize the latent space

In [6]:
latent_u = UMAP(spread=2).fit_transform(latent)

The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.
[1m
File "../../opt/anaconda3/lib/python3.7/site-packages/umap/rp_tree.py", line 135:[0m
[1m@numba.njit(fastmath=True, nogil=True, parallel=True)
[1mdef euclidean_random_projection_split(data, indices, rng_state):
[0m[1m^[0m[0m
[0m
  state.func_ir.loc))
The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see http://numba.pydata.org/numba-doc/latest/user/parallel.html#diagnostics for help.
[1m
File "../../opt/anaconda3/lib/python3.7/site-packages/umap/utils.py", line 409:[0m
[1m@numba.njit(parallel=True)
[1mdef build_candidates(current_graph, n_vertices, n_neighbors, max_candidates, rng_state):
[0m[1m^[0m[0m
[0m

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
[1m[1margs
[1m
File "../../opt/anaconda3/lib/python3.7/site-packages/umap/umap_.py", line 331:[0m
[1mdef compute_membership_strengths(knn_indices, knn_dists, sigmas, rhos):
    <source elided>
    rows = np.zeros((n_samples * n_neighbors), dtype=np.int64)
[1m    cols = np.zeros((n_samples * n_neighbors), dtype=np.int64)
[0m    [1m^[0m[0m
[0m
[0m[1m[1] During: lowering "id=1[LoopNest(index_variable = parfor_index.294, range = (0, $0.22, 1))]{281: <ir.Block at /Users/janihuuh/opt/anaconda3/lib/python3.7/site-packages/umap/umap_.py (331)>}Var(parfor_index.294, /Users/janihuuh/opt/anaconda3/lib/python3.7/site-packages/umap/umap_.py (331))" at /Users/janihuuh/opt/anaconda3/lib/python3.7/site-packages/umap/umap_.py (331)[0m

-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.
You are currently using Numba version 0.46.0.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!



In [None]:
cm = LinearSegmentedColormap.from_list(
        'my_cm', ['deepskyblue', 'hotpink'], N=2)
fig, ax = plt.subplots(figsize=(5, 5))
order = np.arange(latent.shape[0])
random.shuffle(order)
ax.scatter(latent_u[order, 0], latent_u[order, 1], 
           c=all_dataset.batch_indices.ravel()[order], 
           cmap=cm, edgecolors='none', s=5)    
plt.axis("off")
fig.set_tight_layout(True)