# scANVI Pyro Tutorial

This notebook follows the Pyro tutorial on implementing a bare-bones version of the scANVI model: https://pyro.ai/examples/scanvi.html

scANVI (https://doi.org/10.15252/msb.20209620) is a method based on conditional variational autoencoders (CVAEs) that learns cell state representations from single-cell RNA-seq data. It is a semi-supervised method that uses cell type labels when available. In this notebook, we will approximately reproduce Figure 6 in the scANVI paper.

Note that the notation used below is somewhat different than that from scANVI. Below, "y" represents the cell class label and "z1" is the distribution describing cell characteristics within that label; the corresponding variables are called "c" and "u" respectively in scANVI. "z2" below is "z" in scANVI, and rather than elaborate the rest of the scANVI variables, the expression level "x" (same in both) is drawn directly as a zero-inflated negative binomial. (scANVI elaborates the underlying gamma and poisson variables and what they represent.)

### Setup and data preprocessing

We use the scvi-tools package to download some PBMC scRNA-seq data.

In [1]:
# setup environment
import os
smoke_test = ('CI' in os.environ)  # for continuous integration tests

In [2]:
# various import statements
import numpy as np
import scanpy as sc

import torch
import torch.nn as nn
from torch.nn.functional import softplus, softmax
from torch.distributions import constraints
from torch.optim import Adam

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.util import broadcast_shape
from pyro.optim import MultiStepLR
from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO
from pyro.contrib.examples.scanvi_data import get_data

import matplotlib.pyplot as plt
from matplotlib.patches import Patch

In [3]:
# Download and pre-process data
batch_size = 100
if not smoke_test:
    # dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='pbmc', cuda=True, batch_size=batch_size)
    dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='pbmc', cuda=False, batch_size=batch_size)
else:
    dataloader, num_genes, l_mean, l_scale, anndata = get_data(dataset='mock')

[34mINFO    [0m Downloading file at data/PurifiedPBMCDataset.h5ad                                                         
Downloading...: 157054it [00:04, 35066.22it/s]                              


In [4]:
# get some basic info about the data
print("Count data matrix shape:", dataloader.data_x.shape)
print("Mean counts per cell: {:.1f}".format(dataloader.data_x.sum(-1).mean().item()))
print("Number of labeled cells:", dataloader.num_labeled)

Count data matrix shape: torch.Size([20000, 21932])
Mean counts per cell: 1418.6
Number of labeled cells: 200


In [5]:
# define some helper functions for reshaping tensors and 
#   making fully-connected neural networks (see scANVI paper to see where these are used)

# Helper for making fully-connected neural networks
def make_fc(dims):
    layers = []
    for in_dim, out_dim in zip(dims, dims[1:]):
        layers.append(nn.Linear(in_dim, out_dim))
        layers.append(nn.BatchNorm1d(out_dim))
        layers.append(nn.ReLU())
    return nn.Sequential(*layers[:-1])  # Exclude final ReLU non-linearity

# Splits a tensor in half along the final dimension
def split_in_half(t):
    return t.reshape(t.shape[:-1] + (2, -1)).unbind(-2)

# Helper for broadcasting inputs to neural net
def broadcast_inputs(input_args):
    shape = broadcast_shape(*[s.shape[:-1] for s in input_args]) + (-1,)
    input_args = [s.expand(shape) for s in input_args]
    return input_args

#### Model and guide sketches

Before specifying the full model, we write some code to illustrate its high-level structure. Refer to the link above for more details.

We also give a sketch of the guide -- the variational distribution of the parameters.

In [7]:
# Note that this is only a sketch and will not run, since things like z2_decoder and x_decoder are not defined
def model_sketch(x, y=None):
    # This gene-level parameter (theta) modulates the variance of the observation distribution for our vector of counts x
    # It is the gamma variance parameter and controls the level of over-dispersion in the ZINB distribution of x
    theta = pyro.param("inverse_dispersion", 10.0 * torch.ones(num_genes), constraint=constraints.positive)

    # This plate statement encodes that each datapoint (i.e. cell count vector x_i)
    #   is conditionally independent given its own latent variables.
    with pyro.plate("batch", len(x)):
        # Define a unit normal prior for z1 (aka "u", the cell specific params conditional on cell type y)
        # Remember, to_event(1) causes these to be sampled all at once as an MVN with identity covariance
        z1 = pyro.sample("z1", dist.Normal(0, torch.ones(latent_dim)).to_event(1))

        # Define a uniform categorical prior for y (cell type).
        # Note that (via obs=y) if y is None (i.e. y is unobserved) they y will be sampled; otherwise y will be treated as observed (via obs=y).
        y = pyro.sample("y", dist.OneHotCategorical(logits=torch.zeros(num_labels)), obs=y)

        # pass z1 and y to the z2 decoder neural network, which "decodes" these latents to generate z2, 
        #   the cell-specific params (z in original scANVI)
        z2_loc, z2_scale = z2_decoder(z1, y)
        # Define the prior distribution for z2. The parameters of this distribution depend on both z1 and y.
        z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

        # Define a LogNormal prior distribution for log count variable l (capturing library size, capture efficiency, etc)
        l = pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1))

        # We now construct the observation distribution. To do this we first pass z2 to the x decoder neural network,
        #   which "decodes" the latent cell state z2 into the ZINB-distributed observed counts x.
        gate_logits, mu = x_decoder(z2)
        # Using the outputs of the neural network we can define the parameters
        # of our ZINB observation distribution.
        # Note that by construction mu is normalized (i.e. mu.sum(-1) == 1) and the
        # total scale of counts for each cell is determined by the latent variable ℓ.
        # That is, `l * mu` is a G-dimensional vector of mean gene counts.
        nb_logits = (l * mu).log() - theta.log()
        x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta, logits=nb_logits)

        # Observe the datapoint x using the observation distribution x_dist
        pyro.sample("x", x_dist.to_event(1), obs=x)


In [9]:
# This is a sketch of the guide specifying the variational distribution; again, it will not run since some things are undefined
def guide_sketch(self, x, y=None):
    # This plate statement matches the plate in the model
    with pyro.plate("batch", len(x)):
        # We pass the observed count vector x to an encoder network that generates the paramaters we use to define
        #   the variational distributions for the latent variables z2 (cell state parameters) and l (size factors). 
        z2_loc, z2_scale, l_loc, l_scale = z2l_encoder(x)
        pyro.sample("l", dist.LogNormal(l_loc, l_scale).to_event(1))
        z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

        # We only need to specify a variational distribution over y if y is unobserved
        if y is None:
            # We use the "classifier" neural netowrk to turn the latent "code" z2 into 
            #   logits that we can use to specify a distribution over y.
            y_logits = classifier(z2)
            y_dist = dist.OneHotCategorical(logits=y_logits)
            y = pyro.sample("y", y_dist)
        
        # Finally we generate the parameters for the z1 distribution by passing z2 and y through an encoder neural network z1_encoder.
        z1_loc, z1_scale = z1_encoder(z2, y)
        pyro.sample("z1", dist.Normal(z1_loc, z1_scale).to_event(1))



### Encoder and Decoder networks
todo

### scANVI model
todo

### Training
todo

### Plotting / Visualizing Results
todo