In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import *
from objectives import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 0.5.0a0+3bb8c5e cuda: True


In [2]:
T = 30
K = 3
D = 2

## Model Parameters
num_particles_rws = 15
mcmc_steps = 4
num_particles_smc = 30
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS = 2 * K
BATCH_SIZE = 50
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
CUDA = False
RESTORE = False
PATH_ENC = "stepwise_enc-%drws-%dmcmc-%dsmc-enc-%s" % (num_particles_rws, mcmc_steps, num_particles_smc, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'))

In [3]:
Ys = np.load('dataset/sequences.npy')
As_true = torch.from_numpy(np.load('dataset/transitions.npy')).float()
Zs_true = np.load('dataset/states.npy')
mu_ks = torch.from_numpy(np.load('dataset/means.npy')).float()
cov_ks = torch.from_numpy(np.load('dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('dataset/init.npy')).float()

prior = initial_trans_prior(K)

Ys_Zs = np.concatenate((Ys, Zs_true), axis=-1)

In [4]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K)
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [5]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [None]:
EUBOs = []
ELBOs = []
ELBOs2= [] 
ESSs = []
KLs = []
LOSSs = []

weights = []

Grad_Steps = int((Ys.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    np.random.shuffle(Ys_Zs)
    Ys_shuffled = torch.from_numpy(Ys_Zs[:, :, :2]).float()
    Zs_true_shuffled = torch.from_numpy(Ys_Zs[:, :, 2:]).float()
    for step in range(Grad_Steps):
        time_start = time.time()
        optimizer.zero_grad()
        batch_data = Ys_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_zs = Zs_true_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]

        batch_As = As_true[:BATCH_SIZE]
        loss, eubo, elbo2, elbo, ess, kl = ag_sis_stepwise3(enc, prior, batch_As, batch_zs, Pi, mu_ks, cov_ks, batch_data, T, D, K, num_particles_rws, num_particles_smc, mcmc_steps, BATCH_SIZE)
       
        loss.backward()
        optimizer.step()
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        ELBOs2.append(elbo2.item())
        ESSs.append(ess.item())
        KLs.append(kl.item())
        LOSSs.append(loss.item())
        time_end = time.time()
        print('epoch : %d, step : %d, EUBO : %f, ELBO2 : %f, ELBO : %f, KL : %f (%ds)' % (epoch, step, eubo, elbo2, elbo, kl, time_end - time_start))

epoch : 0, step : 0, EUBO : -137.135300, ELBO2 : -155.538727, ELBO : -633.028809, KL : 8.143072 (80s)
epoch : 0, step : 1, EUBO : -136.387650, ELBO2 : -156.219391, ELBO : -636.452454, KL : 8.167732 (83s)
epoch : 0, step : 2, EUBO : -136.308090, ELBO2 : -155.527924, ELBO : -633.579590, KL : 8.070945 (80s)
epoch : 0, step : 3, EUBO : -138.355804, ELBO2 : -157.004990, ELBO : -636.964844, KL : 7.968321 (85s)
epoch : 0, step : 4, EUBO : -137.809540, ELBO2 : -156.701752, ELBO : -636.592529, KL : 7.749455 (82s)
epoch : 0, step : 5, EUBO : -139.476822, ELBO2 : -157.243500, ELBO : -641.338989, KL : 7.569567 (81s)
epoch : 0, step : 6, EUBO : -137.626892, ELBO2 : -156.656677, ELBO : -636.660522, KL : 7.481657 (82s)
epoch : 0, step : 7, EUBO : -138.609818, ELBO2 : -157.624344, ELBO : -642.637634, KL : 7.698624 (88s)
epoch : 0, step : 8, EUBO : -137.310837, ELBO2 : -155.132568, ELBO : -633.225037, KL : 7.712729 (83s)
epoch : 0, step : 9, EUBO : -139.823578, ELBO2 : -158.250488, ELBO : -642.751953, 

In [None]:
plot_results(EUBOs, ELBOs, ESSs, KLs, 'results_sis.png')