In [None]:
import os 

import numpy as np 
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt 
import utils as u

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, SGD

We will first create the dataset for which we will construct the distribution. This dataset can be observed below. 

In [None]:
radial_std = 0.32
tangential_std = 0.1
num_classes = 1
num_per_class = 1000
rate = 0.5
data = u.make_pinwheel_data(radial_std,tangential_std,num_classes,num_per_class,rate)


In [None]:
ax = plt.subplot(111)
ax.scatter(data[:,0],data[:,1])
data = torch.Tensor(data)

In [None]:
class Decoder(nn.Module): 
    def __init__(self, z_dim, x_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, x_dim)
        self.fc22 = nn.Linear(hidden_dim, x_dim)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.tanh = nn.Tanh()
        
    def forward(self, z): 
        hidden = self.tanh(self.fc1(z))
        hidden = self.tanh(self.fc11(hidden))
        x_loc = self.fc21(hidden)
        x_scale = torch.exp(self.fc22(hidden))
        return x_loc, x_scale

class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(x_dim,hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.softplus = nn.Tanh()

    def forward(self, x): 
        hidden = self.softplus(self.fc1(x))
        hidden = self.softplus(self.fc11(hidden))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))

        return z_loc, z_scale



In [None]:
class VAE(nn.Module): 
    def __init__(self, x_dim=2, z_dim=2, hidden_dim=2, use_cuda=False): 
        super().__init__()
        self.encoder = Encoder(x_dim,z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, x_dim, hidden_dim)
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.x_dim = x_dim


    def model(self, x):
        pyro.module('decoder', self.decoder)
        with pyro.plate('data', x.shape[0]):
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = 5 * x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            z = pyro.sample('latent', dist.Normal(z_loc,z_scale).to_event(1))
            loc_img, var_img = self.decoder(z)
            out = pyro.sample('obs', dist.Normal(loc_img, var_img).to_event(1), obs=x)


    def guide(self, x):
        pyro.module('encoder', self.encoder)
        with pyro.plate('data', x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            out = pyro.sample('latent', dist.Normal(z_loc,z_scale).to_event(1))


    def sample_latent(self,num_samples):
        zeros = torch.zeros([num_samples, self.z_dim])
        ones = torch.ones([num_samples,self. z_dim])
        z = dist.Normal(zeros, ones).sample()
        mean, var = self.decoder(z)
        res = pyro.sample('results', dist.Normal(mean,var))
        return res

    def reconstruct(self, x): 
        mean, var = self.encoder(x)
        z = dist.Normal(mean, var).sample()
        mean, var = self.decoder(z)
        res = dist.Normal(mean, var).sample()
        return mean





In [None]:
vae = VAE()

In [None]:
res = vae.model(data).detach().numpy()
plt.scatter(res[:,0],res[:,1])

In [None]:
data = torch.tensor(data)
train = int(len(data) * 0.8)
train_loader = torch.utils.data.DataLoader(data[:train],400,shuffle=True)
test_loader = torch.utils.data.DataLoader(data[train:],200)

In [None]:
def train(svi, train_loader, use_cuda=False): 
    epoch_loss = 0

    for minibatch in train_loader: 
        if use_cuda:
            minibatch = minibatch.cuda()
        epoch_loss = svi.step(minibatch)

    normalizer_train = len(train_loader.dataset)

    total_epoch_loss_train = epoch_loss/normalizer_train
    return total_epoch_loss_train


In [None]:
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [None]:
LR = 0.0005
USE_CUDA = False

NUM_EPOCHS = 10000
TEST_FREQUENCY = 20

In [None]:
pyro.clear_param_store()
vae = VAE(z_dim=20)

adam_args = {'lr':LR}
optimizer = Adam(adam_args)

svi = SVI(vae.model, vae.guide, optimizer,loss=Trace_ELBO())
train_elbo = []
test_elbo = []

for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(total_epoch_loss_train)

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

    if epoch % TEST_FREQUENCY == 0: 
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))


In [None]:
plt.plot(train_elbo)

In [None]:
sampled = vae.sample_latent(1000).detach().numpy()
plt.scatter(sampled[:,0],sampled[:,1])

In [None]:
reconstructed = vae.reconstruct(data).detach().numpy()
plt.scatter(reconstructed[:,0],reconstructed[:,1])

# Second attempt using a full bayesian specification. 