In [7]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from nats import *
from utils import *
from objectives import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical as rcat
from torch.distributions.gamma import Gamma
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
from probtorch.util import expand_inputs
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [8]:
N = 30
K = 3
D = 2

## Model Parameters
NUM_SAMPLES = 10
NUM_HIDDEN = 32
NUM_STATS = K + D*K + D*K
NUM_LATENTS = D * K
NUM_OBS_GLOBAL = D + K
NUM_OBS_LOCAL = D + K*D + K*D
## Training Parameters
SAMPLE_DIM = 0
BATCH_DIM = 1
BATCH_SIZE = 50
NUM_EPOCHS = 2000
LEARNING_RATE = 1e-3
CUDA = torch.cuda.is_available()
PATH = 'gibbs-natparam'

In [9]:
Xs = torch.from_numpy(np.load('gmm_dataset/sequences.npy')).float()
STATES = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [10]:
class Gibbs_global(nn.Module):
    def __init__(self, num_obs=D+K,
                       num_hidden=NUM_HIDDEN,
                       num_stats=NUM_STATS,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()

        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_stats))

        self.sigmas_log_alpha = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_latents))
        
        self.sigmas_log_beta = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_latents))
        
        self.mus_mu = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_latents))
        
        self.mus_log_nu = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.Tanh(),
            nn.Linear(num_hidden, num_latents))
          
        self.prior_mu = torch.zeros((K*D))
        self.prior_nu = torch.ones((K*D)) * 0.3
        self.prior_alpha = torch.ones((K*D)) * 3.0
        self.prior_beta = torch.ones((K*D)) * 3.0
        if CUDA:
            self.prior_mu = self.prior_mu.cuda()
            self.prior_nu = self.prior_nu.cuda()
            self.prior_alpha = self.prior_alpha.cuda()
            self.prior_beta = self.prior_beta.cuda()
        
    def forward(self, data):
        q = probtorch.Trace()
        stats = self.enc_stats(data).sum(-2)
        q_alpha = self.sigmas_log_alpha(stats).exp()
        q_beta = self.sigmas_log_beta(stats).exp() 
        q_precisions = Gamma(q_alpha, q_beta)
        precisions = q_precisions.sample()
        q.gamma(q_alpha, 
                q_beta, 
                value=precisions,
                name='precisions')
        q_mu = self.mus_mu(stats)
        q_nu = self.mus_log_nu(stats).exp()
        q_sigma = 1. / (q_nu * q['precisions'].value).sqrt()
        means = Normal(q_mu, q_sigma).sample()
        q.normal(q_mu, 
                 q_sigma, 
                 value=means, 
                 name='means')

        p = probtorch.Trace()
        p.gamma(self.prior_alpha, 
                self.prior_beta, 
                value=q['precisions'], 
                name='precisions')
        p.normal(self.prior_mu, 
                 1. / (self.prior_nu * q['precisions'].value).sqrt(), 
                 value=q['means'], 
                 name='means')    
        return q, p
        
def initialize():
    gibbs_global = Gibbs_global()
    if CUDA:
        gibbs_global.cuda()
    optimizer =  torch.optim.Adam(list(gibbs_global.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    return gibbs_global, optimizer

In [11]:
gibbs_global, optimizer = initialize()

In [12]:
EUBOs = []
ELBOs = []
ESSs = []

for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        states = STATES[batch_indices]
        data = torch.cat((obs, states), -1)
        data = shuffler(data).repeat(NUM_SAMPLES, 1, 1, 1)
        if CUDA:
            data =data.cuda()
        q, p = gibbs_global(data)
        log_p_eta = p.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)
        log_q_eta = q.log_joint(sample_dims=SAMPLE_DIM, batch_dim=BATCH_DIM)
        means = q['means'].value.view(NUM_SAMPLES, BATCH_SIZE, K, D)
        precisions = q['precisions'].value.view(NUM_SAMPLES, BATCH_SIZE, K, D)
        ll = loglikelihood(data[:,:,:,:2], data[:,:,:,2:], means, precisions, D)
        log_weights = ll + log_p_eta - log_q_eta
        weights = F.softmax(log_weights, 0).detach()
        eubo = (weights * log_weights).sum(0).mean()
        elbo = log_weights.mean()
        ess = (1. / (weights**2).sum(0)).mean()
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
    EUBO /= NUM_BATCHES
    ELBO /= NUM_BATCHES
    ESS /= NUM_BATCHES
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    ESSs.append(ESS) 
    time_end = time.time()
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f (%ds)'
            % (epoch, EUBO, ELBO, ESS, time_end - time_start))


epoch=0, EUBO=-238.875, ELBO=-487.937, ESS=1.046 (0s)
epoch=1, EUBO=-191.192, ELBO=-309.368, ESS=1.071 (0s)
epoch=2, EUBO=-183.372, ELBO=-275.149, ESS=1.099 (0s)
epoch=3, EUBO=-178.125, ELBO=-251.118, ESS=1.104 (0s)
epoch=4, EUBO=-175.569, ELBO=-239.810, ESS=1.123 (0s)
epoch=5, EUBO=-172.118, ELBO=-227.437, ESS=1.121 (0s)
epoch=6, EUBO=-170.521, ELBO=-221.514, ESS=1.136 (0s)
epoch=7, EUBO=-169.174, ELBO=-216.001, ESS=1.139 (0s)
epoch=8, EUBO=-167.842, ELBO=-211.414, ESS=1.143 (0s)
epoch=9, EUBO=-167.319, ELBO=-211.938, ESS=1.131 (0s)
epoch=10, EUBO=-166.022, ELBO=-208.309, ESS=1.130 (0s)
epoch=11, EUBO=-164.424, ELBO=-205.814, ESS=1.135 (0s)
epoch=12, EUBO=-163.076, ELBO=-203.041, ESS=1.149 (0s)
epoch=13, EUBO=-163.146, ELBO=-203.180, ESS=1.140 (0s)
epoch=14, EUBO=-162.620, ELBO=-201.560, ESS=1.150 (0s)
epoch=15, EUBO=-162.114, ELBO=-201.786, ESS=1.147 (0s)
epoch=16, EUBO=-161.697, ELBO=-200.899, ESS=1.146 (0s)
epoch=17, EUBO=-161.055, ELBO=-199.555, ESS=1.145 (0s)
epoch=18, EUBO=-160.

KeyboardInterrupt: 

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(15, 15))
    fig.tight_layout()
    ax1 = fig.add_subplot(3, 1, 1)
    ax2 = fig.add_subplot(3, 1, 2)
    ax3 = fig.add_subplot(3, 1, 3)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')

    ax1.tick_params(labelsize=18)
    
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-300, -150])
    ax1.legend()
    ax3.legend()
    ax3.tick_params(labelsize=18)
    plt.savefig('train_' + PATH + '.svg')

In [None]:
plot_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)

In [None]:
def sample_single_batch(num_seqs, N, K, D, batch_size):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    batch_Xs = Xs[batch_indices]
    batch_Zs = Zs_true[batch_indices]
    batch_Xs, batch_Zs = shuffler(batch_Xs, batch_Zs, N, K, D, batch_size)
    if CUDA:
        batch_Xs = batch_Xs.repeat(10, 1, 1, 1).cuda()
        batch_Zs = batch_Zs.repeat(10, 1, 1, 1).cuda()
    return batch_Xs, batch_Zs

obs, states = sample_single_batch(NUM_SEQS, N, K, D, batch_size=25)
q, p = gibbs_global(obs, states, NUM_SAMPLES, BATCH_SIZE, N, K, D)

In [None]:
def plot_samples(obs, states, q_eta, batch_size):
    colors = ['r', 'b', 'gold']
    fig = plt.figure(figsize=(25,25))
    covs = q_eta['precisions']
    xs = obs[0].cpu()
    zs = states[0].cpu()
    alphas = q_eta['precisions'].dist.concentration[0].cpu().data.numpy()
    betas = q_eta['precisions'].dist.rate[0].cpu().data.numpy()
    precisions_mean = alphas / betas
    covs_mean = (1. / precisions_mean)
    means_mean = q_eta['means'].dist.loc[0].cpu().data.numpy()

    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b].data.numpy()
        z = zs[b].data.numpy()
#         covs = np.zeros((K, D, D))
        mu = means_mean[b].reshape(K, D)
        cov = covs_mean[b].reshape(K, D)
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(cov[k])
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
#             ax.scatter(x[:, 0], x[:, 1])
            plot_cov_ellipse(cov=cov_k, pos=mu[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-10, 10])
        ax.set_xlim([-10, 10])
    plt.savefig('results/modes' + PATH + '.svg')

In [None]:
plot_samples(obs, states, q, batch_size=25)

In [None]:
states[0,0].argmax(-1)