In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import *
from vimco_ball import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [3]:
T = 50
K = 4
D = 2

## Model Parameters
num_samples = 20
mcmc_steps = 5
num_particles_smc = 100
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS = 2 * K
BATCH_SIZE = 20
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
CUDA = False

In [5]:
Ys = torch.from_numpy(np.load('ball_dataset/sequences.npy')).float()
As_true = torch.from_numpy(np.load('ball_dataset/transitions.npy')).float()
Zs_true = torch.from_numpy(np.load('ball_dataset/states.npy')).float()
mus_true = torch.from_numpy(np.load('ball_dataset/means.npy')).float()
covs_true = torch.from_numpy(np.load('ball_dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('ball_dataset/init.npy')).float()

prior = initial_trans_prior(K)

num_seqs = Ys.shape[0]

Ys_Zs = np.concatenate((Ys, Zs_true), axis=-1)

In [6]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K) + prior
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [7]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [8]:
ELBOs = []
ESSs = []
KLs = []
Grads = []

prior_mcmc = torch.ones((K, K)).float()
for k in range(K):
    prior_mcmc[k, k] = 2.0
    
Grad_Steps = int((Ys.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(num_seqs)

    for step in range(Grad_Steps):
        time_start = time.time()
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_data = Ys[batch_indices]
        batch_zs = Zs_true[batch_indices]
        mu_ks = mus_true[batch_indices]
        cov_ks = covs_true[batch_indices]
        
        gradient, elbo, ess, kl = ag_mcmc_vimco_ball(enc, prior_mcmc, prior_mcmc, batch_zs, Pi, mu_ks, cov_ks, batch_data, T, D, K, num_samples, num_particles_smc, mcmc_steps, BATCH_SIZE)
        loss = - gradient
        loss.backward()
        optimizer.step()
        ELBOs.append(elbo.item())
        ESSs.append(ess.item())
        KLs.append(kl.item())
        Grads.append(loss.item())
        time_end = time.time()
        print('epoch : %d, step : %d, ELBO : %f, KL : %f (%ds)' % (epoch, step, elbo, kl, time_end - time_start))

epoch : 0, step : 0, ELBO : -293.185608, KL : 35.251762 (188s)
epoch : 0, step : 1, ELBO : -293.719879, KL : 35.893490 (181s)
epoch : 0, step : 2, ELBO : -291.305176, KL : 37.740482 (183s)
epoch : 0, step : 3, ELBO : -295.787415, KL : 36.534245 (182s)
epoch : 0, step : 4, ELBO : -290.803009, KL : 35.129745 (182s)
epoch : 0, step : 5, ELBO : -292.610138, KL : 35.922729 (183s)
epoch : 0, step : 6, ELBO : -290.442688, KL : 35.606720 (186s)
epoch : 0, step : 7, ELBO : -293.126617, KL : 35.592064 (182s)
epoch : 0, step : 8, ELBO : -296.084015, KL : 35.446075 (183s)
epoch : 0, step : 9, ELBO : -293.879272, KL : 36.099197 (184s)
epoch : 1, step : 0, ELBO : -288.810883, KL : 36.504723 (185s)
epoch : 1, step : 1, ELBO : -289.901855, KL : 37.451214 (184s)
epoch : 1, step : 2, ELBO : -295.697784, KL : 35.532188 (185s)
epoch : 1, step : 3, ELBO : -297.655670, KL : 35.314476 (183s)
epoch : 1, step : 4, ELBO : -291.612244, KL : 36.604446 (183s)
epoch : 1, step : 5, ELBO : -292.039612, KL : 36.357136

KeyboardInterrupt: 