In [None]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.beta import Beta
from torch.distributions.uniform import Uniform
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical as rcat
from torch import logsumexp
import sys
import time
import datetime
import math

In [None]:
Xs = torch.from_numpy(np.load('circles/obs.npy')).float()
Xs = Xs.transpose(1,0)

N = Xs.shape[0]
D = 2
## Model Parameters
NUM_SAMPLES = 10
NUM_HIDDEN = 64
NUM_STROKES = 1
NUM_LATENTS = 2
NUM_OBS = D
SGD_STEPS = 25000
LEARNING_RATE = 1e-4
CUDA = False
PATH = 'circles'

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_obs= D,
                       num_hidden=NUM_HIDDEN,
                       num_strokes=NUM_STROKES,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_h = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.enc_log_u_a = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden+0.5*num_strokes)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden+0.5*num_strokes), num_latents))
        self.enc_log_u_b = nn.Sequential(
            nn.Linear(num_hidden, int(0.5*num_hidden+0.5*num_strokes)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden+0.5*num_strokes), num_latents))

    def forward(self, obs, num_samples):
        h = self.enc_h(obs) # (B, H)
        a = torch.exp(self.enc_log_u_a(h)) # (B, 2)
        b = torch.exp(self.enc_log_u_b(h)) # (B, 2)
        q_u = Beta(a, b)
        u = q_u.sample((num_samples,)) # (S, B, 2)
        log_q_u = q_u.log_prob(u).sum(-1)
        return u, log_q_u
    
    
class Decoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_strokes=NUM_STROKES,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()

        self.x_sigma = 0.01 * torch.ones((N, num_obs))
        self.dec_mu = nn.Sequential(
            nn.Linear(num_latents, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_obs))

    def forward(self, u, obs, num_samples):
        mu = self.dec_mu(u) # (S, B, K+1) -> (S, B, O)
        p_x = Normal(mu, self.x_sigma.repeat(num_samples, 1, 1)) # (S, B, O)
        log_p_x = p_x.log_prob(obs).sum(-1) # (S, B)
        return mu, log_p_x

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 1e-2)     
        
def initialize():
    enc = Encoder()
    dec = Decoder()
#     enc.apply(weights_init)
    opt1 =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99)) 
    opt2 = torch.optim.Adam(list(dec.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))
    return enc, dec, opt1, opt2
enc, dec, opt1, opt2 = initialize()

In [None]:
def oneshot(x, N, D, num_samples, num_strokes):
    log_p_xs = torch.zeros((num_samples, N))
    log_prs = torch.zeros((num_samples, N))
    X_mus = torch.zeros((num_samples, N, D))
    u, log_q = enc(x, num_samples)
    mu, log_p_x = dec(u, x, num_samples)
    pr_u = Uniform(torch.zeros(N, NUM_LATENTS), torch.ones(N, NUM_LATENTS))
    log_pr_u = pr_u.log_prob(u).sum(-1)
    log_weights = (log_pr_u + log_p_x - log_q)
    weights = torch.exp(log_weights - logsumexp(log_weights, 0)).detach()
    eubo = torch.mul(weights, log_weights).sum(0).mean()
    elbo = log_weights.mean(0).mean()

    return eubo, elbo, mu, u

In [None]:
ELBOs = []
EUBOs = []
time_start = time.time()

for step in range(SGD_STEPS):
    indices = torch.randperm(N)
    Xs_shuffle = Xs[indices]
    opt1.zero_grad()
    eubo, elbo, mu, u = oneshot(Xs_shuffle, N, D, NUM_SAMPLES, NUM_STROKES)
    eubo.backward()
    opt1.step()
    opt2.zero_grad()
    eubo, elbo, mu, u = oneshot(Xs_shuffle, N, D, NUM_SAMPLES, NUM_STROKES)
    (-elbo).backward()
    opt2.step()
    
    ELBOs.append(elbo.item())
    EUBOs.append(eubo.item())
    
    if step%100 == 0:
        time_end = time.time()  
        print('SGDstep=%d, EUBO=%.3f, ELBO=%.3f (%ds)' % (step, eubo, elbo, time_end - time_start))
        time_start = time.time()


In [None]:
eubo, elbo, mu, u = oneshot(Xs, N, D, NUM_SAMPLES, NUM_STROKES)

In [None]:
log_p_

In [None]:
T = 80
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(1,2,1)
ax2 = fig.add_subplot(1,2,2)   

ax1.scatter(Xs[:T, 0].data.numpy(), Xs[:T, 1].data.numpy())
ax1.set_xlim([-2,2])
ax1.set_ylim([-2,2])
ax1.set_title('data')

uu = u.mean(0)
ax2.plot(uu[:T, 0].data.numpy(), 'go')
ax2.plot(uu[:T, 1].data.numpy(), 'go')