In [None]:
import os 

import numpy as np 
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt 
import utils as u

from torch.distributions import constraints 
import pyro.contrib.autoguide as autoguide
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, SGD

We will first create the dataset for which we will construct the distribution. This dataset can be observed below. 

In [None]:
radial_std = 0.32
tangential_std = 0.1
num_classes = 2
num_per_class = 1000
rate = 2
data = u.make_pinwheel_data(radial_std,tangential_std,num_classes,num_per_class,rate)

#mean = -100 * torch.ones((num_per_class,2))
#variance = 0.1 * torch.ones((num_per_class,2 ))
#data = dist.Normal(mean,variance).sample()


In [None]:
ax = plt.subplot(111)
ax.scatter(data[:,0],data[:,1])


In [None]:
class Decoder(nn.Module): 
    def __init__(self, z_dim, x_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, x_dim)
        self.fcad = nn.Linear(hidden_dim, hidden_dim)
        #Gives parameters for the cholesky deocomposition
        self.fc22 = nn.Linear(hidden_dim, 1)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.tanh = nn.Softplus()
        self.x_dim = x_dim
        self.z_dim = z_dim
        
    def forward(self, z): 
        hidden = self.tanh(self.fc1(z))
        hidden = self.tanh(self.fc11(hidden))
        hidden = self.tanh(self.fcad(hidden))
        x_loc = self.fc21(hidden)

        x_scale = torch.exp(self.fc22(hidden))
        return x_loc , x_scale

class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(x_dim,hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.softplus = nn.Softplus()

    def forward(self, x): 
        hidden = self.softplus(self.fc1(x))
        hidden = self.softplus(self.fc11(hidden))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))

        return z_loc, z_scale



In [None]:
class VAE(nn.Module): 
    def __init__(self, x_dim=2, z_dim=2, hidden_dim=2, use_cuda=False): 
        super().__init__()
        self.encoder = Encoder(x_dim,z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, x_dim, hidden_dim)
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.x_dim = x_dim


    def model(self, x):
        pyro.module('decoder', self.decoder)
        scale = pyro.param('scale', torch.tensor(1.), constraints.interval(0.2,10))
        with pyro.plate('data', x.shape[0]):
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale =  x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            z = pyro.sample('latent', dist.Normal(z_loc,z_scale).to_event(1))
            loc_img, var_img = self.decoder(z)
            var_img = torch.eye(x.shape[1]) * scale
            out = pyro.sample('obs', dist.MultivariateNormal(loc_img, var_img), obs=x)


    def guide(self, x):
        pyro.module('encoder', self.encoder)
        with pyro.plate('data', x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            out = pyro.sample('latent', dist.Normal(z_loc,z_scale).to_event(1))


    def sample_latent(self,num_samples):
        zeros = torch.zeros([num_samples, self.z_dim])
        ones = torch.ones([num_samples,self. z_dim])
        z = dist.Normal(zeros, ones).sample()
        mean, var = self.decoder(z)
        var = torch.eye(self.x_dim) * pyro.param('scale')
        res = pyro.sample('results', dist.MultivariateNormal(mean,var))
        return res

    def reconstruct(self, x): 
        mean, var = self.encoder(x)
        z = dist.Normal(mean, var).sample()
        mean, var = self.decoder(z)
        var = torch.eye(self.x_dim) * pyro.param('scale')
        res = dist.MultivariateNormal(mean, var).sample()
        return res





In [None]:
vae = VAE()

In [None]:
datat = torch.tensor(data).float()
res = vae.model(datat).detach().numpy()
plt.scatter(res[:,0],res[:,1])

In [None]:
data = torch.tensor(data).float()
train = int(len(data) * 0.8)
train_loader = torch.utils.data.DataLoader(data[:train],len(data[:train]),shuffle=True)
test_loader = torch.utils.data.DataLoader(data[train:],200)

In [None]:
def train(svi, train_loader, use_cuda=False): 
    epoch_loss = 0

    for minibatch in train_loader: 
        if use_cuda:
            minibatch = minibatch.cuda()
        epoch_loss = svi.step(minibatch)

    normalizer_train = len(train_loader.dataset)

    total_epoch_loss_train = epoch_loss/normalizer_train
    return total_epoch_loss_train


In [None]:
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [None]:
LR = 0.01
USE_CUDA = False

NUM_EPOCHS = 20000
TEST_FREQUENCY = 50

In [None]:
pyro.clear_param_store()
vae = VAE(hidden_dim=20)

adam_args = {'lr':LR}
optimizer = torch.optim.Adam
scheduler = pyro.optim.StepLR({ 'step_size' : 5000, 'gamma':0.01,'verbose':False, 'optimizer':optimizer,'optim_args':adam_args})

guide = autoguide.AutoDiagonalNormal(vae.model)
svi = SVI(vae.model, vae.guide, scheduler,loss=Trace_ELBO())
train_elbo = []
test_elbo = []

for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(total_epoch_loss_train)

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
        scheduler.step()

    if epoch % TEST_FREQUENCY == 0: 
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))


In [None]:
plt.plot((np.log(np.array(train_elbo))))

In [None]:
#real = data.detach().numpy()
real = data
sampled = vae.sample_latent(10000).detach().numpy()
plt.scatter(sampled[:,0],sampled[:,1],c='r', alpha =0.04)
plt.scatter(real[:,0],real[:,1],c='b',alpha=0.05)


In [None]:
reconstructed = vae.reconstruct(datat).detach().numpy()
color = real[:,1]/np.abs(real[:,1])
plt.scatter(reconstructed[:,0],reconstructed[:,1])
plt.scatter(real[:,0],real[:,1])

# Second attempt using a full bayesian specification. 

In [None]:
pyro.param('scale')

In [None]:
a.repeat(1,10)

In [None]:
a = torch.rand(2,2,2)
print(a)

In [None]:
a[torch.tensor([0,0]),torch.tensor([0,1])]