In [1]:
%matplotlib inline
import sys
sys.path.append("../")
sys.path.append('/home/hao/Research/probtorch/')
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from kls import *
from NG_nats import *
from utils import *
from objectives import *
from torch.distributions.normal import Normal
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
import time
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 60
K = 3
D = 2

## Model Parameters
SAMPLE_SIZE = 10
NUM_HIDDEN1 = 8
NUM_STATS = 1 + 2 * D
NUM_LATENTS =  D
## Training Parameters
SAMPLE_DIM = 0
BATCH_DIM = 1
BATCH_SIZE = 20
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-4
CUDA = torch.cuda.is_available()
PATH = 'gibbs-NG-2k-learn-stat'

gpu2 = torch.device('cuda:1')
data_path = "../gmm_dataset_conjugate2k"

In [3]:
Xs = torch.from_numpy(np.load(data_path + '/obs.npy')).float()
STATES = torch.from_numpy(np.load(data_path + '/states.npy')).float()
Pi = torch.from_numpy(np.load(data_path + '/init.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [4]:
class Enc_eta(nn.Module):
    def __init__(self, num_obs=D,
                       num_hidden1=NUM_HIDDEN1,
                       num_stats=NUM_STATS,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        
        self.gamma = nn.Sequential(
            nn.Linear(K+D, K),
            nn.Softmax(-1))
        
        self.ob = nn.Sequential(
            nn.Linear(K+D, D))
        
        self.prior_mu = torch.zeros((K, D))
        self.prior_nu = torch.ones((K, D)) * 0.3
        self.prior_alpha = torch.ones((K, D)) * 4
        self.prior_beta = torch.ones((K, D)) * 4
        if CUDA:
            self.prior_mu = self.prior_mu.cuda().to(gpu2)
            self.prior_nu = self.prior_nu.cuda().to(gpu2)
            self.prior_alpha = self.prior_alpha.cuda().to(gpu2)
            self.prior_beta = self.prior_beta.cuda().to(gpu2)
        
    def forward(self, data):
        gammas = self.gamma(data) # S * B * N * K --> S * B * N * K
        xs = self.ob(data)  # S * B * N * D --> S * B * N * D
        q_alpha, q_beta, q_mu, q_nu = post_param(xs, gammas, 
                                                 self.prior_alpha, self.prior_beta, self.prior_mu, self.prior_nu, K, D)
   
        q = probtorch.Trace()
        precisions = Gamma(q_alpha, q_beta).sample()
        q.gamma(q_alpha,
                q_beta,
                value=precisions,
                name='precisions')
        
        p = probtorch.Trace()
        p.gamma(self.prior_alpha,
                self.prior_beta,
                value=q['precisions'],
                name='precisions')   
  
        means = Normal(q_mu, 1. / (q_nu * q['precisions'].value).sqrt()).sample()
        q.normal(q_mu, 
                 1. / (q_nu * q['precisions'].value).sqrt(), 
                 value=means, 
                 name='means')
        p.normal(self.prior_mu, 
                 1. / (self.prior_nu * q['precisions'].value).sqrt(), 
                 value=q['means'], 
                 name='means')    
        return q, p, q_nu
        
def initialize():
    enc_eta = Enc_eta()
    if CUDA:
        enc_eta.cuda().to(gpu2)
    optimizer =  torch.optim.Adam(list(enc_eta.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    return enc_eta, optimizer

In [5]:
enc_eta, optimizer = initialize()

In [6]:
EUBOs = []
ELBOs = []
ESSs = []
KLs_ex = []
KLs_in = []


flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO\tELBO\tESS\tAKLs_ex\tAKLs_in\n')
flog.close()
time_start = time.time()
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    AKL_ex = 0.0
    AKL_in = 0.0
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        states = STATES[batch_indices]
        data = shuffler(torch.cat((obs, states), -1)).repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            data =data.cuda().to(gpu2)
        eubo, elbo, ess, akl_ex, akl_in = Eubo_eta_ng_stat(enc_eta, data, K, D, SAMPLE_SIZE, BATCH_SIZE)
        ## gradient step
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
        AKL_ex += akl_ex.item()
        AKL_in += akl_in.item()
    EUBOs.append(EUBO / NUM_BATCHES)
    ELBOs.append(ELBO / NUM_BATCHES)
    ESSs.append(ESS / NUM_BATCHES) 
    flog = open('results/log-' + PATH + '.txt', 'a+')
    print('%.3f\t%.3f\t%.3f\t%.3f\t%.3f'
            % (EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, AKL_ex/NUM_BATCHES, AKL_in/NUM_BATCHES), file=flog)
    flog.close()
    if epoch % 10 ==0:
        time_end = time.time()
        print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f, KL_ex=%.3f, KL_in=%.3f (%ds)'
            % (epoch, EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, 
               AKL_ex/NUM_BATCHES, AKL_in/NUM_BATCHES, time_end - time_start))
        time_start = time.time()

epoch=0, EUBO=-1004.481, ELBO=-1460.280, ESS=1.114, KL_ex=1260.018, KL_in=536.449 (0s)
epoch=10, EUBO=-640.462, ELBO=-901.980, ESS=1.180, KL_ex=702.595, KL_in=522.360 (9s)
epoch=20, EUBO=-449.476, ELBO=-586.835, ESS=1.280, KL_ex=384.861, KL_in=509.080 (9s)
epoch=30, EUBO=-366.685, ELBO=-453.201, ESS=1.374, KL_ex=251.687, KL_in=487.793 (8s)
epoch=40, EUBO=-324.389, ELBO=-386.350, ESS=1.432, KL_ex=185.360, KL_in=444.873 (9s)
epoch=50, EUBO=-292.813, ELBO=-338.102, ESS=1.527, KL_ex=136.823, KL_in=386.960 (9s)
epoch=60, EUBO=-269.232, ELBO=-301.377, ESS=1.621, KL_ex=100.201, KL_in=329.269 (9s)
epoch=70, EUBO=-253.182, ELBO=-277.657, ESS=1.689, KL_ex=76.465, KL_in=278.715 (9s)
epoch=80, EUBO=-243.284, ELBO=-263.431, ESS=1.801, KL_ex=62.406, KL_in=235.378 (8s)
epoch=90, EUBO=-237.312, ELBO=-255.314, ESS=1.906, KL_ex=54.176, KL_in=196.562 (9s)
epoch=100, EUBO=-231.829, ELBO=-248.677, ESS=1.988, KL_ex=47.570, KL_in=162.299 (9s)
epoch=110, EUBO=-227.003, ELBO=-242.654, ESS=2.086, KL_ex=41.528, 

KeyboardInterrupt: 

In [None]:
torch.save(enc_eta.state_dict(), 'weights/enc-%s' + PATH)

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(15, 15))
    fig.tight_layout()
    ax1 = fig.add_subplot(3, 1, 1)
    ax2 = fig.add_subplot(3, 1, 2)
    ax3 = fig.add_subplot(3, 1, 3)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')

    ax1.tick_params(labelsize=18)
    
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-240, -100])
    ax1.legend()
    ax3.legend()
    ax3.tick_params(labelsize=18)
    ax3.set_ylim([0,1])
    plt.savefig('train_' + PATH + '.svg')

In [None]:
plot_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)

In [None]:
def sample_single_batch(num_seqs, N, K, D, batch_size):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    obs = Xs[batch_indices]
    states = STATES[batch_indices]
    data = shuffler(torch.cat((obs, states), -1)).repeat(NUM_SAMPLES, 1, 1, 1)
    if CUDA:
        obs =data.cuda()[:, :, :, :2]
        states = data.cuda()[:, :, :, 2:]
    return obs, states

obs, states = sample_single_batch(NUM_SEQS, N, K, D, batch_size=25)
stat1, stat2, stat3 = data_to_stats(obs, states, N, K, D)
q, p, _ = enc_eta(stat1, stat2, stat3)

In [None]:
def plot_samples(obs, states, q, batch_size):
    colors = ['r', 'b', 'gold']
    fig = plt.figure(figsize=(25,25))
    xs = obs[0].cpu()
    zs = states[0].cpu()
    mu_means = q['means'].dist.loc[0].cpu().data.numpy()
    tau_means = (q['precisions'].dist.concentration[0] / q['precisions'].dist.rate[0]).cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b].data.numpy()
        z = zs[b].data.numpy()
        mu = mu_means[b].reshape(K, D)
        sigma2 = 1. / tau_means[b]
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(sigma2[k]**2)
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
            plot_cov_ellipse(cov=cov_k, pos=mu[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-15, 15])
        ax.set_xlim([-15, 15])
    plt.savefig('results/modes-' + PATH + '.svg')

In [None]:
plot_samples(obs, states, q, batch_size=25)