In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from smc import *
from plots import *
from torch.distributions.dirichlet import Dirichlet
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
T = 50
K = 3
D = 2

## Model Parameters
rws_samples = 2
smc_samples = 10
steps = 5
NUM_HIDDEN = 64
NUM_LATENTS = K*K
NUM_OBS =  2 * K
BATCH_SIZE = 3
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
CUDA = False

In [3]:
Ys = torch.from_numpy(np.load('hmm_dataset/sequences.npy')).float()
As_true = torch.from_numpy(np.load('hmm_dataset/transitions.npy')).float()
Zs = torch.from_numpy(np.load('hmm_dataset/states.npy')).float()
mus = torch.from_numpy(np.load('hmm_dataset/means.npy')).float()
covs = torch.from_numpy(np.load('hmm_dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('hmm_dataset/init.npy')).float()
num_seqs = Zs.shape[0]

In [4]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh())
        self.latent_dir = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, prior, batch_size):
        As = torch.zeros((batch_size, K, K))
        hidden = self.enc_hidden(obs)
        alphas = F.softmax(self.latent_dir(hidden), -1).view(batch_size, T-1, K*K).sum(1).view(batch_size, K, K) + prior
        for k in range(K):
            As[:, k, :] = Dirichlet(alphas[:, k, :]).sample()
        return alphas, As

In [5]:
def flatz(Z, T, K, batch_size):
    return torch.cat((Z[:, :T-1, :].unsqueeze(2), Z[:, 1:, :].unsqueeze(2)), 2).view(batch_size * (T-1), 2*K)

def initial_trans(prior, K, batch_size):
    As = torch.zeros((batch_size, K, K)).float()
    for k in range(K):
        As[:, k, :] = Dirichlet(prior[k]).sample((batch_size,))
    return As

def adapt_resampling(Zs, log_weights, rws_samples):
    """
    Zs S-B-T-K
    reweights B-S
    """
    reweights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).transpose(0,1)
    ancesters = Categorical(reweights).sample((rws_samples,)).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, K)
    return torch.gather(Zs, 2, ancesters)
    
def eubo_hmm_rws(prior, Pi, mus, covs, Ys, T, D, K, rws_samples, smc_samples, steps, batch_size):
    log_final_weights = torch.zeros((rws_samples, batch_size)).float()
    for m in range(steps):
        log_increment_weights = torch.zeros((rws_samples, batch_size)).float()
        Zs_cand = torch.zeros((rws_samples, batch_size, T, K))
        if m == 0:
            for r in range(rws_samples):
                As = initial_trans(prior, K, batch_size)
                Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mus, covs, Ys, T, D, K, smc_samples, batch_size)
                Z = smc_resamplings(Zs, log_weights, batch_size)
                Zs_cand[r] = Z
                log_increment_weights[r] = log_normalizers
            Zs_samples = adapt_resampling(Zs_cand, log_increment_weights, rws_samples)
            
        elif m == (steps-1):
            for r in range(rws_samples):
                Z_pairs = flatz(Zs_samples[r], T, K, batch_size)
                alphas, As = enc(Z_pairs, prior, batch_size)
                Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mus, covs, Ys, T, D, K, smc_samples, batch_size)
                Z = smc_resamplings(Zs, log_weights, batch_size)
                log_p_prior = Dirichlet(prior).log_prob(As).sum(-1)
                log_q_enc = Dirichlet(alphas).log_prob(As).sum(-1)
                log_final_weights[r] =  log_p_prior - log_q_enc + log_normalizers
                
                
        else:
            for r in range(rws_samples):
                Z_pairs = flatz(Zs_samples[r], T, K, batch_size)
                alphas, As = enc(Z_pairs, prior, batch_size)
                Zs, log_weights, log_normalizers = smc_hmm_batch(Pi, As, mus, covs, Ys, T, D, K, smc_samples, batch_size)
                Z = smc_resamplings(Zs, log_weights, batch_size)
                log_p_prior = Dirichlet(prior).log_prob(As).sum(-1)
                log_q_enc = Dirichlet(alphas).log_prob(As).sum(-1)
                log_increment_weights[r] =  log_p_prior - log_q_enc + log_normalizers
            Zs_samples = adapt_resampling(Zs_cand, log_increment_weights, rws_samples)
        weights = torch.exp(log_final_weights - logsumexp(log_final_weights, dim=0)).detach()
        eubo = torch.mul(weights, log_final_weights).sum(0).mean()
        elbo = log_final_weights.mean(0).mean()
        ess = (1. / (weights ** 2).sum(0)).mean()

    return eubo, elbo, ess

In [6]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [7]:
ELBOs = []
EUBOs = []
ESSs = []
prior = torch.ones((K, K)).float()
for k in range(K):
    prior[k, k] = 2.0
    
Grad_Steps = int((Zs.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(num_seqs)
    time_start = time.time()
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(Grad_Steps):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_zs = Zs[batch_indices]
        batch_ys = Ys[batch_indices]
        eubo, elbo, ess = eubo_hmm_rws(prior, Pi, mus, covs, batch_ys, T, D, K, rws_samples, smc_samples, steps, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
    ESS /= Grad_Steps
    EUBO /= Grad_Steps
    ELBO /= Grad_Steps
    ESSs.append(ESS)
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    time_end = time.time()
    print('epoch : %d, EUBO : %f, ELBO : %f (%ds)' % (epoch, EUBO, ELBO, time_end - time_start))

KeyboardInterrupt: 

In [None]:
def plot_results(EUBOs, ELBOs):
    fig, ax = plt.subplots(figsize=(16,16))
    ax.set_xticks([])
    ax.set_yticks([])
    ax1 = fig.add_subplot(1,1,1)
    ax1.plot(ELBOs, 'b-', label='elbo')
    ax1.plot(EUBOs, 'r-', label='eubo')
    ax1.legend(fontsize=18)
    ax1.set_xlabel('epoch', fontsize=18)
    ax1.set_ylabel('EUBO and ELBO', fontsize=18)
    plt.savefig('results_VAE_hmm.svg')

In [None]:
plot_results(EUBOs, ELBOs)

In [None]:
save_params(ELBOs, KLs, PATH_ENC)