In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import *
from vimco import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
T = 50
K = 4
D = 2

## Model Parameters
num_samples = 20
num_particles_smc = 100
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS = 2 * D
BATCH_SIZE = 20
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-3
CUDA = False
RESTORE = False
# PATH_ENC = "stepwise_enc-%drws-%dmcmc-%dsmc-enc-%s" % (num_particles_rws, mcmc_steps, num_particles_smc, datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'))

In [3]:
Ys = np.load('dataset/sequences.npy')
As_true = torch.from_numpy(np.load('dataset/transitions.npy')).float()
Zs_true = np.load('dataset/states.npy')
mu_ks = torch.from_numpy(np.load('dataset/means.npy')).float()
cov_ks = torch.from_numpy(np.load('dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('dataset/init.npy')).float()

prior = initial_trans_prior(K)

Ys_Zs = np.concatenate((Ys, Zs_true), axis=-1)

In [4]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K) + prior
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [8]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [9]:
ELBOs = []
ESSs = []
KLs = []
Grads = []

prior_mcmc = torch.ones((K, K)).float()
for k in range(K):
    prior_mcmc[k, k] = 2.0
    
Grad_Steps = int((Ys.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    np.random.shuffle(Ys_Zs)
    Ys_shuffled = torch.from_numpy(Ys_Zs[:, :, :2]).float()
    Zs_true_shuffled = torch.from_numpy(Ys_Zs[:, :, 2:]).float()
    for step in range(Grad_Steps):
        time_start = time.time()
        optimizer.zero_grad()
#         batch_data = Ys_shuffled
#         batch_zs = Zs_true_shuffled
        batch_data = Ys_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_zs = Zs_true_shuffled[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
#         batch_As = As_true[:BATCH_SIZE]
#         gradient, elbo, ess, kl = ag_mcmc_vimco(enc, prior, prior_mcmc, batch_zs, Pi, mu_ks, cov_ks, batch_data, T, D, K, num_samples, num_particles_smc, mcmc_steps, BATCH_SIZE)
#         loss = - gradient
#         loss.backward()
#         optimizer.step()
#         ELBOs.append(elbo.item())
#         ESSs.append(ess.item())
#         KLs.append(kl.item())
#         Grads.append(loss.item())
#         time_end = time.time()
#         print('epoch : %d, step : %d, ELBO : %f, KL : %f (%ds)' % (epoch, step, elbo, kl, time_end - time_start))

In [11]:
Z_pairs_true = flatz(batch_zs, T, K, BATCH_SIZE)

In [13]:
variational_new, As_notusing = enc(Z_pairs_true, prior_mcmc, BATCH_SIZE)

In [18]:
hidden = enc.enc_hidden(Z_pairs_true)
alphas = F.softmax(enc.latent_dir(hidden), -1).view(BATCH_SIZE, T-1, K*K).sum(1).view(BATCH_SIZE, K, K) + prior

In [20]:
Dirichlet(alphas[:, 0, :]).sample()

tensor([[0.5570, 0.1388, 0.3042],
        [0.6438, 0.1909, 0.1652],
        [0.6278, 0.2709, 0.1013],
        [0.4039, 0.3735, 0.2226],
        [0.3093, 0.4136, 0.2770],
        [0.2902, 0.1720, 0.5378],
        [0.5810, 0.3837, 0.0353],
        [0.3962, 0.3567, 0.2472],
        [0.2537, 0.6450, 0.1013],
        [0.5471, 0.3459, 0.1070],
        [0.4156, 0.3854, 0.1989],
        [0.3869, 0.2875, 0.3256],
        [0.3625, 0.4143, 0.2232],
        [0.5628, 0.1495, 0.2878],
        [0.3259, 0.4018, 0.2723],
        [0.5278, 0.2365, 0.2357],
        [0.5634, 0.2044, 0.2322],
        [0.4361, 0.1766, 0.3873],
        [0.6356, 0.1236, 0.2409],
        [0.3140, 0.2250, 0.4610],
        [0.6709, 0.1881, 0.1410],
        [0.3744, 0.4301, 0.1954],
        [0.6564, 0.0670, 0.2766],
        [0.5846, 0.3144, 0.1010],
        [0.6007, 0.1655, 0.2338],
        [0.4873, 0.4506, 0.0621],
        [0.3577, 0.3285, 0.3138],
        [0.5828, 0.3269, 0.0903],
        [0.5137, 0.1154, 0.3709],
        [0.493