In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from nats import *
from utils import *
from objectives import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 60
K = 3
D = 2

## Model Parameters
MCMC_SIZE = 10
SAMPLE_SIZE = 10
NUM_HIDDEN1 = 8
NUM_STATS = 1 + 2 * D
NUM_LATENTS =  D
## Training Parameters
SAMPLE_DIM = 0
BATCH_DIM = 1
BATCH_SIZE = 20
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-3
CUDA = torch.cuda.is_available()
PATH = 'ag-sis-init-z-NG'

In [3]:
Xs = torch.from_numpy(np.load('gmm_dataset/obs.npy')).float()
STATES = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
OBS_MU = torch.from_numpy(np.load('gmm_dataset/obs_mu.npy')).float()
OBS_SIGMA = torch.from_numpy(np.load('gmm_dataset/obs_sigma.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [4]:
class Enc_eta(nn.Module):
    def __init__(self, num_obs=D,
                       num_hidden1=NUM_HIDDEN1,
                       num_stats=NUM_STATS,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()

        self.mus_mu = nn.Sequential(
            nn.Linear(num_stats, num_hidden1),
            nn.Tanh(),
            nn.Linear(num_hidden1, int(0.5*num_hidden1)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden1), num_latents))
        self.mus_log_nu = nn.Sequential(
            nn.Linear(num_stats, num_hidden1),
            nn.Tanh(),
            nn.Linear(num_hidden1, int(0.5*num_hidden1)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden1), num_latents))
        
        self.tau_log_alpha = nn.Sequential(
            nn.Linear(num_stats, num_hidden1),
            nn.Tanh(),
            nn.Linear(num_hidden1, int(0.5*num_hidden1)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden1), num_latents))
        
        self.tau_log_beta = nn.Sequential(
            nn.Linear(num_stats, num_hidden1),
            nn.Tanh(),
            nn.Linear(num_hidden1, int(0.5*num_hidden1)),
            nn.Tanh(),
            nn.Linear(int(0.5*num_hidden1), num_latents))
        
        self.prior_mu = torch.zeros((K, D))
        self.prior_nu = torch.ones((K, D)) * 0.5
        self.prior_alpha = torch.ones((K, D)) * 3
        self.prior_beta = torch.ones((K, D)) * 3
        if CUDA:
            self.prior_mu = self.prior_mu.cuda()
            self.prior_nu = self.prior_nu.cuda()
            self.prior_alpha = self.prior_alpha.cuda()
            self.prior_beta = self.prior_beta.cuda()
        
    def forward(self, stat1, stat2, stat3):
        q = probtorch.Trace()
        stats_c1 = torch.cat((stat1[:, :, 0].unsqueeze(-1), stat2[:, :, 0, :], stat3[:, :, 0, :]), -1) ## S * B * 1+2*D
        stats_c2 = torch.cat((stat1[:, :, 1].unsqueeze(-1), stat2[:, :, 1, :], stat3[:, :, 1, :]), -1) ## S * B * 1+2*D
        stats_c3 = torch.cat((stat1[:, :, 2].unsqueeze(-1), stat2[:, :, 2, :], stat3[:, :, 2, :]), -1) ## S * B * 1+2*D
        ##
        q_alpha1 = self.tau_log_alpha(stats_c1).exp()
        q_beta1 = self.tau_log_beta(stats_c1).exp()
        q_alpha2 = self.tau_log_alpha(stats_c2).exp()
        q_beta2 = self.tau_log_beta(stats_c2).exp()
        q_alpha3 = self.tau_log_alpha(stats_c3).exp()
        q_beta3 = self.tau_log_beta(stats_c3).exp()
        ##
        q_alpha = torch.cat((q_alpha1.unsqueeze(-2), q_alpha2.unsqueeze(-2), q_alpha3.unsqueeze(-2)), -2)
        q_beta = torch.cat((q_beta1.unsqueeze(-2), q_beta2.unsqueeze(-2), q_beta3.unsqueeze(-2)), -2)
        precisions = Gamma(q_alpha, q_beta).sample()
        q.gamma(q_alpha,
                q_beta,
                value=precisions,
                name='precisions')
        
        p = probtorch.Trace()
        p.gamma(self.prior_alpha,
                self.prior_beta,
                value=q['precisions'],
                name='precisions')   
        ##
        q_mu1 = self.mus_mu(stats_c1)
        q_nu1 = self.mus_log_nu(stats_c1).exp()
        q_mu2 = self.mus_mu(stats_c2)
        q_nu2 = self.mus_log_nu(stats_c2).exp()  
        q_mu3 = self.mus_mu(stats_c3)
        q_nu3 = self.mus_log_nu(stats_c3).exp()     
        
        q_mu = torch.cat((q_mu1.unsqueeze(-2), q_mu2.unsqueeze(-2), q_mu3.unsqueeze(-2)), -2)
        q_nu = torch.cat((q_nu1.unsqueeze(-2), q_nu2.unsqueeze(-2), q_nu3.unsqueeze(-2)), -2)
        ##
        means = Normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt()).sample()
        q.normal(q_mu, 
                 1. / (q_nu * q['precisions'].value).sqrt(), 
                 value=means, 
                 name='means')
        p.normal(self.prior_mu, 
                 1. / (self.prior_nu * q['precisions'].value).sqrt(), 
                 value=q['means'], 
                 name='means')    
        return q, p, q_nu

class Enc_z(nn.Module):
    def __init__(self, num_obs=3*D,
                       num_hidden=NUM_HIDDEN1,
                       num_latents=K):
        super(self.__class__, self).__init__()
        self.pi_prob = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, int(0.5*num_hidden)),
            nn.Linear(int(0.5*num_hidden), 1))
        
        self.prior_pi = torch.ones(K) * (1./ K)
        if CUDA:
            self.prior_pi = self.prior_pi.cuda()
  
    def forward(self, obs, obs_tau, obs_mu, sample_size, batch_size):
        q = probtorch.Trace()
        obs_tau_c1 = obs_tau[:, :, 0, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_mu_c1 = obs_mu[:, :, 0, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_tau_c2 = obs_tau[:, :, 1, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_mu_c2 = obs_mu[:, :, 1, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_tau_c3 = obs_tau[:, :, 2, :].unsqueeze(-2).repeat(1,1,N,1)
        obs_mu_c3 = obs_mu[:, :, 2, :].unsqueeze(-2).repeat(1,1,N,1)
        
        data_c1 = torch.cat((obs, obs_mu_c1, obs_tau_c1), -1) ## S * B * N * 3D
        data_c2 = torch.cat((obs, obs_mu_c2, obs_tau_c2), -1) ## S * B * N * 3D
        data_c3 = torch.cat((obs, obs_mu_c3, obs_tau_c3), -1) ## S * B * N * 3D
        
        z_pi_c1 = self.pi_prob(data_c1)
        z_pi_c2 = self.pi_prob(data_c2)
        z_pi_c3 = self.pi_prob(data_c3)
        
        z_pi = F.softmax(torch.cat((z_pi_c1, z_pi_c2, z_pi_c3), -1), -1)
        z = cat(z_pi).sample()
        _ = q.variable(cat, probs=z_pi, value=z, name='zs')
        p = probtorch.Trace()
        _ = p.variable(cat, probs=self.prior_pi, value=z, name='zs')
        return q, p
    
def initialize():
    enc_eta = Enc_eta()
    enc_z = Enc_z()
    if CUDA:
        enc_eta.cuda()
        enc_z.cuda()
    optimizer =  torch.optim.Adam(list(enc_z.parameters())+list(enc_eta.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    return enc_z, enc_eta, optimizer

In [5]:
enc_z, enc_eta, optimizer = initialize()

In [6]:
def Eubo_ag_sis(enc_eta, enc_z, obs, N, K, D, mcmc_size, sample_size, batch_size):
    """
    initialize z
    incremental weight doesn't involve backward transition
    """
    eubos = torch.zeros(mcmc_size).cuda()
    elbos = torch.zeros(mcmc_size).cuda()
    esss = torch.zeros(mcmc_size).cuda()
    for m in range(mcmc_size):
        if m == 0:
            p_init_z = cat(torch.ones(K).cuda() * (1. / K))
            states = p_init_z.sample((sample_size, batch_size, N,))
#             log_p_z = p_init_z.log_prob(states)## S * B * N
#             log_q_z = p_init_z.log_prob(states)
            log_p_z = p_init_z.log_prob(states).sum(-1)## S * B * N
            log_q_z = p_init_z.log_prob(states).sum(-1)
        else:
            ## update z -- cluster assignments
            q_z, p_z = enc_z(obs, obs_sigma, obs_mu, sample_size, batch_size)
#             log_q_z = q_z['zs'].log_prob ## S * B * N
#             log_p_z = p_z['zs'].log_prob ## S * B * N
            log_p_z = p_z.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)
            log_q_z = q_z.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)
            states = q_z['zs'].value ## S * B * N * K
        labels = states.argmax(-1)
#         log_p_z_c = torch.cat([((labels==k).float() * log_p_z).sum(-1).unsqueeze(-1) for k in range(K)], -1)
#         log_q_z_c = torch.cat([((labels==k).float() * log_q_z).sum(-1).unsqueeze(-1) for k in range(K)], -1)
        ## update tau and mu -- global variables
        stat1, stat2, stat3 = data_to_stats(obs, states, K, D)
        ##
        q_eta, p_eta, q_nu = enc_eta(stat1, stat2, stat3)
        ## for individual importance weight, S * B * K
#         log_q_eta = q_eta['precisions'].log_prob.sum(-1) + q_eta['means'].log_prob.sum(-1)
#         log_p_eta = p_eta['precisions'].log_prob.sum(-1) + p_eta['means'].log_prob.sum(-1)
        log_p_eta = p_eta.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)
        log_q_eta = q_eta.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)

        obs_mu = q_eta['means'].value
        obs_tau = q_eta['precisions'].value
        obs_sigma = 1. / obs_tau.sqrt()
        ##
        log_obs = Log_likelihood(obs, states, obs_mu, obs_sigma, K, D, cluster_flag=False)
#         log_weights = log_obs + log_p_eta + log_p_z_c - log_q_eta - log_q_z_c
        log_weights = log_obs.sum(-1) + log_p_eta + log_p_z - log_q_eta - log_q_z
        weights = F.softmax(log_weights, 0).detach()
        eubos[m] = (weights * log_weights).sum(0).mean()
        elbos[m] = log_weights.mean(0).mean()
        esss[m] = (1. / (weights**2).sum(0)).mean()

    ## KLs for mu and sigma based on Normal-Gamma prior
    q_mu = q_eta['means'].dist.loc
    q_alpha = q_eta['precisions'].dist.concentration
    q_beta = q_eta['precisions'].dist.rate
    q_logits = q_z['zs'].dist.probs.log()
    stat1, stat2, stat3 = data_to_stats(obs, states, K, D)
    post_mu, post_nu, post_alpha, post_beta = Post_mu_tau(stat1, stat2, stat3, enc_eta.prior_mu, enc_eta.prior_nu, enc_eta.prior_alpha, enc_eta.prior_beta, D)
    kl_eta_ex, kl_eta_in = kls_NGs(q_mu, q_nu, q_alpha, q_beta, post_mu, post_nu, post_alpha, post_beta)
    ## KLs for cluster assignments
    post_logits = Post_z(obs, obs_sigma, obs_mu, N, K)
    kl_z_ex, kl_z_in = kls_cats(q_logits, post_logits)

    return eubos.mean(), elbos.mean(), esss.mean(), kl_eta_ex.sum(-1).mean(), kl_eta_in.sum(-1).mean(), kl_z_ex.sum(-1).mean(), kl_z_in.sum(-1).mean()


In [None]:
flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO\tELBO\tESS\tKLs_eta_ex\tKLs_eta_in\tKLs_z_ex\tKLs_z_in\n')
flog.close()
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    KL_eta_ex = 0.0
    KL_eta_in = 0.0
    KL_z_ex = 0.0
    KL_z_in = 0.0
    
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        obs = shuffler(obs).repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            obs =obs.cuda()
        eubo, elbo, ess, kl_eta_ex, kl_eta_in, kl_z_ex, kl_z_in = Eubo_ag_sis(enc_eta, enc_z, obs, N, K, D, MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE)
        ## gradient step
        eubo.backward()
        
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
        KL_eta_ex += kl_eta_ex.item() 
        KL_eta_in += kl_eta_in.item()
        KL_z_ex += kl_z_ex.item() 
        KL_z_in += kl_z_in.item()
        
    flog = open('results/log-' + PATH + '.txt', 'a+')
    print('%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f'
            % (EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, KL_eta_ex/NUM_BATCHES, KL_eta_in/NUM_BATCHES, KL_z_ex/NUM_BATCHES, KL_z_in/NUM_BATCHES), file=flog)
    flog.close()
 
    time_end = time.time()
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f, eta_ex=%.3f, eta_in=%.3f, z_ex=%.3f, z_in=%.3f (%ds)'
        % (epoch, EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, 
           KL_eta_ex/NUM_BATCHES, KL_eta_in/NUM_BATCHES, KL_z_ex/NUM_BATCHES, KL_z_in/NUM_BATCHES, time_end - time_start))

epoch=0, EUBO=-496.842, ELBO=-763.147, ESS=1.040, eta_ex=405.914, eta_in=24.738, z_ex=235.275, z_in=34.497 (48s)
epoch=1, EUBO=-409.921, ELBO=-477.926, ESS=1.096, eta_ex=120.261, eta_in=39.989, z_ex=54.982, z_in=19.264 (53s)
epoch=2, EUBO=-394.902, ELBO=-432.597, ESS=1.168, eta_ex=74.861, eta_in=43.814, z_ex=27.994, z_in=13.197 (61s)
epoch=3, EUBO=-385.120, ELBO=-411.533, ESS=1.230, eta_ex=53.988, eta_in=36.536, z_ex=20.405, z_in=11.140 (63s)
epoch=4, EUBO=-377.001, ELBO=-397.114, ESS=1.298, eta_ex=39.535, eta_in=30.197, z_ex=17.380, z_in=10.264 (64s)
epoch=5, EUBO=-372.973, ELBO=-389.796, ESS=1.341, eta_ex=32.243, eta_in=28.577, z_ex=16.428, z_in=9.941 (62s)
epoch=6, EUBO=-371.790, ELBO=-387.174, ESS=1.362, eta_ex=29.640, eta_in=29.097, z_ex=15.678, z_in=9.735 (61s)
epoch=7, EUBO=-371.310, ELBO=-385.915, ESS=1.374, eta_ex=28.368, eta_in=29.871, z_ex=15.247, z_in=9.585 (63s)
epoch=8, EUBO=-371.123, ELBO=-385.189, ESS=1.387, eta_ex=27.629, eta_in=31.199, z_ex=14.808, z_in=9.460 (56s)
ep

In [None]:
torch.save(enc_eta.state_dict(), 'weights/enc-%s' + PATH)

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(15, 15))
    fig.tight_layout()
    ax1 = fig.add_subplot(3, 1, 1)
    ax2 = fig.add_subplot(3, 1, 2)
    ax3 = fig.add_subplot(3, 1, 3)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')

    ax1.tick_params(labelsize=18)
    
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-240, -100])
    ax1.legend()
    ax3.legend()
    ax3.tick_params(labelsize=18)
    ax3.set_ylim([0,1])
    plt.savefig('train_' + PATH + '.svg')

In [None]:
def sample_single_batch(num_seqs, N, K, D, batch_size):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    obs = Xs[batch_indices]
    obs = shuffler(obs).repeat(SAMPLE_SIZE, 1, 1, 1)
    if CUDA:
        obs =obs.cuda()[:, :, :, :2]
    return obs

def test(mcmc_size, sample_size, batch_size):
    LLs = []
    for m in range(mcmc_size):
        if m == 0:
            states = cat(torch.ones(K)* (1. / K)).sample((sample_size, batch_size, N,)).cuda()
        ## update tau and mu -- global variabl
        stat1, stat2, stat3 = data_to_stats(obs, states, K, D)
        ##
        q_eta, p_eta, q_nu = enc_eta(stat1, stat2, stat3)

        obs_mu = q_eta['means'].value
        obs_tau = q_eta['precisions'].value
        obs_sigma = 1. / obs_tau.sqrt() 
        ## update z -- cluster assignments    
        q_z, p_z = enc_z(obs, obs_tau, obs_mu, sample_size, batch_size)
        states = q_z['zs'].value ## S * B * N * K
        labels = states.argmax(-1)
        labels_flat = labels.unsqueeze(-1).repeat(1, 1, 1, D)
        obs_mu_expand = torch.gather(obs_mu, 2, labels_flat)
        obs_sigma_expand = torch.gather(obs_sigma, 2, labels_flat)
        log_obs = Normal(obs_mu_expand, obs_sigma_expand).log_prob(obs).sum(-1).sum(-1) ## S * B
        ll = (log_obs + p_eta['means'].log_prob.sum(-1).sum(-1) + p_eta['precisions'].log_prob.sum(-1).sum(-1) + p_z['zs'].log_prob.sum(-1)).mean(0).unsqueeze(0)
        LLs.append(ll)
    LLs = torch.cat(LLs, 0).cpu()
    return LLs, states, q_eta

def plot_ll(LLs, batch_size):
    fig = plt.figure(figsize=(15,15))
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        ax.plot(LLs[:, b].data.numpy())
        
def plot_samples(obs, states, q, batch_size):
    colors = ['r', 'b', 'gold']MCMC_SIZE = 20
LLs, states, q_eta = test(MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE_TEST)
plot_ll(LLs, BATCH_SIZE_TEST)
plot_samples(obs, states, q_eta, BATCH_SIZE_TEST)
%time
    fig = plt.figure(figsize=(15,15))
    xs = obs[0].cpu()
    zs = states[0].cpu()
    mu_means = q['means'].dist.loc[0].cpu().data.numpy()
    tau_means = (q['precisions'].dist.concentration[0] / q['precisions'].dist.rate[0]).cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b].data.numpy()
        z = zs[b].data.numpy()
        mu = mu_means[b].reshape(K, D)
        sigma2 = 1. / tau_means[b]
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(sigma2[k]**2)
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
            plot_cov_ellipse(cov=cov_k, pos=mu[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-15, 15])
        ax.set_xlim([-15, 15])
    plt.savefig('results/modes-' + PATH + '.svg')        

In [None]:
BATCH_SIZE_TEST = 25
obs = sample_single_batch(NUM_SEQS, N, K, D, BATCH_SIZE_TEST)

In [None]:
MCMC_SIZE = 20
LLs, states, q_eta = test(MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE_TEST)
plot_ll(LLs, BATCH_SIZE_TEST)
plot_samples(obs, states, q_eta, BATCH_SIZE_TEST)
%time