In [1]:
%matplotlib inline
import sys
sys.path.append("../")
sys.path.append('/home/hao/Research/probtorch/')
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from plots import *
from utils import *
from objectives import *
from model import *
import time
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 300
K = 3
D = 2
## Model Parameters
MCMC_SIZE = 2
SAMPLE_SIZE = 10
NUM_HIDDEN_GLOBAL = 8
NUM_HIDDEN_LOCAL = 64
STAT_SIZE = 8
## Training Parameters
BATCH_SIZE = 50
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-3
CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda:0')
PATH = 'ag-sis-mu-rad-v1'

In [3]:
Xs = torch.from_numpy(np.load('rings_varying_radius/obs.npy')).float()
OBS_RAD = torch.from_numpy(np.load('rings_varying_radius/obs_rad.npy')).float()
NUM_SEQS = Xs.shape[0]
NUM_BATCHES = int((Xs.shape[0] / BATCH_SIZE))

In [4]:
enc_mu_rad, enc_z, optimizer = initialize(NUM_HIDDEN_GLOBAL, STAT_SIZE, NUM_HIDDEN_LOCAL, K, D, CUDA, DEVICE, LEARNING_RATE)

In [5]:
EUBOs = []
ELBOs = []
ESSs = []

flog = open('results/log-' + PATH + '.txt', 'w+')
flog.write('EUBO\tELBO\tESS\n')
flog.close()
for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(NUM_SEQS)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(NUM_BATCHES):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        obs = Xs[batch_indices]
        obs = shuffler(obs).repeat(SAMPLE_SIZE, 1, 1, 1)
        if CUDA:
            obs = obs.cuda().to(DEVICE)
        eubos, elbos, esss = Eubo_ag_rad(enc_mu_rad, enc_z, obs, N, K, D, MCMC_SIZE, SAMPLE_SIZE, BATCH_SIZE, noise_sigma=0.05, device=DEVICE)
        ## gradient step
        eubos.mean().backward()
        optimizer.step()
        EUBO += eubos[-1].item()
        ELBO += elbos[-1].item()
        ESS += esss[-1].item()
    EUBOs.append(EUBO / NUM_BATCHES)
    ELBOs.append(ELBO / NUM_BATCHES)
    ESSs.append(ESS / NUM_BATCHES) 
    flog = open('results/log-' + PATH + '.txt', 'a+')
    print('%.3f\t%.3f\t%.3f'
            % (EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES), file=flog)
    flog.close()
    time_end = time.time()
    print('epoch=%d, EUBO=%.3f, ELBO=%.3f, ESS=%.3f (%ds)'
            % (epoch, EUBO/NUM_BATCHES, ELBO/NUM_BATCHES, ESS/NUM_BATCHES, 
               time_end - time_start))


epoch=0, EUBO=-148730.119, ELBO=-655732.381, ESS=2.236 (3s)
epoch=1, EUBO=-92651.178, ELBO=-602439.098, ESS=2.291 (2s)
epoch=2, EUBO=-79593.036, ELBO=-516023.454, ESS=2.332 (2s)
epoch=3, EUBO=-71216.943, ELBO=-385058.565, ESS=2.401 (2s)
epoch=4, EUBO=-65700.604, ELBO=-275792.336, ESS=2.487 (2s)
epoch=5, EUBO=-61352.627, ELBO=-201508.343, ESS=2.576 (3s)
epoch=6, EUBO=-56456.803, ELBO=-156349.119, ESS=2.670 (2s)
epoch=7, EUBO=-51606.915, ELBO=-133100.483, ESS=2.748 (2s)
epoch=8, EUBO=-48208.921, ELBO=-119784.412, ESS=2.834 (2s)
epoch=9, EUBO=-46492.433, ELBO=-111494.640, ESS=2.948 (2s)
epoch=10, EUBO=-43899.554, ELBO=-102523.092, ESS=3.101 (3s)
epoch=11, EUBO=-42569.646, ELBO=-95795.923, ESS=3.239 (2s)
epoch=12, EUBO=-41580.292, ELBO=-89091.421, ESS=3.384 (2s)
epoch=13, EUBO=-41459.799, ELBO=-87644.613, ESS=3.518 (2s)
epoch=14, EUBO=-40819.312, ELBO=-83616.675, ESS=3.653 (2s)
epoch=15, EUBO=-40723.207, ELBO=-81089.498, ESS=3.756 (3s)
epoch=16, EUBO=-39528.583, ELBO=-77619.377, ESS=3.872 

KeyboardInterrupt: 

In [None]:
BATCH_SIZE_TEST = 50

def test(enc_mu, enc_z, obs, obs_rad, N, K, D, mcmc_size, sample_size, batch_size, gpu, noise_sigma=0.1):
    p_init_z = cat(enc_z.prior_pi)
    states = p_init_z.sample((sample_size, batch_size, N,))
    log_p_z = p_init_z.log_prob(states)## S * B * N
    log_q_z = p_init_z.log_prob(states)
    for m in range(mcmc_size):
        q_mu, p_mu = enc_mu(obs, states, sample_size, batch_size)
        log_q_mu = q_mu['means'].log_prob.sum(-1)
        log_p_mu = p_mu['means'].log_prob.sum(-1) # S * B * K
        obs_mu = q_mu['means'].value
        log_obs_k = True_Log_likelihood(obs, states, obs_mu, obs_rad, K, D, noise_sigma=noise_sigma, gpu=gpu, cluster_flag=True)
        log_weights_global = log_obs_k + log_p_mu - log_q_mu
        weights_global = F.softmax(log_weights_global, 0).detach()
        ## resample mu
        obs_mu = resample_mu(obs_mu, weights_global)
        ## update z -- cluster assignments
        q_z, p_z = enc_z(obs, obs_mu, obs_rad, sample_size, batch_size, decay_factor=0)
        log_p_z = p_z['zs'].log_prob
        log_q_z = q_z['zs'].log_prob ## S * B * N
        states = q_z['zs'].value
        log_obs_n = True_Log_likelihood(obs, states, obs_mu, obs_rad, K, D, noise_sigma=noise_sigma, gpu=gpu, cluster_flag=False)        
        log_weights_local = log_obs_n + log_p_z - log_q_z
        weights_local = F.softmax(log_weights_local, 0).detach()

    return q_mu, q_z
def plot_samples(obs, q_eta, q_z, K, batch_size, PATH):
    colors = ['r', 'b', 'g']
    fig = plt.figure(figsize=(25,50))
    xs = obs[0].cpu()
    mu_mu = q_eta['means'].dist.loc[0].cpu().data.numpy()
    mu_sigma = q_eta['means'].dist.scale[0].cpu().data.numpy()
    zs = q_z['zs'].dist.probs[0].cpu().data.numpy()
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = xs[b]
        z = zs[b]
        mu_mu_b = mu_mu[b]
        mu_sigma_b = mu_sigma[b]
        assignments = z.argmax(-1)
        for k in range(K):
            cov_k = np.diag(mu_sigma_b[k]**2)
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k], alpha=0.2)
            plot_cov_ellipse(cov=cov_k, pos=mu_mu_b[k], nstd=2, ax=ax, alpha=1.0, color=colors[k])
        ax.set_ylim([-8, 8])
        ax.set_xlim([-8, 8])
    plt.savefig('results/modes-' + PATH + '.svg')
obs, obs_rad = sample_single_batch(NUM_SEQS, Xs, OBS_RAD, SAMPLE_SIZE, BATCH_SIZE_TEST, gpu)
q_eta, q_z = test(enc_mu, enc_z, obs, obs_rad, N, K, D, 20, SAMPLE_SIZE, BATCH_SIZE_TEST, gpu)
%time plot_samples(obs, q_eta, q_z, K, BATCH_SIZE_TEST, PATH)