In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import *
from objectives import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 0.5.0a0+3bb8c5e cuda: True


In [9]:
T = 30
K = 3
D = 2

## Model Parameters
num_particles_rws = 10
mcmc_steps = 5
num_particles_smc = 60
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS = 2 * K
BATCH_SIZE = 50
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
CUDA = False
RESTORE = False
PATH_ENC = "stepwise_enc-%drws-%dmcmc-%dsmc-enc-%s" % (num_particles_rws, mcmc_steps, num_particles_smc, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'))

In [10]:
Ys = np.load('dataset2/sequences.npy')
As_true = torch.from_numpy(np.load('dataset2/transitions.npy')).float()
Zs_true = np.load('dataset2/states.npy')
mu_ks = torch.from_numpy(np.load('dataset2/means.npy')).float()
cov_ks = torch.from_numpy(np.load('dataset2/covariances.npy')).float()
Pi = torch.from_numpy(np.load('dataset2/init.npy')).float()

prior = initial_trans_prior(K)

Ys_Zs = np.concatenate((Ys, Zs_true), axis=-1)

In [11]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K) + prior
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [12]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [None]:
EUBOs = []
ELBOs = []
ELBOs2= [] 
ESSs = []
KLs = []
LOSSs = []

prior_mcmc = torch.ones((K, K)).float()
for k in range(K):
    prior_mcmc[k, k] = 2.0
    
Grad_Steps = int((Ys.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    np.random.shuffle(Ys_Zs)
    Ys_shuffled = torch.from_numpy(Ys_Zs[:, :, :2]).float()
    Zs_true_shuffled = torch.from_numpy(Ys_Zs[:, :, 2:]).float()
    for step in range(Grad_Steps):
        time_start = time.time()
        optimizer.zero_grad()
        batch_data = Ys_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_zs = Zs_true_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]

        batch_As = As_true[:BATCH_SIZE]
        loss, eubo, elbo, ess, kl = ag_sis_stepwise4(enc, prior, prior_mcmc, batch_As, batch_zs, Pi, mu_ks, cov_ks, batch_data, T, D, K, num_particles_rws, num_particles_smc, mcmc_steps, BATCH_SIZE)
       
        loss.backward()
        optimizer.step()
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        ESSs.append(ess.item())
        KLs.append(kl.item())
        LOSSs.append(loss.item())
        time_end = time.time()
        print('epoch : %d, step : %d, EUBO : %f, ELBO : %f, KL : %f (%ds)' % (epoch, step, eubo, elbo, kl, time_end - time_start))

epoch : 0, step : 0, EUBO : -398.835114, ELBO : -422.378296, KL : 8.359016 (130s)
epoch : 0, step : 1, EUBO : -400.512695, ELBO : -424.214447, KL : 9.648344 (130s)
epoch : 1, step : 0, EUBO : -399.881439, ELBO : -423.125397, KL : 9.109833 (130s)
epoch : 1, step : 1, EUBO : -398.757721, ELBO : -421.653748, KL : 9.855011 (135s)
epoch : 2, step : 0, EUBO : -395.613678, ELBO : -419.229767, KL : 9.609931 (141s)
epoch : 2, step : 1, EUBO : -404.466400, ELBO : -425.395050, KL : 8.391374 (138s)
epoch : 3, step : 0, EUBO : -400.587646, ELBO : -422.363708, KL : 8.334251 (139s)
epoch : 3, step : 1, EUBO : -397.603333, ELBO : -419.403992, KL : 8.980739 (145s)
epoch : 4, step : 0, EUBO : -400.409454, ELBO : -423.052612, KL : 8.053574 (139s)
epoch : 4, step : 1, EUBO : -397.236633, ELBO : -420.494843, KL : 8.539365 (137s)
epoch : 5, step : 0, EUBO : -399.882446, ELBO : -423.988800, KL : 7.195759 (141s)
epoch : 5, step : 1, EUBO : -399.749268, ELBO : -420.888855, KL : 9.944623 (140s)
epoch : 6, step 

In [None]:
plot_results(EUBOs, ELBOs2, ESSs, KLs, 'results_sis.png')

In [None]:
Ys_test = torch.from_numpy(Ys_Zs[0, :, :2]).float().unsqueeze(0)
Zs_true_test = torch.from_numpy(Ys_Zs[0, :, 2:]).float().unsqueeze(0)

In [None]:
mcmc_steps = 5

log_uptonow_weights = torch.zeros((1, mcmc_steps, num_particles_rws))
log_increment_weights = torch.zeros((1, mcmc_steps, num_particles_rws))
Zs_candidates = torch.zeros((num_particles_rws, 1, T, K))
log_normalizers_candidates = torch.zeros((num_particles_rws, 1))

conj_posts = conj_posterior(prior.unsqueeze(0), Zs_true_test, T, K, 1)
Z_pairs_true = flatz(Zs_true_test, T, K, 1)

batch_As = As_true[:1]


for m in range(mcmc_steps):
    if m == 0:
        for l in range(num_particles_rws):
            ## As B * K * K
            As = initial_trans(prior, K, 1)
            Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, num_particles_smc, 1)
            ## Z B * T * K
            Z = smc_resamplings(Zs, log_weights, 1)
            Zs_candidates[l] = Z
            log_increment_weights[:, m, l] = log_normalizers
            log_uptonow_weights[:, m, l] = log_normalizers
            
    else:
        for l in range(num_particles_rws):
            ## z_pairs (B * T-1)-by-(2K)
            Z_pairs = flatz(Zs_candidates[l], T, K, 1)
            variational, As = enc(Z_pairs, prior, 1)

            Zs, log_weights, log_normalizers = smc_hmm_batch(Pi,As, mu_ks, cov_ks, Ys_test, T, D, K, num_particles_smc, 1)
            Z = smc_resamplings(Zs, log_weights, 1)
            Zs_candidates[l] = Z
            
            log_ps = log_joints(prior, Z, Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, 1)
            log_ps_smc = smc_log_joints(Z, Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, 1)
            log_increment_weights[:, m, l] =  log_ps.detach() + log_normalizers - log_qs(variational, As) - log_ps_smc.detach()
            log_uptonow_weights[:, m, l] = log_uptonow_weights[:, m-1, l] + log_increment_weights[:, m, l]
            
variational_true, As_notusing = enc(Z_pairs_true, prior, 1)
kls = log_qs(conj_posts, As_true[0].unsqueeze(0)) - log_qs(variational_true, As_true[0].unsqueeze(0))

log_final_weights = log_increment_weights[:, -1, :]
weights_rws = torch.exp(log_final_weights - logsumexp(log_final_weights, dim=1).unsqueeze(1)).detach()
uptonow_weights = torch.exp(log_uptonow_weights - logsumexp(log_uptonow_weights, dim=-1).unsqueeze(-1)).detach()

log_overall_weights = log_uptonow_weights[:, -1, :]

ess = (1. / (weights_rws ** 2 ).sum(1)).mean()
eubos = torch.mul(weights_rws, log_final_weights).sum(1)
#     elbos =  log_overall_weights.mean(-1)
elbos2 = log_final_weights.mean(-1)

In [None]:
Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As_true[0].unsqueeze(0), mu_ks, cov_ks, Ys_test, T, D, K, 100, 1)
## Z B * T * K
Z = smc_resamplings(Zs, log_weights, 1)
torch.abs(Z - Zs_true_test).sum()

In [None]:
torch.abs(Zs_candidates.squeeze(1) - Zs_true_test).sum(1).sum(1)

In [None]:
variational_true, As_notusing = enc(Z_pairs_true, prior, 1)
kls = log_qs(conj_posts, As_true[0].unsqueeze(0)) - log_qs(variational_true, As_true[0].unsqueeze(0))

In [None]:
kls