<a href="https://colab.research.google.com/github/cablanc/dlps/blob/master/pyro_ppca.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[TF intro to probabilistic programming](https://blog.tensorflow.org/2018/12/an-introduction-to-probabilistic.html)
[Probabilistic Programming and Bayesian Methods for Hackers](https://github.com/CamDavidsonPilon/Probabilistic-Programming-and-Bayesian-Methods-for-Hackers/tree/master/Chapter2_MorePyMC)

In [1]:
!pip3 install pyro-ppl 



In [2]:
!pip install visdom



In [3]:
mkdir vae_results

mkdir: cannot create directory ‘vae_results’: File exists


For a model with 𝑁 observations, running the model and guide and constructing the ELBO involves evaluating log pdf’s whose complexity scales badly with 𝑁. This is a problem if we want to scale to large datasets. Luckily, the ELBO objective naturally supports subsampling provided that our model/guide have some conditional independence structure that we can take advantage of. For example, in the case that the observations are conditionally independent given the latents, the log likelihood term in the ELBO can be approximated (see [Autoencoding Variational Bayes](https://arxiv.org/abs/1312.6114))

In Pyro, the model corresponds to p(x,z;θ) and the guide corresponds to q(z|x;ϕ)

Pyro calculates p(x, z) = p(x|z)p(z), where z is a sample from q(z)

In [4]:
import argparse
import errno
import os
from functools import reduce
import visdom

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from pyro.contrib.examples.util import MNIST, get_data_directory

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam

assert pyro.__version__.startswith('1.6.0')
# parse command line arguments
import sys

In [5]:
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

def plot_conditional_samples_ssvae(ssvae, visdom_session):
    """
    This is a method to do conditional sampling in visdom
    """
    vis = visdom_session
    ys = {}
    for i in range(10):
        ys[i] = torch.zeros(1, 10)
        ys[i][0, i] = 1
    xs = torch.zeros(1, 784)

    for i in range(10):
        images = []
        for rr in range(100):
            # get the loc from the model
            sample_loc_i = ssvae.model(xs, ys[i])
            img = sample_loc_i[0].view(1, 28, 28).cpu().data.numpy()
            images.append(img)
        vis.images(images, 10, 2)


def plot_llk(train_elbo, test_elbo):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import scipy as sp
    import seaborn as sns
    plt.figure(figsize=(30, 10))
    sns.set_style("whitegrid")
    data = np.concatenate([np.arange(len(test_elbo))[:, sp.newaxis], -test_elbo[:, sp.newaxis]], axis=1)
    df = pd.DataFrame(data=data, columns=['Training Epoch', 'Test ELBO'])
    g = sns.FacetGrid(df, size=10, aspect=1.5)
    g.map(plt.scatter, "Training Epoch", "Test ELBO")
    g.map(plt.plot, "Training Epoch", "Test ELBO")
    plt.savefig('./vae_results/test_elbo_vae.png')
    plt.close('all')


def plot_vae_samples(vae, visdom_session):
    vis = visdom_session
    x = torch.zeros([1, 784])
    for i in range(10):
        images = []
        for rr in range(100):
            # get loc from the model
            sample_loc_i = vae.model(x)
            img = sample_loc_i[0].view(1, 28, 28).cpu().data.numpy()
            images.append(img)
        vis.images(images, 10, 2)


def mnist_test_tsne(vae=None, test_loader=None):
    """
    This is used to generate a t-sne embedding of the vae
    """
    name = 'VAE'
    data = test_loader.dataset.test_data.float()
    mnist_labels = test_loader.dataset.test_labels
    z_loc, z_scale = vae.encoder(data)
    plot_tsne(z_loc, mnist_labels, name)


def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None):
    """
    This is used to generate a t-sne embedding of the ss-vae
    """
    if name is None:
        name = 'SS-VAE'
    data = test_loader.dataset.test_data.float()
    mnist_labels = test_loader.dataset.test_labels
    z_loc, z_scale = ssvae.encoder_z([data, mnist_labels])
    plot_tsne(z_loc, mnist_labels, name)


def plot_tsne(z_loc, classes, name):
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.manifold import TSNE
    model_tsne = TSNE(n_components=2, random_state=0)
    z_states = z_loc.detach().cpu().numpy()
    z_embed = model_tsne.fit_transform(z_states)
    classes = classes.detach().cpu().numpy()
    fig = plt.figure()
    for ic in range(10):
        ind_vec = np.zeros_like(classes)
        ind_vec[:, ic] = 1
        ind_class = classes[:, ic] == 1
        color = plt.cm.Set1(ic)
        plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color)
        plt.title("Latent Variable T-SNE per Class")
        fig.savefig('./vae_results/'+str(name)+'_embedding_'+str(ic)+'.png')
    fig.savefig('./vae_results/'+str(name)+'_embedding.png')


In [6]:
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.contrib.examples.util import MNIST

# This file contains utilities for caching, transforming and splitting MNIST data
# efficiently. By default, a PyTorch DataLoader will apply the transform every epoch
# we avoid this by caching the data early on in MNISTCached class
# https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/utils/mnist_cached.py


# transformations for MNIST data
def fn_x_mnist(x, use_cuda):
    # normalize pixel values of the image to be in [0,1] instead of [0,255]
    xp = x * (1. / 255)

    # transform x to a linear tensor from bx * a1 * a2 * ... --> bs * A
    xp_1d_size = reduce(lambda a, b: a * b, xp.size()[1:])
    xp = xp.view(-1, xp_1d_size)

    # send the data to GPU(s)
    if use_cuda:
        xp = xp.cuda()

    return xp


def fn_y_mnist(y, use_cuda):
    yp = torch.zeros(y.size(0), 10)

    # send the data to GPU(s)
    if use_cuda:
        yp = yp.cuda()
        y = y.cuda()

    # transform the label y (integer between 0 and 9) to a one-hot
    yp = yp.scatter_(1, y.view(-1, 1), 1.0)
    return yp


def get_ss_indices_per_class(y, sup_per_class):
    # number of indices to consider
    n_idxs = y.size()[0]

    # calculate the indices per class
    idxs_per_class = {j: [] for j in range(10)}

    # for each index identify the class and add the index to the right class
    for i in range(n_idxs):
        curr_y = y[i]
        for j in range(10):
            if curr_y[j] == 1:
                idxs_per_class[j].append(i)
                break

    idxs_sup = []
    idxs_unsup = []
    for j in range(10):
        np.random.shuffle(idxs_per_class[j])
        idxs_sup.extend(idxs_per_class[j][:sup_per_class])
        idxs_unsup.extend(idxs_per_class[j][sup_per_class:len(idxs_per_class[j])])

    return idxs_sup, idxs_unsup


def split_sup_unsup_valid(X, y, sup_num, validation_num=10000):
    """
    helper function for splitting the data into supervised, un-supervised and validation parts
    :param X: images
    :param y: labels (digits)
    :param sup_num: what number of examples is supervised
    :param validation_num: what number of last examples to use for validation
    :return: splits of data by sup_num number of supervised examples
    """

    # validation set is the last 10,000 examples
    X_valid = X[-validation_num:]
    y_valid = y[-validation_num:]

    X = X[0:-validation_num]
    y = y[0:-validation_num]

    assert sup_num % 10 == 0, "unable to have equal number of images per class"

    # number of supervised examples per class
    sup_per_class = int(sup_num / 10)

    idxs_sup, idxs_unsup = get_ss_indices_per_class(y, sup_per_class)
    X_sup = X[idxs_sup]
    y_sup = y[idxs_sup]
    X_unsup = X[idxs_unsup]
    y_unsup = y[idxs_unsup]

    return X_sup, y_sup, X_unsup, y_unsup, X_valid, y_valid


def print_distribution_labels(y):
    """
    helper function for printing the distribution of class labels in a dataset
    :param y: tensor of class labels given as one-hots
    :return: a dictionary of counts for each label from y
    """
    counts = {j: 0 for j in range(10)}
    for i in range(y.size()[0]):
        for j in range(10):
            if y[i][j] == 1:
                counts[j] += 1
                break
    print(counts)


class MNISTCached(MNIST):
    """
    a wrapper around MNIST to load and cache the transformed data
    once at the beginning of the inference
    """

    # static class variables for caching training data
    train_data_size = 50000
    train_data_sup, train_labels_sup = None, None
    train_data_unsup, train_labels_unsup = None, None
    validation_size = 10000
    data_valid, labels_valid = None, None
    test_size = 10000

    def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs):
        super().__init__(train=mode in ["sup", "unsup", "valid"], *args, **kwargs)

        # transformations on MNIST data (normalization and one-hot conversion for labels)
        def transform(x):
            return fn_x_mnist(x, use_cuda)

        def target_transform(y):
            return fn_y_mnist(y, use_cuda)

        self.mode = mode

        assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values"

        if mode in ["sup", "unsup", "valid"]:

            # transform the training data if transformations are provided
            if transform is not None:
                self.data = (transform(self.data.float()))
            if target_transform is not None:
                self.targets = (target_transform(self.targets))

            if MNISTCached.train_data_sup is None:
                if sup_num is None:
                    assert mode == "unsup"
                    MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \
                        self.data, self.targets
                else:
                    MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \
                        MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \
                        MNISTCached.data_valid, MNISTCached.labels_valid = \
                        split_sup_unsup_valid(self.data, self.targets, sup_num)

            if mode == "sup":
                self.data, self.targets = MNISTCached.train_data_sup, MNISTCached.train_labels_sup
            elif mode == "unsup":
                self.data = MNISTCached.train_data_unsup

                # making sure that the unsupervised labels are not available to inference
                self.targets = (torch.Tensor(
                    MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan
            else:
                self.data, self.targets = MNISTCached.data_valid, MNISTCached.labels_valid

        else:
            # transform the testing data if transformations are provided
            if transform is not None:
                self.data = (transform(self.data.float()))
            if target_transform is not None:
                self.targets = (target_transform(self.targets))

    def __getitem__(self, index):
        """
        :param index: Index or slice object
        :returns tuple: (image, target) where target is index of the target class.
        """
        if self.mode in ["sup", "unsup", "valid"]:
            img, target = self.data[index], self.targets[index]
        elif self.mode == "test":
            img, target = self.data[index], self.targets[index]
        else:
            assert False, "invalid mode: {}".format(self.mode)
        return img, target


def setup_data_loaders(dataset, use_cuda, batch_size, sup_num=None, root=None, download=True, **kwargs):
    """
        helper function for setting up pytorch data loaders for a semi-supervised dataset
    :param dataset: the data to use
    :param use_cuda: use GPU(s) for training
    :param batch_size: size of a batch of data to output when iterating over the data loaders
    :param sup_num: number of supervised data examples
    :param download: download the dataset (if it doesn't exist already)
    :param kwargs: other params for the pytorch data loader
    :return: three data loaders: (supervised data for training, un-supervised data for training,
                                  supervised data for testing)
    """
    # instantiate the dataset as training/testing sets
    if root is None:
        root = get_data_directory('./')
    if 'num_workers' not in kwargs:
        kwargs = {'num_workers': 0, 'pin_memory': False}

    cached_data = {}
    loaders = {}
    for mode in ["unsup", "test", "sup", "valid"]:
        if sup_num is None and mode == "sup":
            # in this special case, we do not want "sup" and "valid" data loaders
            return loaders["unsup"], loaders["test"]
        cached_data[mode] = dataset(root=root, mode=mode, download=download,
                                    sup_num=sup_num, use_cuda=use_cuda)
        loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, **kwargs)

    return loaders


def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise


EXAMPLE_DIR = os.path.dirname(os.path.abspath(os.path.join('./', os.pardir)))
DATA_DIR = os.path.join(EXAMPLE_DIR, 'data')
RESULTS_DIR = os.path.join(EXAMPLE_DIR, 'results')


In [7]:
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()


        # multivariate normal has unexpected behavior with
        # pyro.sample so we learn a "good enough" diagonal covariance matrix
        # as a function of self.sigma
        # softplus ensures the learned covariance is PSD
        self.Tau = nn.Linear(1, z_dim) # z_scale = self.softplus(self.Tau(self.sigma**2))
        self.softplus = nn.Softplus()
        # self.M_inv_tilde = nn.Linear(z_dim, z_dim)

        self.z_dim = z_dim
        
    def compute_M_inv(self, W, sigma):
        # compute true precision matrix M_inv
        # which is used to compute z_loc = M_inv W.T (x - mu)
        # (bottom) (p. 573) http://users.isr.ist.utl.pt/~wurmd/Livros/school/Bishop%20-%20Pattern%20Recognition%20And%20Machine%20Learning%20-%20Springer%20%202006.pdf

        # M = W.T W + sigma**2 I
        M = torch.einsum('ij,jk->ik', W.T, W) + sigma**2 * torch.eye(self.z_dim)
        # print(M.shape) # 50 x 50

        # The inverse can be computed using the following identity
        # (A + BD**(-1)C)**(-1) = A**(-1) − A**(-1)B(D + CA**(-1)B)**(-1)CA**(-1)
        # setting
        # B = W.T, D = I, C = W, A = sigma**2 I
        # we have
        # M_inv = (W.T W + sigma**2 I)**(-1) 
        # = (sigma**-2 I) − (sigma**-2 I)W.T(I + W(sigma**-2 I)W.T)**(-1) W(sigma**-2 I)

        first_term = sigma**(-2) * torch.eye(self.z_dim)
        # print(first_term.shape) # 50 x 50
        sigma_neg2_WT = torch.einsum('ij,jk->ik', first_term, W.T)
        # print(sigma_neg2_WT.shape) # 50 x 784
        W_sigma_neg2_WT = torch.einsum('ij,jk->ik', W, sigma_neg2_WT)
        # print(W_sigma_neg2_WT.shape) # 784 x 784
        I_plus_W_sigma_neg2_WT_inv = torch.linalg.inv(torch.eye(784) + W_sigma_neg2_WT)
        # print(I_plus_W_sigma_neg2_WT_inv.shape) # 784 x 784
        second_term_aux = torch.einsum('ij,jk->ik', sigma_neg2_WT, I_plus_W_sigma_neg2_WT_inv)
        # print(second_term_aux.shape) # 50 x 784
        second_term = torch.einsum('ij,jk->ik', second_term_aux, sigma_neg2_WT.T)
        # print(second_term.shape) # 50 x 50
        M_inv = first_term - second_term
        # print(self.M_inv.shape) # 50 x 50

        return M, M_inv
        
    def forward(self, x, W, mu, sigma):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.reshape(-1, 784)

        # compute z_loc 
        # (bottom) (p. 573) http://users.isr.ist.utl.pt/~wurmd/Livros/school/Bishop%20-%20Pattern%20Recognition%20And%20Machine%20Learning%20-%20Springer%20%202006.pdf
        x_minus_mu = x - mu
        # print(W.T.shape, x_minus_mu.shape) # 50x784, 256x784
        WT_x_minus_mu = torch.einsum('ij,bj->bi', W.T, x_minus_mu)
        # print(WT_x_minus_mu.shape) # 256x50
        M, M_inv = self.compute_M_inv(W, sigma)
        z_loc = torch.einsum('ij,bj->bi', M_inv, WT_x_minus_mu)
        # z_loc = self.M_inv_tilde(WT_x_minus_mu)

        # approximate covariance with learned diagonal covariance
        z_scale = sigma**(-2) * M
        # print(self.sigma.shape, sigma2_M.reshape(-1).shape)
        # z_scale = self.softplus(self.Tau(self.sigma**(-2)))

        return z_loc, z_scale

In [8]:
# define the PyTorch module that parameterizes the
# diagonal gaussian distribution q(z|x)
class Encode_nn(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # W : projection matrix
        # mu : mean of x (is this needed? images are normalized)
        # sigma : noise parameter
        self.W = torch.rand(784, z_dim, requires_grad=True)
        self.mu = torch.rand(784, requires_grad=True)
        self.sigma = torch.rand(1, requires_grad=True)

        # multivariate normal has unexpected behavior with
        # pyro.sample so we learn a "good enough" diagonal covariance matrix
        # as a function of self.sigma
        # softplus ensures the learned covariance is PSD
        self.Tau = nn.Linear(1, z_dim) # z_scale = self.softplus(self.Tau(self.sigma**2))
        self.softplus = nn.Softplus()
        # also approximate M_inv
        self.M_inv_tilde = nn.Linear(z_dim, z_dim)
        
    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.reshape(-1, 784)

        # compute z_loc 
        # (bottom) (p. 573) http://users.isr.ist.utl.pt/~wurmd/Livros/school/Bishop%20-%20Pattern%20Recognition%20And%20Machine%20Learning%20-%20Springer%20%202006.pdf
        x_minus_mu = x - self.mu
        # print(self.W.T.shape, x_minus_mu.shape) # 50x784, 256x784
        WT_x_minus_mu = torch.einsum('ij,bj->bi', self.W.T, x_minus_mu)
        # print(WT_x_minus_mu.shape) # 256x50
        z_loc = self.M_inv_tilde(WT_x_minus_mu)

        # approximate covariance with learned diagonal covariance
        z_scale = self.softplus(self.Tau(self.sigma**2))

        return z_loc, z_scale

In [9]:
# define the PyTorch module that parameterizes the
# observation likelihood p(x|z)
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # setup the two linear transformations used
        # self.W = torch.rand(784, z_dim, requires_grad=True)
        # self.mu = torch.rand(784, requires_grad=True)

        # self.fc21 = nn.Linear(hidden_dim, 784)
        # setup the non-linearities
        # self.softplus = nn.Softplus()

    def forward(self, z, W, mu):
        # define the forward computation on the latent z
        # first compute the hidden units
        # hidden = self.softplus(self.fc1(z))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        # loc_img = torch.sigmoid(self.fc21(hidden))
        x_minus_mu = torch.einsum('ij,bj->bi', W, z)
        loc_img = torch.sigmoid(x_minus_mu + mu)
        return loc_img

In [10]:
# define a PyTorch module for the VAE
class VAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)

        # W : projection matrix
        # mu : mean of x (is this needed? images are normalized)
        # sigma : noise parameter
        self.W = torch.randn(784, z_dim, requires_grad=True)
        self.mu = torch.zeros(784)
        self.sigma = torch.randn(1, requires_grad=True)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        pyro.param("W", self.W)
        # pyro.param("mu", self.mu)
        pyro.param("sigma", self.sigma)
        # For a model with 𝑁 observations, running the model and guide and constructing the ELBO 
        # involves evaluating log pdf’s whose complexity scales badly with 𝑁. 
        # This is a problem if we want to scale to large datasets. 
        # Luckily, the ELBO objective naturally supports subsampling provided that our 
        # model/guide have some conditional independence structure that we can take advantage of. 
        # For example, in the case that the observations are conditionally independent given the latents, 
        # the log likelihood term in the ELBO can be approximated with mini-batch
        # when there are only local random variables the scale factor introduced by subsampling scales
        # all the terms in the ELBO by the same amount. This explains why for the VAE it’s permissible
        # for the user to take complete control over subsampling and pass mini-batches directly to the model and guide
        # ; plate is still used, but subsample_size and subsample are not.
        with pyro.plate("data", x.shape[0]): 
            # setup hyperparameters for prior p(z)
            z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            z_scale = torch.eye(self.z_dim, dtype=x.dtype, device=x.device)
            # sample from prior (value will be sampled by guide when computing the ELBO)
            # normal = dist.Normal(z_loc, z_scale).to_event(1)
            # print(normal.event_shape, normal.batch_shape) # 50, 256
            # z = pyro.sample("latent", normal)
            # print(z.shape)
            mvn = dist.MultivariateNormal(z_loc, z_scale)
            # print(mvn.event_shape, mvn.batch_shape) # 50, 256
            z = pyro.sample("latent", mvn)
            # print(z_mvn.shape)

            # decode the latent code z
            loc_img = self.decoder.forward(z, self.W, self.mu)
            # score against actual images (with relaxed Bernoulli values)
            # the use of .to_event(1) when sampling from the latent z ensures 
            # that instead of treating our sample as being generated from a univariate normal 
            # with batch_size = z_dim, we treat them as being generated from a 
            # multivariate normal distribution with diagonal covariance.
            # As such, the log probabilities along each dimension is summed 
            # out when we evaluate .log_prob for a “latent” sample.
            pyro.sample("obs",
                        dist.Bernoulli(loc_img, validate_args=False)
                            .to_event(1),
                        obs=x.reshape(-1, 784))
            # return the loc so we can visualize it later
            return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x, self.W, self.mu, self.sigma)
            # print(z_loc.shape, z_scale.shape)
            # sample the latent code z
            # normal = dist.Normal(z_loc, z_scale).to_event(1)
            # pyro.sample("latent", normal)
            # print(normal.event_shape, normal.batch_shape) # 50, 256
            # z = pyro.sample("latent2", normal)
            # print(z.shape) # 256 x 50
            mvn = dist.MultivariateNormal(z_loc, z_scale)
            # print(mvn.event_shape, mvn.batch_shape) # 50, 256
            pyro.sample("latent", mvn)
            # print(z.shape)

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x, self.W, self.mu, self.sigma)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z, self.W, self.mu)
        return loc_img

In [11]:
# define a PyTorch module for the VAE
class VAE_normal(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        # For a model with 𝑁 observations, running the model and guide and constructing the ELBO 
        # involves evaluating log pdf’s whose complexity scales badly with 𝑁. 
        # This is a problem if we want to scale to large datasets. 
        # Luckily, the ELBO objective naturally supports subsampling provided that our 
        # model/guide have some conditional independence structure that we can take advantage of. 
        # For example, in the case that the observations are conditionally independent given the latents, 
        # the log likelihood term in the ELBO can be approximated with mini-batch
        # when there are only local random variables the scale factor introduced by subsampling scales
        # all the terms in the ELBO by the same amount. This explains why for the VAE it’s permissible
        # for the user to take complete control over subsampling and pass mini-batches directly to the model and guide
        # ; plate is still used, but subsample_size and subsample are not.
        with pyro.plate("data", x.shape[0]): 
            # setup hyperparameters for prior p(z)
            z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images (with relaxed Bernoulli values)
            # the use of .to_event(1) when sampling from the latent z ensures 
            # that instead of treating our sample as being generated from a univariate normal 
            # with batch_size = z_dim, we treat them as being generated from a 
            # multivariate normal distribution with diagonal covariance.
            # As such, the log probabilities along each dimension is summed 
            # out when we evaluate .log_prob for a “latent” sample.
            pyro.sample("obs",
                        dist.Bernoulli(loc_img, validate_args=False)
                            .to_event(1),
                        obs=x.reshape(-1, 784))
            # return the loc so we can visualize it later
            return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img

In [None]:
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)

def main(args):
    # clear param store
    pyro.clear_param_store()

    # setup MNIST data loaders
    # train_loader, test_loader
    train_loader, test_loader = setup_data_loaders(MNISTCached, use_cuda=args.cuda, batch_size=256)

    # setup the VAE
    vae = VAE(use_cuda=args.cuda)

    # setup the optimizer
    adam_args = {"lr": args.learning_rate}
    optimizer = Adam(adam_args)

    # setup the inference algorithm
    elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
    svi = SVI(vae.model, vae.guide, optimizer, loss=elbo)

    # training loop
    train_test_loop(args, train_loader, test_loader, svi, vae)

    return vae


def train_test_loop(args, train_loader, test_loader, svi, vae):
    # setup visdom for visualization
    if args.visdom_flag:
        vis = visdom.Visdom()
    else:
        vis = None

    train_elbo = []
    test_elbo = []

    for epoch in range(args.num_epochs):
        total_epoch_loss_train = train(train_loader, svi)
        train_elbo.append(total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch % args.test_frequency == 0:
            total_epoch_loss_test = evaluate(args, test_loader, svi, vae, vis)
            test_elbo.append(total_epoch_loss_test)
            print("[epoch %03d]  average test loss: %.4f" % (epoch, total_epoch_loss_test))

        if epoch == args.tsne_iter:
            mnist_test_tsne(vae=vae, test_loader=test_loader)
            plot_llk(np.array(train_elbo), np.array(test_elbo))


def train(train_loader, svi):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, _ in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if args.cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # report training diagnostics
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train

    return total_epoch_loss_train


def evaluate(args, test_loader, svi, vae, vis):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for i, (x, _) in enumerate(test_loader):
        # if on GPU put mini-batch into CUDA memory
        if args.cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)

        # pick three random test images from the first mini-batch and
        # visualize how well we're reconstructing them
        if i == 0:
            if args.visdom_flag:
              plot_vae_samples(vae, vis)
              reco_indices = np.random.randint(0, x.shape[0], 3)
              for index in reco_indices:
                  test_img = x[index, :]
                  reco_img = vae.reconstruct_img(test_img)
                  vis.image(test_img.reshape(28, 28).detach().cpu().numpy(),
                            opts={'caption': 'test image'})
                  vis.image(reco_img.reshape(28, 28).detach().cpu().numpy(),
                            opts={'caption': 'reconstructed image'})

    # report test diagnostics
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test

    return total_epoch_loss_test
    


sys.argv = ['foo']

parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=101, type=int, help='number of training epochs')
parser.add_argument('-tf', '--test-frequency', default=10, type=int, help='how often we evaluate the test set')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
parser.add_argument('--cuda', action='store_true', default=False, help='whether to use cuda')
parser.add_argument('--jit', action='store_true', default=False, help='whether to use PyTorch jit')
parser.add_argument('-visdom', '--visdom_flag', action="store_true", help='Whether plotting in visdom is desired')
parser.add_argument('-i-tsne', '--tsne_iter', default=100, type=int, help='epoch when tsne visualization runs')
args = parser.parse_args()

model = main(args)


[epoch 000]  average training loss: 38471.2846
[epoch 000]  average test loss: 27884.1285
[epoch 001]  average training loss: 22786.5611
[epoch 002]  average training loss: 16355.0928
[epoch 003]  average training loss: 12781.8993
[epoch 004]  average training loss: 10461.0182
[epoch 005]  average training loss: 8726.2500
[epoch 006]  average training loss: 7265.8326
[epoch 007]  average training loss: 5765.6455
[epoch 008]  average training loss: 4016.2855
[epoch 009]  average training loss: 2318.3348
[epoch 010]  average training loss: 1312.0207
[epoch 010]  average test loss: 1028.9420


In [None]:
from IPython.display import Image
Image('vae_results/VAE_embedding_8.png')

In [None]:
!ls vae_results