In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
from torch import logsumexp
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 30
K = 3
D = 2

## Model Parameters
NUM_SAMPLES = 10
NUM_HIDDEN = 64
STEPS = 10
NUM_STATS = K+D*K+D*K
NUM_LATENTS = D * K
NUM_OBS = D + K
BATCH_SIZE = 10
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-4
CUDA = False

In [3]:
Xs = torch.from_numpy(np.load('gmm_dataset/sequences.npy')).float()
Zs = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
# mus_true = torch.from_numpy(np.load('gmm_dataset/means.npy')).float()
# covs = torch.from_numpy(np.load('gmm_dataset2/covariances.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
num_seqs = Zs.shape[0]

In [4]:
def StatsGMM(Xs, Zs, K, D):
    """
    Xs is B * N * D
    Zs is B * N * K
    stat1 corresponds I[z_n=1], ..., I[z_n=K]
    stat2 corresponds I[z_n=1]x_n, ..., I[z_n=K]x_n
    stat3 corresponds I[z_n=1]x_n**2, ..., I[z_n=K]x_n**2
    return B * (K+D*K+D*K)
    """
    stat1 = Zs.sum(1)
    stat2 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), Xs.unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    stat3 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), torch.mul(Xs, Xs).unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    return stat1, stat2, stat3, torch.cat((stat1, stat2.view(-1, D*K), stat3.view(-1, D*K)), dim=-1)

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_stats=NUM_STATS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_stats))
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.ReLU())
        self.sigmas_log_alpha = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.sigmas_log_beta = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
        self.enc_hidden2 = nn.Sequential(
            nn.Linear(num_stats, num_hidden),
            nn.ReLU())
        self.mus_mean = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.mus_log_std = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, K, D, num_samples, batch_size):
        stats = self.enc_stats(obs).view(batch_size, N, -1).sum(1)
        hidden = self.enc_hidden(stats)
        alpha = torch.exp(self.sigmas_log_alpha(hidden)).view(-1, K, D) ## B * K * D
        beta = torch.exp(self.sigmas_log_beta(hidden)).view(-1, K, D) ## B * K * D
        precisions = Gamma(alpha, beta).sample((num_samples,)) ## S * B * K * D
        
        hidden2 = self.enc_hidden2(stats)                 
        mus_mean = self.mus_mean(hidden2).view(-1, K, D)
        mus_sigma = torch.exp(self.mus_log_std(hidden2).view(-1, K, D))
        mus = Normal(mus_mean, mus_sigma).sample((num_samples,))  
        return alpha, beta, precisions, mus_mean, mus_sigma, mus

In [6]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE, betas=(0.9, 0.99))    
    return enc, optimizer
enc, optimizer = initialize()

In [7]:
prior_mean = torch.zeros((BATCH_SIZE, K, D))
prior_sigma = torch.ones((BATCH_SIZE, K, D))
prior_alpha = torch.ones((BATCH_SIZE, K, D)) * 2.0
prior_beta = torch.ones((BATCH_SIZE, K, D)) * 2.0

def log_joints_gmm(X, Z, Pi, mus, precisions, N, D, K, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size):
    log_probs = torch.zeros(batch_size).float()
    ## priors on mus and sigmas, S * B
    log_probs = log_probs + Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1)
    log_probs = log_probs + Gamma(prior_alpha, prior_beta).log_prob(precisions).sum(-1).sum(-1)
    ## Z B-by-T-by-K
    log_probs = log_probs + cat(Pi).log_prob(Z).sum(-1)
    labels = Z.nonzero()
    sigmas = 1. / torch.sqrt(precisions)
    log_probs = log_probs + Normal(mus[labels[:, 0], labels[:, -1], :].view(batch_size, N, D), 
                                   sigmas[labels[:, 0], labels[:, -1], :].view(batch_size, N, D)).log_prob(X).sum(-1).sum(-1)
    return log_probs

def inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size):
    mus = Normal(prior_mean, prior_sigma).sample()
    precisions = Gamma(prior_alpha, prior_beta).sample()
    ## log prior size B
    log_p =  Normal(prior_mean, prior_sigma).log_prob(mus).sum(-1).sum(-1) + Gamma(prior_alpha, prior_beta).log_prob(precisions).sum(-1).sum(-1)
    return mus, precisions, log_p

def E_step(X, mus, precisions, N, D, K):
    mus_expand = mus.unsqueeze(0).repeat(N, 1, 1, 1).transpose(0, 2)
    sigmas = 1. / torch.sqrt(precisions)
    sigmas_expand = sigmas.unsqueeze(0).repeat(N, 1, 1, 1).transpose(0, 2) ## K * B * N * D
    log_gammas = Normal(mus_expand, sigmas_expand).log_prob(X).sum(-1).permute(1, 2, 0) ## B * N * K
    gammas = torch.exp(log_gammas - logsumexp(log_gammas, -1).unsqueeze(-1))
    Z = cat(gammas).sample() ## B * N * K
    log_q_z = cat(gammas).log_prob(Z).sum(-1) ## B
    return Z, log_q_z

def rws(Xs, Pi, N, K, D, num_samples, steps, batch_size):
    """
    rws gradient estimator
    sis sampling scheme
    no resampling
    """
    log_increment_weights = torch.zeros((steps, num_samples, batch_size))
    log_uptonow_weights = torch.zeros((steps, num_samples, batch_size))
    Z_samples = torch.zeros((num_samples, batch_size, N, K))
    for m in range(steps):
        if m == 0:
            for l in range(num_samples):
                mus, precisions, log_p = inti_global(K, D, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size)
                Z, log_q_z = E_step(Xs, mus, precisions, N, D, K)
                Z_samples[l] = Z
                labels = Z.nonzero()
                log_likelihood_z = cat(Pi).log_prob(Z).sum(-1)
                sigmas = 1. / torch.sqrt(precisions)
                log_likelihood_x = Normal(mus[labels[:, 0], labels[:, -1], :].view(batch_size, N, D), sigmas[labels[:, 0], labels[:, -1], :].view(batch_size, N, D)).log_prob(Xs).sum(-1).sum(-1)
                log_increment_weights[m, l] = log_likelihood_x + log_likelihood_z - log_q_z     
                log_uptonow_weights[m, l] = log_likelihood_x + log_likelihood_z - log_q_z       
        else:
            for l in range(num_samples):
                Z = Z_samples[l]
                data = torch.cat((Xs, Z), dim=-1).view(batch_size*N, -1)
                alpha, beta, precisions, mus_mean, mus_sigma, mus = enc(data, K, D, 1, batch_size)            
                log_q_eta =  Normal(mus_mean, mus_sigma).log_prob(mus[0]).sum(-1).sum(-1) + Gamma(alpha, beta).log_prob(precisions[0]).sum(-1).sum(-1)## B
                Z, log_q_z = E_step(Xs, mus[0], precisions[0], N, D, K)
                Z_samples[l] = Z
                log_p = log_joints_gmm(Xs, Z, Pi, mus[0], precisions[0], N, D, K, prior_mean, prior_sigma, prior_alpha, prior_beta, batch_size)
                log_increment_weights[m, l] = log_p - log_q_z - log_q_eta
                log_uptonow_weights[m ,l] = log_increment_weights[m, l] + log_uptonow_weights[m-1 ,l]
#     log_final_weights = log_uptonow_weights[-1]
#     log_local_weights = log_increment_weights[-1]
    
    
#     local_weights = torch.exp(log_local_weights - logsumexp(log_local_weights, 0)).detach()
#     weights = torch.exp(log_final_weights - logsumexp(log_final_weights, 0)).detach()
    overall_weights = torch.exp(log_uptonow_weights - logsumexp(log_uptonow_weights, 1).unsqueeze(1).repeat(1, num_samples, 1)).detach()
    
#     eubo = torch.mul(weights, log_final_weights).sum(0).mean()
#     elbo = log_increment_weights.mean(0).mean()
    
    ess = (1./ (overall_weights ** 2).sum(1)).,mean(0).mean()
    
#     eubolocal = torch.mul(local_weights, log_local_weights).sum(0).mean()
#     elbolocal = log_local_weights.mean(0).mean()
    
    euboave = torch.mul(overall_weights, log_increment_weights).sum(1).mean(0).mean()
    elboave = log_increment_weights.mean(1).mean(0).mean()
    
    return euboave, elboave, ess 

def shuffler(batch_Xs, N, K, D, batch_size):
    indices = torch.cat([torch.randperm(N).unsqueeze(0) for b in range(batch_size)])
    indices_Xs = indices.unsqueeze(-1).repeat(1, 1, D)
    return torch.gather(batch_Xs, 1, indices_Xs)

In [8]:
EUBOs = []
ELBOs = []
ESSs = []
num_batches = int((Xs.shape[0] / BATCH_SIZE))

for epoch in range(NUM_EPOCHS):
    time_start = time.time()
    indices = torch.randperm(num_seqs)
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    for step in range(num_batches):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_Xs = Xs[batch_indices]
        batch_Xs = shuffler(batch_Xs, N, K, D, BATCH_SIZE)
        eubo, elbo, euboloal, elboLocal, euboave, elboave, ess = rws(batch_Xs, Pi, N, K, D, NUM_SAMPLES, STEPS, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
    EUBO /= num_batches
    ELBO /= num_batches
    ESS /= num_batches
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    ESSs.append(ESS)

    time_end = time.time()
    print('epoch=%d, EUBO=%f, ELBO=%f, ESS=%.3f (%ds)' % (epoch, EUBO, ELBO, ESS, time_end - time_start))

epoch=0, EUBO=-839450335.525469, ELBO=-170514700.580781, ESS=1.001 (43s)
epoch=1, EUBO=-29183862.976523, ELBO=-8770010.245623, ESS=1.001 (43s)
epoch=2, EUBO=-134353.828242, ELBO=-15572.746381, ESS=1.001 (41s)
epoch=3, EUBO=-43321.954048, ELBO=-4718.822240, ESS=1.003 (42s)
epoch=4, EUBO=-15661.659045, ELBO=-1754.890521, ESS=1.007 (43s)
epoch=5, EUBO=-7801.204644, ELBO=-917.346822, ESS=1.014 (43s)
epoch=6, EUBO=-6326.921257, ELBO=-783.293335, ESS=1.013 (43s)
epoch=7, EUBO=-4437.965813, ELBO=-528.924717, ESS=1.015 (45s)
epoch=8, EUBO=-3751.567554, ELBO=-438.761576, ESS=1.029 (44s)
epoch=9, EUBO=-3357.231270, ELBO=-377.666663, ESS=1.024 (45s)
epoch=10, EUBO=-3139.800137, ELBO=-346.696580, ESS=1.022 (45s)
epoch=11, EUBO=-4662.925083, ELBO=-547.092782, ESS=1.022 (44s)


KeyboardInterrupt: 

In [None]:
def save_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE):
    fout = open('ave_amorgibbs-steps=%d-samples=%d-lr=%d.txt' % (STEPS, NUM_SAMPLES, LEARNING_RATE), 'w+')
    fout.write('EUBOs, ELBOs, ESSs\n')
    for i in range(len(EUBOs)):
        fout.write(str(EUBOs[i]) + ', ' + str(ELBOs[i]) + ', ' + str(ESSs[i]) + '\n')
    fout.close()
torch.save(enc.state_dict(), 'models/ave_amorgibbs-steps=%d-samples=%d-lr=%d' % (STEPS, NUM_SAMPLES, LEARNING_RATE))
save_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(30, 30))
    fig.tight_layout()
    ax1 = fig.add_subplot(3, 1, 1)
    ax2 = fig.add_subplot(3, 1, 2)
    ax3 = fig.add_subplot(3, 1, 3)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')
    ax2.plot(TrueKls_exclusive, '#66b3ff', label='true exclusive KL')
    ax2.plot(MCKls_exclusive, '#ff9999', label='est exclusive KL')
    ax2.plot(TrueKls_inclusive, '#99ff99', label='true inclusive KL')
    ax2.plot(MCKls_inclusive, 'gold', label='est inclusive KL')
    
    ax1.tick_params(labelsize=18)
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-200, -150])
    ax1.legend()
    ax2.set_ylim([-50, 50])
    ax2.legend()
    ax3.legend()
    ax2.tick_params(labelsize=18)
    ax3.tick_params(labelsize=18)
    plt.savefig('gibbs_results_learn_both_lr=%.1E_samples=%d.svg' % (lr, num_samples))

In [None]:
plot_results(EUBOs, ELBOs, ESSs, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)