# Introduction to stochastic variational inference in pyro

## Inferring coin bias with the beta-binomial model

In [1]:
import numpy as np
import pyro.distributions as dist
import pyro
from torch.distributions import constraints
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import torch

In [2]:
from tqdm import tqdm

In [3]:
import arviz as az

Specify the joint density 

In [4]:
def model(data):    
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    for i in range(len(data)):
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])  # Don't worry, we will vectorize this later

Specify the variational family

In [5]:
def guide(data):  # guide and model must have the same signature, despite us not actually using the data in this case
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0), # optimization initialization
                         constraint=constraints.positive, # constrained optimization
                        )  # requires_grad is automatically set to True
    beta_q = pyro.param("beta_q", torch.tensor(15.0), 
                        constraint=constraints.positive)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

Specify optimizer

In [6]:
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

In [7]:
d = torch.tensor(np.random.binomial(n=1, p=0.55, size=10), dtype=float)

In [8]:
n_steps = 1000

for step in tqdm(range(n_steps)):
    svi.step(d)

100%|██████████| 1000/1000 [00:04<00:00, 221.39it/s]


Grab the learned variational parameters

In [9]:
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

In [10]:
def summarize_beta(alpha, beta):
    inferred_mean = alpha / (alpha + beta)
    
    factor = beta / (alpha * (1.0 + alpha + beta))
    inferred_std = inferred_mean * np.sqrt(factor)
    return inferred_mean, inferred_std

In [11]:
summarize_beta(alpha_q, beta_q)

(0.4988920100706238, 0.08993969349039682)

## Conditional independence and subsampling

The objective is to not have to touch every data point during inference, but rather approximate the log likelihood with mini-batches. Let $\boldsymbol{x}$ denote a data vector of observations, and $\boldsymbol{z}$ denote a vector of latent random variables 

$$\sum_{i=1}^N \log p(\boldsymbol{x}_i | \boldsymbol{z}) \approx \frac{N}{M} \sum_{i \in \mathcal{I}_M}^N \log p(\boldsymbol{x}_i | \boldsymbol{z})$$

where $\mathcal{I}_M$ is a mini-batch of indices of size $M$. To do this, we require the variational family to be a **conditionally conjugate model**, see [Blei's review](https://arxiv.org/pdf/1601.00670.pdf).

The `pyro.plate` allows us to encode conditional independence in the model. Let's do that:

In [12]:
def model_vec(data):    
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    for i in pyro.plate("data_loop", len(data)):
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])  # this allows us to leverage conditional independence of the observations given the latent variables

Let's make this more efficient by:
- Vectorizing 
- Subsampling, so we can mini-batch

In [13]:
def model_vec_subsampled(data):    
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    with pyro.plate("observe_data", 
                    
                    # Size is required so that the correct scaling factor can be computed
                    size=len(data), 
                    
                    # We only evaluate the log likelihood for 5 randomly chosen datapoints in the data, 
                    # and the log likelihood will automatically get scaled by N/M
                    subsample_size=5,  
                                    
                    # set the device to use a GPU
                    # device = 
                    
                    # A stateful subsampling scheme may be necessary -- it is possible to never touch 
                    # some data points if the dataset is sufficiently large
                    # subsample = 
                   ) as ind:
        pyro.sample("obs", dist.Bernoulli(f), obs=data.index_select(0, ind))  # this will be a tensor of length 5  

In [14]:
svi = SVI(model_vec_subsampled, guide, optimizer, loss=Trace_ELBO())

In [15]:
n_steps = 1000

for step in tqdm(range(n_steps)):
    svi.step(d)

100%|██████████| 1000/1000 [00:01<00:00, 502.58it/s]


In [16]:
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

In [17]:
summarize_beta(alpha_q, beta_q)

(0.5049274781342739, 0.09002792446530777)

We may also have conditional independence in the variational distribution (the `guide`) too.

Let $\beta$ be a vector of global latent variables, which potentially govern any of the data. Let $z$ be a vector of local latent variables, whose $i$th component only governs data in the $i$th "context". The joint density of a conditionally conjugate model is:

$$p(\beta, \boldsymbol{z}, \boldsymbol{x}) = p(\beta) \prod_{i=1}^n p(z_i, x_i | \beta)$$

The variational family (according to the pyro docs, though I don't yet get how this gels with the review) should factorize like

$$p(\beta, \boldsymbol{z}) = p(\beta) \prod_{i=1}^n p(z_i | \beta, \lambda_i)$$

where $\lambda_i$ are local variational parameters (other variational parameters are left implicit). 

In [None]:
# TODO

## Amortization

In [None]:
# TODO