In [None]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from util_data import *
from util_hmm_variational_gibbs import *
from util_plots import *
from scipy.stats import invwishart, dirichlet
from torch.distributions.dirichlet import Dirichlet
sys.path.append('/home/hao/Research/probtorch/')
from probtorch.util import expand_inputs
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

In [None]:
## Dataset parameters
num_series = 1
T = 50
K = 4
D = 2
dt = 10
Boundary = 30
noise_ratio = 0.5

## Model Parameters
num_particles_rws = 50
mcmc_steps = 2
num_particles_smc = 50
NUM_HIDDEN = 128
NUM_LATENTS = K*K
NUM_OBS = 2 * K

NUM_EPOCHS = 80
LEARNING_RATE = 1e-3
CUDA = False

In [None]:
noise_cov = np.array([[1, 0], [0, 1]]) * noise_ratio
init_v = np.random.random(2) * np.random.choice([-1,1], size=2)
v_norm = ((init_v ** 2 ).sum()) ** 0.5 ## compute norm for each initial velocity
init_v = init_v / v_norm * dt ## to make the velocity lying on a circle

STATE, Disp, A_true, Zs_true = generate_seq(T, dt, Boundary, init_v, noise_cov)
## true global variables
cov_true = np.tile(noise_cov, (K, 1, 1))
dirs = np.array([[1, 1], [1, -1], [-1, -1], [-1, 1]])
mu_true = np.tile(np.absolute(init_v), (K, 1)) * dirs
Pi_true = np.ones(K) * (1/K)

plot_clusters(Disp, mu_true, cov_true, K)
Zs_true = torch.from_numpy(Zs_true).float()
cov_ks = torch.from_numpy(cov_true).float()
mu_ks = torch.from_numpy(mu_true).float()
Pi = torch.from_numpy(Pi_true).float()
A_init = torch.from_numpy(A_true).float()
## piror of A
alpha_trans_0 = initial_trans_prior(K)
## Y
Y = torch.from_numpy(Disp).float()
print(mu_true)

In [None]:
A_samples = A_init
Zs, log_weights, log_normalizer = smc_hmm(Pi, A_samples, mu_ks, cov_ks, Y, T, D, K, num_particles_smc)
Z_ret = resampling_smc(Zs, log_weights)
plot_smc_sample(Zs_true, Z_ret)

In [None]:
## return samples in order to compute the weights and 
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior_sum, T):
        A_samples = torch.zeros((K, K))
        hidden = self.enc_hidden(obs)
        latents_dirs = torch.exp(self.latent_dir(hidden)).sum(0).view(K, K)
        latents_dirs_norm = latents_dirs / latents_dirs.sum() * (prior_sum + T-1)
        for k in range(K):
            A_samples[k] = Dirichlet(latents_dirs_norm[k]).sample()
        return latents_dirs_norm, A_samples

In [None]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [None]:
KLs = []
EUBOs = []
log_p_conds = []
log_qs = []
ESSs = []

for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    optimizer.zero_grad()
    init_v = np.random.random(2) * np.random.choice([-1,1], size=2)
    v_norm = ((init_v ** 2 ).sum()) ** 0.5 ## compute norm for each initial velocity
    init_v = init_v / v_norm * dt ## to make the velocity lying on a circle
    T = np.random.randint(30, 50)
    STATE, Disp, A_true, Zs_true = generate_seq(T, dt, Boundary, init_v, noise_cov)
    ## true global variables
    cov_true = np.tile(noise_cov, (K, 1, 1))
    dirs = np.array([[1, 1], [1, -1], [-1, -1], [-1, 1]])
    mu_true = np.tile(np.absolute(init_v), (K, 1)) * dirs
    Pi_true = np.ones(K) * (1/K)
    cov_ks = torch.from_numpy(cov_true).float()
    mu_ks = torch.from_numpy(mu_true).float()
    Pi = torch.from_numpy(Pi_true).float()
    A_init = initial_trans(alpha_trans_0, K)
#     A_init = torch.from_numpy(A_true).float()
    alpha_trans_0 = initial_trans_prior(K)
    Y = torch.from_numpy(Disp).float()
    

    enc, loss_infer, eubo, kl, ess, latents_dirs, Z_ret = rws(enc, A_init, alpha_trans_0, Pi, mu_ks, cov_ks, Y, T, D, K, num_particles_rws, num_particles_smc, mcmc_steps)
#     kl_est = torch.mul(weights_rws, log_p_conds - log_qs).sum().detach().item()
    log_q = - loss_infer
    eubo.backward()
    KLs.append(kl.item())
    EUBOs.append(eubo)
    ESSs.append(ess)
    log_qs.append(log_q)
    optimizer.step()
#     A_samples = A_samples.detach()
    time_end = time.time()
    print('epoch : %d, eubo : %f, log_q : %f, KL : %f (%ds)' % (epoch, eubo, log_q, kl, time_end - time_start))

In [None]:
learned_dicichlet_post = latents_dirs
true_dirichlet_post = alpha_trans_0 + pairwise(torch.from_numpy(Zs_true).float(), T).sum(0)
print('variational : ')
print(learned_dicichlet_post)
print('conjugate posterior :')
print(true_dirichlet_post)
plot_dirs(learned_dicichlet_post.data.numpy(), true_dirichlet_post.data.numpy(), vmax=15)

In [None]:
plt.plot(np.array(ESSs) / num_particles_rws)

In [None]:
learned_dicichlet_post.sum()

In [None]:
true_dirichlet_post.sum()