In [None]:
import os 

import numpy as np 
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt 
import utils as u


from torch.distributions import constraints 
import pyro.contrib.autoguide as autoguide
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, Trace_ELBO, config_enumerate, infer_discrete
from pyro.optim import Adam, SGD
from pyro import poutine
import pyro

We will first create the dataset for which we will construct the distribution. This dataset can be observed below. 

In [None]:
radial_std = 0.3
tangential_std = 0.05
num_classes = 5
num_per_class = 300
rate = 0.2
data = u.make_pinwheel_data(radial_std,tangential_std,num_classes,num_per_class,rate)

#mean = -100 * torch.ones((num_per_class,2))
#variance = 0.1 * torch.ones((num_per_class,2 ))
#data = dist.Normal(mean,variance).sample()


In [None]:
ax = plt.subplot(111)
ax.scatter(data[:,0],data[:,1])


In [None]:
def set_neural_net_idenity(neural_net, low_lim=-10, up_lim=10, num_points=15, epochs=25000, lr=0.0001, frequency=200):
    "Sets neural network to idenity in a specified region"
    x = torch.linspace(low_lim, up_lim, num_points, requires_grad=True)
    y = torch.linspace(low_lim, up_lim, num_points, requires_grad=True)
    ac, bc = torch.meshgrid((x,y))
    acf = ac.flatten().unsqueeze(-1)
    bcf = bc.flatten().unsqueeze(-1)
    data = torch.cat((acf,bcf), dim=1)
    data = data.detach()[torch.randperm(len(data))]

    opt = torch.optim.SGD(neural_net.parameters(), lr=lr)
    loss = torch.nn.MSELoss()

#    data = data[torch.randperm(len(data))].float()
    train = int(len(data) * 0.8)
    train_loader = torch.utils.data.DataLoader(data[:train],len(data[:train]),shuffle=True)
    test_loader = torch.utils.data.DataLoader(data[train:],200)


    losses = []
    for epoch in range(1, epochs + 1): 
        for batch in train_loader: 
            opt.zero_grad()
            res = neural_net(batch)[0]
            out = loss(res, batch)
            out.backward()
            opt.step()

            losses.append(out.item())

        if epoch % frequency == 0:
            print("[{}] Loss:{:.2f}".format(epoch, losses[-1]))
    
    return losses










In [None]:
class Decoder(nn.Module): 
    def __init__(self, z_dim, x_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, x_dim)
        self.fcad = nn.Linear(hidden_dim, hidden_dim)
        #Gives parameters for the cholesky deocomposition
        self.fc22 = nn.Linear(hidden_dim, 1)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.non_linearity = nn.ReLU()
        self.x_dim = x_dim
        self.z_dim = z_dim
        
    def forward(self, z): 
        hidden = self.non_linearity(self.fc1(z))
        hidden = self.non_linearity(self.fc11(hidden))
        x_loc = self.fc21(hidden)

        x_scale = torch.exp(self.fc22(hidden))
        return x_loc , x_scale

class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, hidden_dim): 
        super().__init__()
        self.fc1 = nn.Linear(x_dim,hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.fc11 = nn.Linear(hidden_dim, hidden_dim)
        self.softplus = nn.Softplus()

    def forward(self, x): 
        hidden = self.softplus(self.fc1(x))
        hidden = self.softplus(self.fc11(hidden))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))

        return z_loc, z_scale



In [None]:
class VAE(nn.Module): 
    def __init__(self, x_dim=2, z_dim=2, hidden_dim=2, K=3, use_cuda=False, init_lr=0.001, init_epochs=20000): 
        super().__init__()
        self.encoder = Encoder(x_dim,z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, x_dim, hidden_dim)
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.x_dim = x_dim
        self.K = K
        set_neural_net_idenity(self.encoder, epochs=init_epochs, lr=init_lr)
        set_neural_net_idenity(self.decoder, epochs=init_epochs, lr=init_lr)

    @config_enumerate 
    def model(self, x):
        pyro.module('decoder', self.decoder)

        scale = pyro.param('scale', torch.tensor(1.), constraints.interval(0.2,10))
        latent_scale = pyro.sample('latent_scale', dist.LogNormal(-1, 1.0))
        weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(self.K)))
        with pyro.plate('components', self.K):
            m_loc = x.new_zeros(torch.Size((self.K, self.z_dim)))
            m_scale =  2 * x.new_ones(torch.Size((self.K, self.z_dim)))
            locs = pyro.sample('locs',dist.Normal(m_loc, m_scale).to_event(1))

        with pyro.plate('data', x.shape[0]):
            assignment = pyro.sample('assigment', dist.Categorical(weights))
            z = pyro.sample('latent', dist.Normal(locs[assignment],latent_scale).to_event(1))
            loc_img, var_img = self.decoder(z)
            var_img = torch.eye(x.shape[1]) * scale
            out = pyro.sample('obs', dist.MultivariateNormal(loc_img, var_img), obs=x)


    def guide(self, x):
        pyro.module('encoder', self.encoder)
        scale_mean = pyro.param('scale_mean', torch.ones(1))
        scale_variance = pyro.param('scale_var', torch.ones(1), constraint=constraints.positive)
        concentration = pyro.param('concentration', torch.ones(1), constraint=constraints.positive)

        location_means = pyro.param('location_means', torch.ones(self.K, self.z_dim))
        location_scale = pyro.param('location_scale', torch.ones(1), constraint=constraints.positive)

        pyro.sample('weights', dist.Dirichlet(concentration * torch.ones(self.K)))
        pyro.sample('latent_scale', dist.LogNormal(scale_mean, scale_variance))
        with pyro.plate('components', self.K): 
            locs = pyro.sample('locs', dist.Normal(location_means, location_scale).to_event(1))



        pyro.module('encoder', self.encoder)
        with pyro.plate('data', x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            out = pyro.sample('latent', dist.Normal(z_loc,z_scale).to_event(1))


    def sample_latent(self,num_samples):

        scale = pyro.param('scale')
        scale_mean = pyro.param('scale_mean')
        scale_variance = pyro.param('scale_var')
        concentration = pyro.param('concentration')
        location_means = pyro.param('location_means')

        #I use the mean for the value of the scale 
        latent_scale = torch.exp(scale_mean  - scale_variance ** 2)
        weights = pyro.sample('weights', dist.Dirichlet(concentration * torch.ones(self.K)))

        with pyro.plate('data', num_samples):
            assignment = pyro.sample('assigment', dist.Categorical(weights))
            z = pyro.sample('latent', dist.Normal(location_means[assignment],latent_scale).to_event(1))
            loc_img, var_img = self.decoder(z)
            var_img = torch.eye(self.x_dim) * scale
            res = pyro.sample('obser', dist.MultivariateNormal(loc_img, var_img))
        return assignment, res

    def sample_cluster(self, num_samples, location):
        scale = pyro.param('scale')
        scale_mean = pyro.param('scale_mean')
        scale_variance = pyro.param('scale_var')
        concentration = pyro.param('concentration')
        location_means = pyro.param('location_means')

        #I use the mean for the value of the scale 
        latent_scale = torch.exp(scale_mean  - scale_variance ** 2)
        weights = pyro.sample('weights', dist.Dirichlet(concentration * torch.ones(self.K)))

        with pyro.plate('data', num_samples):
            z = pyro.sample('latent', dist.Normal(location_means[location],latent_scale).to_event(1))
            loc_img, var_img = self.decoder(z)
            var_img = torch.eye(self.x_dim) * scale
            res = pyro.sample('obser', dist.MultivariateNormal(loc_img, var_img))
        return res


    def reconstruct(self, x): 
        mean, var = self.encoder(x)
        z = dist.Normal(mean, var).sample()
        mean, var = self.decoder(z)
        var = torch.eye(self.x_dim) * pyro.param('scale')
        res = dist.MultivariateNormal(mean, var).sample()
        return res





In [None]:
data = torch.tensor(data).float()
train = int(len(data) * 0.8)
train_loader = torch.utils.data.DataLoader(data[:train],len(data[:train]),shuffle=True)
test_loader = torch.utils.data.DataLoader(data[train:],200)

In [None]:
def train(svi, train_loader, use_cuda=False): 
    epoch_loss = 0


    for minibatch in train_loader: 
        if use_cuda:
            minibatch = minibatch.cuda()
        epoch_loss = svi.step(minibatch)

    normalizer_train = len(train_loader.dataset)

    total_epoch_loss_train = epoch_loss/normalizer_train
    return total_epoch_loss_train


In [None]:
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [None]:
LR = 0.01
USE_CUDA = False
K = 5
NUM_EPOCHS = 20000
TEST_FREQUENCY = 50

In [None]:
pyro.clear_param_store()
vae = VAE(hidden_dim=20, init_epochs=2000, K=K)

adam_args = {'lr':LR}
optimizer = torch.optim.Adam
scheduler = pyro.optim.StepLR({ 'step_size' : 15000, 'gamma':0.01,'verbose':False, 'optimizer':optimizer,'optim_args':adam_args})

guide = autoguide.AutoDiagonalNormal(vae.model)
svi = SVI(vae.model, vae.guide, scheduler,loss=TraceEnum_ELBO(max_plate_nesting=1))
train_elbo = []
test_elbo = []

for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(total_epoch_loss_train)

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))
        scheduler.step()

    if epoch % TEST_FREQUENCY == 0: 
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))


In [None]:
plt.plot((np.log(np.array(train_elbo))))

In [None]:
#real = data.detach().numpy()
cluster = 4
real = data
sampled = vae.sample_cluster(10000, cluster)
sampled = sampled.detach().numpy()
plt.scatter(sampled[:,0],sampled[:,1],c='r', alpha =0.1)
plt.scatter(real[:,0],real[:,1],c='b',alpha=0.5)


In [None]:
read = data 
ass, sampled = vae.sample_latent(10000)
sampled = sampled.detach().numpy()
plt.scatter(sampled[:,0],sampled[:,1],c=ass**2, alpha =0.1)
plt.scatter(real[:,0],real[:,1],c='b',alpha=0.1)


In [None]:
reconstructed = vae.reconstruct(datat).detach().numpy()
color = real[:,1]/np.abs(real[:,1])
plt.scatter(reconstructed[:,0],reconstructed[:,1])
plt.scatter(real[:,0],real[:,1])

In [None]:
#Trying to predict membership 
guide_trace = poutine.trace(vae.guide).get_trace(data)
trained_model = poutine.replay(vae.model,trace=guide_trace)
def classifier(data, temperature=0):
    inferred_model = infer_discrete(trained_model, temperature=temperature,first_available_dim=-3)  # avoid conflict with data plate
    trace = poutine.trace(inferred_model).get_trace(data)
    return trace.nodes["assignment"]["value"]

print(classifier(data))


# Second attempt using a full bayesian specification. 

In [None]:
print(pyro.param('location_means'))
print(pyro.param('location_scale'))
print(pyro.param('scale_var'))
print(pyro.param('scale_mean'))