In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.beta import Beta
from torch.distributions.uniform import Uniform
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical as rcat
from torch import logsumexp
import sys
import time
import datetime
import math

In [2]:
Xs = torch.from_numpy(np.load('circles/obs.npy')).float()
Mus = torch.from_numpy(np.load('circles/mus.npy')).float()
Xs = Xs.transpose(1,2)

B, N, D = Xs.shape

## Model Parameters
NUM_SAMPLES = 10
NUM_HIDDEN = 32
NUM_LATENTS = 2
NUM_OBS = D
NUM_EPOCHS = 25000
LEARNING_RATE = 1e-4
CUDA = False
PATH = 'circles'

In [3]:
class Encoder(nn.Module):
    def __init__(self, num_obs= NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_h = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.enc_log_u_a = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))
        self.enc_log_u_b = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden), num_latents))

    def forward(self, obs, num_samples):
        h = self.enc_h(obs) # (B, H)
        a = torch.exp(self.enc_log_u_a(h)) # (B, 2)
        b = torch.exp(self.enc_log_u_b(h)) # (B, 2)
        q_u = Beta(a, b)
        u = q_u.sample((num_samples,)) # (S, B, 2)
        log_q_u = q_u.log_prob(u).sum(-1)
        return u, log_q_u
    
    
class Decoder(nn.Module):
    def __init__(self, num_obs=D,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()

        self.x_sigma = 0.01 * torch.ones((N, num_obs))
        self.dec_mu = nn.Sequential(
            nn.Linear(num_latents, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_obs))

    def forward(self, u, obs, num_samples):
        mu = self.dec_mu(u) # (S, B, K) -> (S, B, O)
        p_x = Normal(mu, self.x_sigma.repeat(num_samples, 1, 1)) # (S, B, O)
        log_p_x = p_x.log_prob(obs).sum(-1) # (S, B)
        return mu, log_p_x

In [4]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 1e-2)     
        
def initialize():
    enc = Encoder()
    dec = Decoder()
#     enc.apply(weights_init)
    opt1 =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99)) 
    opt2 = torch.optim.Adam(list(dec.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
    return enc, dec, opt1, opt2
enc, dec, opt1, opt2 = initialize()

In [5]:
def oneshot(x, mean, N, D, num_samples):
    data = x - mean

    u, log_q = enc(data, num_samples)
    mu, log_p_x = dec(u, data, num_samples)
    pr_u = Uniform(torch.zeros(N, NUM_LATENTS), torch.ones(N, NUM_LATENTS))
    log_pr_u = pr_u.log_prob(u).sum(-1)
    log_weights = (log_pr_u + log_p_x - log_q)
    weights = torch.exp(log_weights - logsumexp(log_weights, 0)).detach()
    eubo = torch.mul(weights, log_weights).sum(0).mean()
    elbo = log_weights.mean(0).mean()

    return eubo, elbo, mu, u

In [6]:
ELBOs = []
EUBOs = []
time_start = time.time()

for epoch in range(NUM_EPOCHS):
    for b in range(B):
        X = Xs[b]
        mean = Mus[b]
        indices = torch.randperm(N)
        X_shuffle = X[indices]
        opt1.zero_grad()
        eubo, elbo, mu, u = oneshot(X_shuffle, mean, N, D, NUM_SAMPLES)
        eubo.backward()
        opt1.step()
        opt2.zero_grad()
        eubo, elbo, mu, u = oneshot(X_shuffle, mean, N, D, NUM_SAMPLES)
        (-elbo).backward()
        opt2.step()
    
        ELBOs.append(elbo.item())
        EUBOs.append(eubo.item())
    

    time_end = time.time()  
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f (%ds)' % (epoch, eubo, elbo, time_end - time_start))
    time_start = time.time()


epoch=0, EUBO=-3885.971, ELBO=-4992.250 (1s)
epoch=1, EUBO=-3645.053, ELBO=-4650.609 (1s)
epoch=2, EUBO=-3162.197, ELBO=-4262.119 (1s)
epoch=3, EUBO=-2660.216, ELBO=-3816.616 (1s)
epoch=4, EUBO=-2188.037, ELBO=-3321.881 (1s)
epoch=5, EUBO=-1823.715, ELBO=-2888.209 (1s)
epoch=6, EUBO=-1481.981, ELBO=-2435.780 (1s)
epoch=7, EUBO=-1180.779, ELBO=-2102.136 (1s)
epoch=8, EUBO=-943.233, ELBO=-1795.091 (1s)
epoch=9, EUBO=-727.630, ELBO=-1463.161 (1s)
epoch=10, EUBO=-543.582, ELBO=-1159.913 (1s)
epoch=11, EUBO=-413.598, ELBO=-972.828 (1s)
epoch=12, EUBO=-312.466, ELBO=-866.138 (1s)
epoch=13, EUBO=-238.926, ELBO=-716.899 (0s)
epoch=14, EUBO=-188.392, ELBO=-639.786 (0s)
epoch=15, EUBO=-142.875, ELBO=-576.557 (0s)
epoch=16, EUBO=-122.317, ELBO=-486.011 (0s)
epoch=17, EUBO=-107.185, ELBO=-466.349 (0s)
epoch=18, EUBO=-86.590, ELBO=-378.207 (0s)
epoch=19, EUBO=-76.600, ELBO=-333.947 (0s)
epoch=20, EUBO=-67.908, ELBO=-297.793 (0s)
epoch=21, EUBO=-58.715, ELBO=-270.757 (1s)
epoch=22, EUBO=-51.276, ELB

KeyboardInterrupt: 

In [7]:
torch.save(enc.state_dict(), 'VAE/enc-circle')
torch.save(dec.state_dict(), 'VAE/dec-circle')

In [None]:
P = 2
eubo, elbo, mu, u = oneshot(Xs[P], Mus[P], N, D, NUM_SAMPLES)

In [None]:
T = 80
fig = plt.figure(figsize=(18,5))
ax1 = fig.add_subplot(1,3,1)
ax2 = fig.add_subplot(1,3,2)   
ax3 = fig.add_subplot(1,3,3)
ax1.scatter(Xs[P,:T, 0].data.numpy(), Xs[P, :T, 1].data.numpy())
ax1.set_xlim([-2,2])
ax1.set_ylim([-2,2])
ax1.set_title('data')

uu = u.mean(0)
ax2.plot(uu[:T, 0].data.numpy(), 'go')
ax2.plot(uu[:T, 1].data.numpy(), 'ro')

ax3.scatter(uu[:T, 0].data.numpy(), uu[:T, 1].data.numpy())