In [None]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import *
from objectives import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

In [None]:
T = 30
K = 3
D = 2

## Model Parameters
num_particles_rws = 20
mcmc_steps = 3
num_particles_smc = 80
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS = 2 * K
BATCH_SIZE = 1
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-3
CUDA = False
RESTORE = False
PATH_ENC = "stepwise_enc-%drws-%dmcmc-%dsmc-enc-%s" % (num_particles_rws, mcmc_steps, num_particles_smc, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'))

In [None]:
generate_seq(T, K, dt, Boundary, init_v, noise_cov)

In [None]:
Ys = np.load('ball_dataset/sequences.npy')
As_true = torch.from_numpy(np.load('ball_dataset/transitions.npy')).float()
Zs_true = np.load('ball_dataset/states.npy')
mu_ks = torch.from_numpy(np.load('ball_dataset/means.npy')).float()
cov_ks = torch.from_numpy(np.load('ball_dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('ball_dataset/init.npy')).float()

prior = initial_trans_prior(K)

Ys_Zs = np.concatenate((Ys, Zs_true), axis=-1)

In [None]:
# Ys_test = torch.from_numpy(Ys[0]).float().unsqueeze(0)
# Zs_test = torch.from_numpy(Zs_true[0]).float().unsqueeze(0)
# As = As_true[0].unsqueeze(0)
# Accus = []
# for i in range(100):
#     Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, num_particles_smc, 1)
#     ## Z B * T * K
#     Z = smc_resamplings(Zs, log_weights, 1)
#     ## compute accuracy
#     accuracy = torch.abs(Zs_test - Z).sum().item() / 2.0
#     Accus.append(accuracy)

In [None]:
# plt.plot(Accus)

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K) + prior
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [None]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [None]:
EUBOs = []
ELBOs = []
ESSs = []
KLs = []
LOSSs = []

prior_mcmc = torch.ones((K, K)).float()
for k in range(K):
    prior_mcmc[k, k] = 2.0

Grad_Steps = int((Ys.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
#     np.random.shuffle(Ys_Zs)
    Ys_shuffled = torch.from_numpy(Ys_Zs[:BATCH_SIZE, :, :2]).float()
    Zs_true_shuffled = torch.from_numpy(Ys_Zs[:BATCH_SIZE, :, 2:]).float()
    for step in range(Grad_Steps):
        time_start = time.time()
        optimizer.zero_grad()
#         batch_data = Ys_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
#         batch_zs = Zs_true_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_data = Ys_shuffled
        batch_zs = Zs_true_shuffled
        batch_As = As_true[:BATCH_SIZE]
        eubo, elbo, ess, kl = ag_sis_adaptive(enc, prior, prior_mcmc, batch_As, batch_zs, Pi, mu_ks, cov_ks, batch_data, T, D, K, num_particles_rws, num_particles_smc, mcmc_steps, BATCH_SIZE)
       
        eubo.backward()
        optimizer.step()
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        ESSs.append(ess.item())
        KLs.append(kl.item())
#         LOSSs.append(loss.item())
        time_end = time.time()
        print('epoch : %d, step : %d, EUBO : %f, ELBO : %f, KL : %f, ESS : %f (%ds)' % (epoch, step, eubo, elbo, kl, ess, time_end - time_start))

In [None]:
plot_results(EUBOs, ELBOs, ESSs, KLs, 'results_sis.png')

In [None]:
Ys_test = torch.from_numpy(Ys_Zs[0, :, :2]).float().unsqueeze(0)
Zs_true_test = torch.from_numpy(Ys_Zs[0, :, 2:]).float().unsqueeze(0)

In [None]:
mcmc_steps = 5

log_uptonow_weights = torch.zeros((1, mcmc_steps, num_particles_rws))
log_increment_weights = torch.zeros((1, mcmc_steps, num_particles_rws))
Zs_candidates = torch.zeros((num_particles_rws, 1, T, K))
log_normalizers_candidates = torch.zeros((num_particles_rws, 1))

conj_posts = conj_posterior(prior.unsqueeze(0), Zs_true_test, T, K, 1)
Z_pairs_true = flatz(Zs_true_test, T, K, 1)

batch_As = As_true[:1]


for m in range(mcmc_steps):
    if m == 0:
        for l in range(num_particles_rws):
            ## As B * K * K
            As = initial_trans(prior, K, 1)
            Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, num_particles_smc, 1)
            ## Z B * T * K
            Z = smc_resamplings(Zs, log_weights, 1)
            Zs_candidates[l] = Z
            log_increment_weights[:, m, l] = log_normalizers
            log_uptonow_weights[:, m, l] = log_normalizers
            
    else:
        for l in range(num_particles_rws):
            ## z_pairs (B * T-1)-by-(2K)
            Z_pairs = flatz(Zs_candidates[l], T, K, 1)
            variational, As = enc(Z_pairs, prior, 1)

            Zs, log_weights, log_normalizers = smc_hmm_batch(Pi,As, mu_ks, cov_ks, Ys_test, T, D, K, num_particles_smc, 1)
            Z = smc_resamplings(Zs, log_weights, 1)
            Zs_candidates[l] = Z
            
            log_ps = log_joints(prior, Z, Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, 1)
            log_ps_smc = smc_log_joints(Z, Pi, As, mu_ks, cov_ks, Ys_test, T, D, K, 1)
            log_increment_weights[:, m, l] =  log_ps.detach() + log_normalizers - log_qs(variational, As) - log_ps_smc.detach()
            log_uptonow_weights[:, m, l] = log_uptonow_weights[:, m-1, l] + log_increment_weights[:, m, l]
            
variational_true, As_notusing = enc(Z_pairs_true, prior, 1)
kls = log_qs(conj_posts, As_true[0].unsqueeze(0)) - log_qs(variational_true, As_true[0].unsqueeze(0))

log_final_weights = log_increment_weights[:, -1, :]
weights_rws = torch.exp(log_final_weights - logsumexp(log_final_weights, dim=1).unsqueeze(1)).detach()
uptonow_weights = torch.exp(log_uptonow_weights - logsumexp(log_uptonow_weights, dim=-1).unsqueeze(-1)).detach()

log_overall_weights = log_uptonow_weights[:, -1, :]

ess = (1. / (weights_rws ** 2 ).sum(1)).mean()
eubos = torch.mul(weights_rws, log_final_weights).sum(1)
#     elbos =  log_overall_weights.mean(-1)
elbos2 = log_final_weights.mean(-1)

In [None]:
plot_results(EUBOs, ELBOs, ESSs, KLs, 'filename.png')

In [None]:
torch.abs(Zs_candidates.squeeze(1) - Zs_true_test).sum(1).sum(1)

In [None]:
variational_true, As_notusing = enc(Z_pairs_true, prior, 1)
kls = log_qs(conj_posts, As_true[0].unsqueeze(0)) - log_qs(variational_true, As_true[0].unsqueeze(0))

In [None]:
kls