# scANVI Pyro Tutorial

This notebook follows the Pyro tutorial on implementing a bare-bones version of the scANVI model: https://pyro.ai/examples/scanvi.html

scANVI (https://doi.org/10.15252/msb.20209620) is a method based on conditional variational autoencoders (CVAEs) that learns cell state representations from single-cell RNA-seq data. It is a semi-supervised method that uses cell type labels when available. In this notebook, we will approximately reproduce Figure 6 in the scANVI paper.

### Setup and data preprocessing

We use the scvi-tools package to download some PBMC scRNA-seq data.

In [1]:
# setup environment
import os
smoke_test = ('CI' in os.environ)  # for continuous integration tests

In [9]:
# various import statements
import numpy as np
import scanpy as sc

import torch
import torch.nn as nn
from torch.nn.functional import softplus, softmax
from torch.distributions import constraints
from torch.optim import Adam

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape
from pyro.optim import MultiStepLR
from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO
from pyro.contrib.examples.scanvi_data import get_data

import matplotlib.pyplot as plt
from matplotlib.patches import Patch

In [4]:
# Download and pre-process data
batch_size = 100
if not smoke_test:
    # dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='pbmc', cuda=True, batch_size=batch_size)
    dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='pbmc', cuda=False, batch_size=batch_size)
else:
    dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='mock')

[34mINFO    [0m File data/PurifiedPBMCDataset.h5ad already downloaded                                                     


In [5]:
# get some basic info about the data
print("Count data matrix shape:", dataloader.data_x.shape)
print("Mean counts per cell: {:.1f}".format(dataloader.data_x.sum(-1).mean().item()))
print("Number of labeled cells:", dataloader.num_labeled)

Count data matrix shape: torch.Size([20000, 21932])
Mean counts per cell: 1418.6
Number of labeled cells: 200


In [6]:
# define some helper functions for reshaping tensors and 
#   making fully-connected neural networks (see scANVI paper to see where these are used)

# Helper for making fully-connected neural networks
def make_fc(dims):
    layers = []
    for in_dim, out_dim in zip(dims, dims[1:]):
        layers.append(nn.Linear(in_dim, out_dim))
        layers.append(nn.BatchNorm1d(out_dim))
        layers.append(nn.ReLU())
    return nn.Sequential(*layers[:-1])  # Exclude final ReLU non-linearity

# Splits a tensor in half along the final dimension
def split_in_half(t):
    return t.reshape(t.shape[:-1] + (2, -1)).unbind(-2)

# Helper for broadcasting inputs to neural net
def broadcast_inputs(input_args):
    shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
    input_args = [s.expand(shape) for s in input_args]
    return input_args

### Encoder and Decoder networks
todo

### scANVI model
todo

### Training
todo

### Plotting / Visualizing Results
todo