In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
from torch import logsumexp
import sys
import os
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [15]:
K = 3
D = 2

## Model Parameters
NUM_SAMPLES_1 = 11
NUM_SAMPLES_2 = 10
NUM_HIDDEN_PIXELS = 256
NUM_HIDDEN = 64
STEPS = 10
NUM_STATS = K+D*K+D*K
NUM_LATENTS = D * K
NUM_OBS_GLOBAL = D + K
NUM_OBS_LOCAL = D + K*D + K*D
NUM_PIXELS = 28*28
BATCH_SIZE = 99
N = BATCH_SIZE
NUM_EPOCHS = 5000
LEARNING_RATE = 1e-4
CUDA = False

In [16]:
from torchvision import datasets, transforms
DATA_PATH = 'MNIST/'
if not os.path.isdir(DATA_PATH):
    os.makedirs(DATA_PATH)
train_data = torch.utils.data.DataLoader(
                datasets.MNIST(DATA_PATH, train=True, download=True,
                               transform=transforms.ToTensor()),
                batch_size=BATCH_SIZE, shuffle=True) 
# test_data = torch.utils.data.DataLoader(
#                 datasets.MNIST(DATA_PATH, train=False, download=True,
#                                transform=transforms.ToTensor()),
#                 batch_size=BATCH_SIZE, shuffle=True)

indices_0 = train_data.dataset.train_labels.eq(0).nonzero().squeeze()
indices_1 = train_data.dataset.train_labels.eq(1).nonzero().squeeze()
indices_2 = train_data.dataset.train_labels.eq(2).nonzero().squeeze()
truncate_index = int(torch.FloatTensor([indices_0.shape[0], indices_1.shape[0], indices_2.shape[0]]).min().item())
images_0 = train_data.dataset.train_data[indices_0][:truncate_index]
images_1 = train_data.dataset.train_data[indices_1][:truncate_index]
images_2 = train_data.dataset.train_data[indices_2][:truncate_index]

In [17]:
class Encoder_latent(nn.Module):
    def __init__(self, num_obs=NUM_PIXELS,
                       num_hidden=NUM_HIDDEN_PIXELS,
                       num_latents=2):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU())
        self.latent_mean = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.latent_log_std = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, images, num_samples):
        hidden = self.enc_hidden(images)
        latent_mean = self.latent_mean(hidden)
        latent_std = torch.exp(self.latent_log_std(hidden))
        xs = Normal(latent_mean, latent_std).sample((num_samples,))
        log_q_x = Normal(latent_mean, latent_std).log_prob(xs).sum(-1) # S * B
        return latent_mean, latent_std, xs, log_q_x
    
class Decoder_latent(nn.Module):
    def __init__(self, num_obs=NUM_PIXELS,
                       num_hidden=NUM_HIDDEN_PIXELS,
                       num_latents=2, 
                       batch_size=BATCH_SIZE):
        super(self.__class__, self).__init__()
        self.latent_mean = torch.zeros((batch_size, num_latents))
        self.latent_std = torch.ones((batch_size, num_latents))
        self.dec_image = nn.Sequential(
            nn.Linear(num_latents, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_obs),
            nn.Sigmoid())
        
    def forward(self, xs, images, num_samples):
        log_prior_x = Normal(self.latent_mean, self.latent_std).log_prob(xs).sum(-1) # S * B
        recon = self.dec_image(xs)
        log_likelihood = - BCE(images, recon)
        log_p_joint = log_prior_x + log_likelihood
        return recon, log_likelihood, log_p_joint
    
    
def BCE(x, x_hat, EPS=1e-8):
    return -(torch.log(x_hat + EPS) * x 
             + torch.log(1 - x_hat + EPS) * (1-x)).sum(-1)
    
class Encoder_global(nn.Module):
    def __init__(self, num_obs=NUM_OBS_GLOBAL,
                       num_stats=NUM_STATS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_stats))
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.ReLU())
        self.sigmas_log_alpha = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.sigmas_log_beta = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
        self.enc_hidden2 = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.ReLU())
        self.mus_mean = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.mus_log_std = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
                
    def forward(self, obs, K, D, num_samples, batch_size):
        stats = self.enc_stats(obs).view(batch_size, N, -1).sum(1)
        hidden = self.enc_hidden(stats)
        alpha = torch.exp(self.sigmas_log_alpha(hidden)).view(-1, K, D) ## B * K * D
        beta = torch.exp(self.sigmas_log_beta(hidden)).view(-1, K, D) ## B * K * D
        precisions = Gamma(alpha, beta).sample((num_samples,)) ## S * B * K * D
        
        hidden2 = self.enc_hidden2(stats)                 
        mus_mean = self.mus_mean(hidden2).view(-1, K, D)
        mus_sigma = torch.exp(self.mus_log_std(hidden2).view(-1, K, D))
        mus = Normal(mus_mean, mus_sigma).sample((num_samples,))  
        return alpha, beta, precisions, mus_mean, mus_sigma, mus
    
class Encoder_local(nn.Module):
    def __init__(self, num_obs=NUM_OBS_LOCAL,
                       num_hidden=NUM_HIDDEN,
                       num_latents=K):
        super(self.__class__, self).__init__()
        self.enc_onehot = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_latents),
            nn.Softmax(-1))
        
    def forward(self, obs, N, K, D, num_samples, batch_size):
        zs_pi = self.enc_onehot(obs).view(batch_size, N, K)
        zs = cat(zs_pi).sample((num_samples,))
        log_qz = cat(zs_pi).log_prob(zs).view(num_samples, batch_size, -1).sum(-1) ## S * B
        zs = zs.view(num_samples, batch_size, -1, K) ## S * B * N * K
        return zs_pi, zs, log_qz

In [27]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 1e-2)

def initialize():
    enc_latent = Encoder_latent()
    dec_latent = Decoder_latent()
    enc_global = Encoder_global()
    enc_local = Encoder_local()
#     enc_latent.apply(weights_init)
#     dec_latent.apply(weights_init)
    enc_global.apply(weights_init)
#     enc_local.apply(weights_init)
    optimizer1 =  torch.optim.Adam(list(enc_latent.parameters()) + list(enc_global.parameters()) + list(enc_local.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    optimizer2 = torch.optim.Adam(list(dec_latent.parameters()), lr=LEARNING_RATE, betas=(0.9, 0.99))
    return enc_latent, dec_latent, enc_global, enc_local, optimizer1, optimizer2

enc_latent, dec_latent, enc_global, enc_local, optimizer1, optimizer2 = initialize()

In [28]:
prior_mean = torch.zeros((NUM_SAMPLES_2, K, D))
prior_sigma = torch.ones((NUM_SAMPLES_2, K, D))
prior_alpha = torch.ones((NUM_SAMPLES_2, K, D)) * 2.0
prior_beta = torch.ones((NUM_SAMPLES_2, K, D)) * 2.0
Pi = torch.FloatTensor([1./ 3, 1./3, 1./3])
def log_joints_gmm(X, Z, Pi, mus, precisions, N, D, K, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size):
    log_probs = torch.zeros(batch_size).float()
    ## priors on mus and sigmas, S * B
    log_probs = log_probs + Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1)
    log_probs = log_probs + Gamma(prior_alpha, prior_beta).log_prob(precisions).sum(-1).sum(-1)
    ## Z B-by-T-by-K
    log_probs = log_probs + cat(Pi).log_prob(Z).sum(-1)
    labels = Z.nonzero()
    sigmas = 1. / torch.sqrt(precisions)
    log_probs = log_probs + Normal(mus[labels[:, 0], labels[:, -1], :].view(batch_size, N, D), 
                                   sigmas[labels[:, 0], labels[:, -1], :].view(batch_size, N, D)).log_prob(X).sum(-1).sum(-1)
    return log_probs

def inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size):
    mus = Normal(prior_mean, prior_sigma).sample()
    precisions = Gamma(prior_alpha, prior_beta).sample()
    ## log prior size B
    log_p =  Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1) + Gamma(prior_alpha, prior_beta).log_prob(precisions).sum(-1).sum(-1)
    return mus, precisions, log_p

def E_step(X, mus, precisions, N, D, K, batch_size):
    mus_flat = mus.view(-1, K*D).unsqueeze(1).repeat(1, N, 1)
    covs = 1. / torch.sqrt(precisions)
    covs_flat = covs.view(-1, K*D).unsqueeze(1).repeat(1, N, 1)
    data = torch.cat((X, mus_flat, covs_flat), -1).view(batch_size*N, -1)
    zs_pi, zs, log_q_z = enc_local(data, N, K, D, 1, batch_size)
    return zs_pi, zs[0], log_q_z[0]

def M_step(X, z, N, D, K, batch_size):
    data = torch.cat((X, z), dim=-1).view(batch_size*N, -1)
    alpha, beta, precisions, mus_mean, mus_sigma, mus = enc_global(data, K, D, 1, batch_size)            
    log_q_eta =  Normal(mus_mean, mus_sigma).log_prob(mus[0]).sum(-1).sum(-1) + Gamma(alpha, beta).log_prob(precisions[0]).sum(-1).sum(-1)## B
    return mus[0], precisions[0], log_q_eta, alpha, beta, mus_mean, mus_sigma
    
    
def rws(Ys, Pi, N, K, D, num_samples, steps, batch_size):
    """
    train both encoders
    rws gradient estimator
    sis sampling scheme
    no resampling
    """
    batch_size_gmm = num_samples
    latent_mean, latent_std, Xs, log_q_x = enc_latent(Ys, num_samples)
    recon, log_likelihood, log_p_joint = dec_latent(Xs, batch_images, num_samples)
    log_outer_weights = (log_p_joint - log_q_x).sum(-1)

    log_outer_ratio = (log_likelihood - log_q_x).sum(-1)
    log_outer_weights_normalized = log_outer_weights - logsumexp(log_outer_weights, 0)
    outer_weights = torch.exp(log_outer_weights_normalized).detach()
    ## gmm part
    log_increment_weights = torch.zeros((steps, num_samples, batch_size_gmm))
    Z_samples = torch.zeros((num_samples, batch_size_gmm, N, K))
    for m in range(steps):
        if m == 0:
            for l in range(num_samples):
                mus, precisions, log_p_eta = inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size_gmm)
                zs_pi, z, log_q_z = E_step(Xs, mus, precisions, N, D, K, batch_size_gmm)
                Z_samples[l] = z
                labels = z.nonzero()
                log_p_z = cat(Pi).log_prob(z).sum(-1)
                sigmas = 1. / torch.sqrt(precisions)
                log_p_x = Normal(mus[labels[:, 0], labels[:, -1], :].view(batch_size, N, D), sigmas[labels[:, 0], labels[:, -1], :].view(batch_size, N, D)).log_prob(Xs).sum(-1).sum(-1)
                log_increment_weights[m, l] = log_p_x + log_p_z - log_q_z     
        else:
            for l in range(num_samples):
                z_prev = Z_samples[l]
                mus, precisions, log_q_eta, alpha, beta, mus_mean, mus_sigma = M_step(Xs, z_prev, N, D, K, batch_size_gmm)
                zs_pi, z, log_q_z = E_step(Xs, mus, precisions, N, D, K, batch_size_gmm)
                Z_samples[l] = z
                log_p_joint = log_joints_gmm(Xs, z, Pi, mus, precisions, N, D, K, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size_gmm)
                log_increment_weights[m, l] = log_p_joint - log_q_z - log_q_eta

    increment_weights = torch.exp(log_increment_weights - logsumexp(log_increment_weights, 1).unsqueeze(1).repeat(1, num_samples, 1)).detach()
#     ess = (1./ (increment_weights ** 2).sum(1)).mean(0).mean()
    
    eubo = torch.mul(torch.mul(increment_weights, log_increment_weights).sum(1).mean(0), outer_weights) + torch.mul(outer_weights, log_outer_ratio)
    elbo = log_increment_weights.mean(1).mean(0) + log_outer_ratio
    
    return eubo.mean(), elbo.mean()

def shuffler(images_0, images_1, images_2, K, batch_size):
    data_size = int(images_0.shape[0])
    indices_0 = torch.randperm(data_size)
    indices_1 = torch.randperm(data_size)
    indices_2 = torch.randperm(data_size)
    images_0_shuffle = images_0[indices_0]
    images_1_shuffle = images_1[indices_1]
    images_2_shuffle = images_2[indices_2]
    images_concat = torch.cat((images_0_shuffle.unsqueeze(0), images_1_shuffle.unsqueeze(0), images_2_shuffle.unsqueeze(0)), 0)
    size_per_class = int(batch_size / K)
    return images_concat, size_per_class

In [29]:
images, size_per_class = shuffler(images_0, images_1, images_2, K, BATCH_SIZE)
batch_images = images[:, 0*size_per_class : (0+1)*size_per_class].contiguous().view(size_per_class*K, 28, 28).view(-1, NUM_PIXELS)
batch_images = batch_images[torch.randperm(size_per_class*K)].float()
# rws(batch_images, Pi, N, K, D, NUM_SAMPLES, STEPS, NUM_SAMPLES)
latent_mean, latent_std, Xs, log_q_x = enc_latent(batch_images, NUM_SAMPLES_1)
log_q_xs = log_q_x.sum(-1)
recon, log_likelihood, log_p_joint = dec_latent(Xs, batch_images, NUM_SAMPLES_1)
log_likelihoods = log_likelihood.sum(-1) ## of size S1
log_p_joints = log_p_joint.sum(-1)
# gmm part, we reshape the x,  S1 * B * D, can be used as a B*N*D 'batch' input to GMM, simply for the purpose of reducing computation complexity
N = BATCH_SIZE
log_increment_weights = torch.zeros((STEPS, NUM_SAMPLES_2, NUM_SAMPLES_1))
Z_samples = torch.zeros((NUM_SAMPLES_2, NUM_SAMPLES_1, N, K))
for m in range(STEPS):
    if m == 0:
        for l in range(NUM_SAMPLES_2):
            mus, precisions, log_p_eta = inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, NUM_SAMPLES_1)
            zs_pi, z, log_q_z = E_step(Xs, mus, precisions, N, D, K, NUM_SAMPLES_1)
            Z_samples[l] = z
            labels = z.nonzero()
            log_p_z = cat(Pi).log_prob(z).sum(-1)
            sigmas = 1. / torch.sqrt(precisions)
            log_p_x = Normal(mus[labels[:, 0], labels[:, -1], :].view(NUM_SAMPLES_1, N, D), sigmas[labels[:, 0], labels[:, -1], :].view(NUM_SAMPLES_1, N, D)).log_prob(Xs).sum(-1).sum(-1)
            log_increment_weights[m, l] = log_p_x + log_p_z - log_q_z     
#     else:
#         for l in range(num_samples):
#             z_prev = Z_samples[l]
#             mus, precisions, log_q_eta, alpha, beta, mus_mean, mus_sigma = M_step(Xs, z_prev, N, D, K, batch_size_gmm)
#             zs_pi, z, log_q_z = E_step(Xs, mus, precisions, N, D, K, batch_size_gmm)
#             Z_samples[l] = z
#             log_p_joint = log_joints_gmm(Xs, z, Pi, mus, precisions, N, D, K, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size_gmm)
#             log_increment_weights[m, l] = log_p_joint - log_q_z - log_q_eta


RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 2. Got 11 and 10 in dimension 0 at /pytorch/aten/src/TH/generic/THTensorMoreMath.cpp:1333

In [30]:
%debug

> [0;32m<ipython-input-28-ff9a8137e962>[0m(30)[0;36mE_step[0;34m()[0m
[0;32m     28 [0;31m    [0mcovs[0m [0;34m=[0m [0;36m1.[0m [0;34m/[0m [0mtorch[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0mprecisions[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0mcovs_flat[0m [0;34m=[0m [0mcovs[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mK[0m[0;34m*[0m[0mD[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mrepeat[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mN[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m    [0mdata[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m([0m[0mX[0m[0;34m,[0m [0mmus_flat[0m[0;34m,[0m [0mcovs_flat[0m[0;34m)[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbatch_size[0m[0;34m*[0m[0mN[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;

ipdb>  mus_flat.shape


torch.Size([10, 99, 6])


ipdb>  u


> [0;32m<ipython-input-29-faeaa1aed50c>[0m(18)[0;36m<module>[0;34m()[0m
[0;32m     16 [0;31m        [0;32mfor[0m [0ml[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mNUM_SAMPLES_2[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m            [0mmus[0m[0;34m,[0m [0mprecisions[0m[0;34m,[0m [0mlog_p_eta[0m [0;34m=[0m [0minti_global[0m[0;34m([0m[0mK[0m[0;34m,[0m [0mD[0m[0;34m,[0m [0mprior_mean[0m[0;34m,[0m [0mprior_sigma[0m[0;34m,[0m [0mprior_alpha[0m[0;34m,[0m [0mprior_beta[0m[0;34m,[0m [0mNUM_SAMPLES_1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 18 [0;31m            [0mzs_pi[0m[0;34m,[0m [0mz[0m[0;34m,[0m [0mlog_q_z[0m [0;34m=[0m [0mE_step[0m[0;34m([0m[0mXs[0m[0;34m,[0m [0mmus[0m[0;34m,[0m [0mprecisions[0m[0;34m,[0m [0mN[0m[0;34m,[0m [0mD[0m[0;34m,[0m [0mK[0m[0;34m,[0m [0mNUM_SAMPLES_1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m            [0mZ_samples[0m[0;3

ipdb>  d


> [0;32m<ipython-input-28-ff9a8137e962>[0m(30)[0;36mE_step[0;34m()[0m
[0;32m     28 [0;31m    [0mcovs[0m [0;34m=[0m [0;36m1.[0m [0;34m/[0m [0mtorch[0m[0;34m.[0m[0msqrt[0m[0;34m([0m[0mprecisions[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m    [0mcovs_flat[0m [0;34m=[0m [0mcovs[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mK[0m[0;34m*[0m[0mD[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mrepeat[0m[0;34m([0m[0;36m1[0m[0;34m,[0m [0mN[0m[0;34m,[0m [0;36m1[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m    [0mdata[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m([0m[0mX[0m[0;34m,[0m [0mmus_flat[0m[0;34m,[0m [0mcovs_flat[0m[0;34m)[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbatch_size[0m[0;34m*[0m[0mN[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;

ipdb>  q


In [22]:
x.shape

torch.Size([1, 99, 2])

In [None]:
EUBOs = []
ELBOs = []
num_batches = int(truncate_index / BATCH_SIZE)

for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    images, size_per_class = shuffler(images_0, images_1, images_2, K, BATCH_SIZE)
    EUBO = 0.0
    ELBO = 0.0
    for step in range(num_batches):
        batch_images = images[:, step*size_per_class : (step+1)*size_per_class].contiguous().view(size_per_class*K, 28, 28).view(-1, NUM_PIXELS)
        batch_images = batch_images[torch.randperm(size_per_class*K)].float()
        optimizer1.zero_grad()
        eubo, elbo = rws(batch_images, Pi, N, K, D, NUM_SAMPLES, STEPS, NUM_SAMPLES)
        eubo.backward()
        optimizer1.step()
        optimizer2.zero_grad()
        eubo, elbo = rws(batch_images, Pi, N, K, D, NUM_SAMPLES, STEPS, NUM_SAMPLES)
        loss = - elbo
        loss.backward()
        optimizer2.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
    EUBO /= num_batches
    ELBO /= num_batches
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)

    time_end = time.time()
    print('epoch=%d, EUBO=%f, ELBO=%f (%ds)' % (epoch, EUBO, ELBO, time_end - time_start))

In [None]:
def save_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE):
    fout = open('local_amorgibbs-steps=%d-samples=%d-lr=%d.txt' % (STEPS, NUM_SAMPLES, LEARNING_RATE), 'w+')
    fout.write('EUBOs, ELBOs, ESSs\n')
    for i in range(len(EUBOs)):
        fout.write(str(EUBOs[i]) + ', ' + str(ELBOs[i]) + ', ' + str(ESSs[i]) + '\n')
    fout.close()
# torch.save(enc.state_dict(), 'models/local_amorgibbs-steps=%d-samples=%d-lr=%d' % (STEPS, NUM_SAMPLES, LEARNING_RATE))
save_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(30, 30))
    fig.tight_layout()
    ax1 = fig.add_subplot(2, 1, 1)
    ax2 = fig.add_subplot(2, 1, 2)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')
    ax1.tick_params(labelsize=18)
    ax2.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-200, -150])
    ax1.legend()
    ax2.legend()
    ax2.tick_params(labelsize=18)
    plt.savefig('local_gibbs_results_learn_both_lr=%.1E_samples=%d.svg' % (lr, num_samples))

In [None]:
plot_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)

In [None]:
torch.save(enc_global.state_dict(), 'models/enc-global-steps=%d-samples=%d-lr=%d' % (STEPS, NUM_SAMPLES, LEARNING_RATE))
torch.save(enc_local.state_dict(), 'models/enc-local-steps=%d-samples=%d-lr=%d' % (STEPS, NUM_SAMPLES, LEARNING_RATE))

In [None]:
STEPS = 10
BATCH_SIZE = 50

prior_mean = torch.zeros((BATCH_SIZE, K, D))
prior_sigma = torch.ones((BATCH_SIZE, K, D))
prior_alpha = torch.ones((BATCH_SIZE, K, D)) * 2.0
prior_beta = torch.ones((BATCH_SIZE, K, D)) * 2.0


def sample_single_batch(num_seqs, N, K, D, batch_size):
    indices = torch.randperm(num_seqs)
    batch_indices = indices[0*batch_size : (0+1)*batch_size]
    batch_Xs = Xs[batch_indices]
    batch_Xs = shuffler(batch_Xs, N, K, D, batch_size)
    return batch_Xs

def test(num_seqs, Pi, N, K, D, steps, batch_size):
    LLs = [] 
    x = sample_single_batch(num_seqs, N, K, D, batch_size)
    for m in range(steps):
        if m == 0:
            mus, precisions, log_p_eta = inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size)
            zs_pi, z, log_q_z = E_step(x, mus, precisions, N, D, K, batch_size)
        else:
            mus, precisions, log_q_eta, alpha, beta, mus_mean, mus_sigma = M_step(x, z, N, D, K, batch_size)
            zs_pi, z, log_q_z = E_step(x, mus, precisions, N, D, K, batch_size)
            labels = z.nonzero()
            sigmas = 1. / torch.sqrt(precisions)
            ll = Normal(mus[labels[:, 0], labels[:, -1], :].view(batch_size, N, D), sigmas[labels[:, 0], labels[:, -1], :].view(batch_size, N, D)).log_prob(x).sum(-1).sum(-1).mean()
            LLs.append(ll.item())
    E_precisions = alpha / beta
    E_mus = mus_mean
    E_z = torch.argmax(zs_pi, dim=-1)
    return x, z, mus, precisions, LLs, E_mus, E_precisions, E_z

In [None]:
def plot_final_samples(Xs, Zs, mus, precisions, steps, batch_size):
    colors = ['r', 'b', 'gold']
    fig = plt.figure(figsize=(25,50))
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = Xs[b].data.numpy()
        z = Zs[b].data.numpy()
        mu = mus[b].data.numpy()
        precision = precisions[b].data.numpy()

        covs = np.zeros((K, D, D))
        assignments = np.nonzero(z)[1]
        for k in range(K):
            covs[k] = np.diag(1. / precision[k])
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
            plot_cov_ellipse(cov=covs[k], pos=mu[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-10, 10])
        ax.set_xlim([-10, 10])
    plt.savefig('samples_steps=%d.svg' % (steps))

In [None]:
x, z_samples, mus_samples, precisions_samples, LLS, E_mus, E_precisions, E_z = test(num_seqs, Pi, N, K, D, STEPS, BATCH_SIZE)

In [None]:
plot_final_samples(x, z_samples, mus_samples, precisions_samples, STEPS, BATCH_SIZE)

In [None]:
def plot_final_samples(Xs, Zs, mus, precisions, steps, batch_size):
    colors = ['r', 'b', 'gold']
    fig = plt.figure(figsize=(25,50))
    for b in range(batch_size):
        ax = fig.add_subplot(int(batch_size / 5), 5, b+1)
        x = Xs[b].data.numpy()
        z = Zs[b].data.numpy()
        mu = mus[b].data.numpy()
        precision = precisions[b].data.numpy()

        covs = np.zeros((K, D, D))
        assignments = z
        for k in range(K):
            covs[k] = np.diag(1. / precision[k])
            xk = x[np.where(assignments == k)]
            ax.scatter(xk[:, 0], xk[:, 1], c=colors[k])
            plot_cov_ellipse(cov=covs[k], pos=mu[k], nstd=2, ax=ax, alpha=0.2, color=colors[k])
        ax.set_ylim([-10, 10])
        ax.set_xlim([-10, 10])
    plt.savefig('modes_steps=%d.svg' % (steps))

In [None]:
plot_final_samples(x, E_z, E_mus, E_precisions, STEPS, BATCH_SIZE)

In [None]:
plt.plot(LLS[:15])

In [None]:
plt.plot(LLS)

In [None]:
plt.plot(LLS[:80])

In [None]:
E_mus.shape