In [87]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader 
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam

## Classical VAE

In [90]:
# VAE encoder
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        hidden = torch.relu(self.fc1(x))
        z_loc = self.fc21(hidden) # mean vector
        z_scale = torch.exp(0.5*self.fc22(hidden)) # covariance vector
        return z_loc, z_scale

In [92]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)

    def forward(self, z):
        hidden = torch.relu(self.fc1(z))
        loc_img = torch.sigmoid(self.fc21(hidden))
        return loc_img

In [94]:
class VAE(nn.Module):
    def __init__(self, z_dim=10, hidden_dim=128):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # prior p(z) as standard normal
            z_loc = torch.zeros(x.shape[0], self.z_dim)
            z_scale = torch.ones(x.shape[0], self.z_dim)
            # sample from prior
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            
            pyro.sample(
                "obs",
                dist.Bernoulli(loc_img, validate_args=False).to_event(1),
                obs=x.reshape(-1, 784),
            )
            
            return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image
        loc_img = self.decoder(z)
        return loc_img

In [96]:
#input_dim = 784
lr = 0.01
batch_size = 128
epochs = 10

train_data = datasets.MNIST(
    root='./data', 
    train=True, 
    transform=transforms.ToTensor(), 
    download=True)

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True
)

vae = VAE()
optimizer = Adam({"lr": lr})
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())


In [98]:
#training 
for epoch in range(epochs):
    train_loss = 0
    for x, _ in train_dataloader:
        x = x.view(-1, 28*28)
        train_loss += svi.step(x)
        
    train_loss /= len(train_dataloader.dataset)
    print(f"Epoch {epoch + 1}, Loss: {train_loss}")

Epoch 1, Loss: 546.0059057373047
Epoch 2, Loss: 545.9868314229329
Epoch 3, Loss: 545.990926147461
Epoch 4, Loss: 546.0077961425782
Epoch 5, Loss: 545.9942545776368
Epoch 6, Loss: 546.0172881734212
Epoch 7, Loss: 545.9797055480957
Epoch 8, Loss: 545.9897087402344
Epoch 9, Loss: 545.9915152526855
Epoch 10, Loss: 545.9966286946615


## SS VAE (M2 model)

In [None]:
def model(self, xs, ys=None):
    pyro.module("ss_vae", self)
    batch_size = xs.size(0)

    with pyro.plate("data"):

        # sample z
        prior_loc = xs.new_zeros([batch_size, self.z_dim])
        prior_scale = xs.new_ones([batch_size, self.z_dim])
        zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

        # if the label y (which digit to write) is supervised, sample from the
        # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
        alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
        ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

        # finally, score the image (x) using the handwriting style (z) and
        # the class label y (which digit to write) against the
        # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
        # where `decoder` is a neural network
        loc = self.decoder([zs, ys])
        pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)

def guide(self, xs, ys=None):
    with pyro.plate("data"):
        # if the class label (the digit) is not supervised, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            alpha = self.encoder_y(xs)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha))

        # sample (and score) the latent handwriting-style with the variational
        # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
        loc, scale = self.encoder_z([xs, ys])
        pyro.sample("z", dist.Normal(loc, scale).to_event(1))

In [102]:
xxx = torch.tensor((), dtype=torch.int32)
xxx.new_ones([16, 10]) / (1.0 * 10)

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0