In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
from torch import logsumexp
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 25
K = 3
D = 2

## Model Parameters
NUM_SAMPLES = 1
NUM_HIDDEN = 64
NUM_STATS = K+D*K+D*K
NUM_LATENTS = D * K
NUM_OBS = D + K
BATCH_SIZE = 100
NUM_EPOCHS = 1000
LEARNING_RATE = 1e-4
CUDA = False

In [3]:
Xs = torch.from_numpy(np.load('gmm_dataset/sequences.npy')).float()
Zs = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
# mus_true = torch.from_numpy(np.load('gmm_dataset/means.npy')).float()
# covs = torch.from_numpy(np.load('gmm_dataset2/covariances.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
num_seqs = Zs.shape[0]

In [4]:
def StatsGMM(Xs, Zs, K, D):
    """
    Xs is B * N * D
    Zs is B * N * K
    stat1 corresponds I[z_n=1], ..., I[z_n=K]
    stat2 corresponds I[z_n=1]x_n, ..., I[z_n=K]x_n
    stat3 corresponds I[z_n=1]x_n**2, ..., I[z_n=K]x_n**2
    return B * (K+D*K+D*K)
    """
    stat1 = Zs.sum(1)
    stat2 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), Xs.unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    stat3 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), torch.mul(Xs, Xs).unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    return stat1, stat2, stat3, torch.cat((stat1, stat2.view(-1, D*K), stat3.view(-1, D*K)), dim=-1)

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_stats=NUM_STATS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_stats),
            nn.ReLU())
#         self.enc_hidden = nn.Sequential(
#             nn.Linear(num_stats, num_hidden),
#             nn.ReLU())
#         self.mus_mean = nn.Sequential(
#             nn.Linear(num_hidden, num_latents))
#         self.mus_log_std = nn.Sequential(
#             nn.Linear(num_hidden, num_latents))
#         self.sigmas_log_alpha = nn.Sequential(
#             nn.Linear(num_hidden, num_latents))
#         self.sigmas_log_beta = nn.Sequential(
#             nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, K, D, num_samples, batch_size):
        stats = self.enc_stats(obs).view(batch_size, N, -1).sum(1) ** 2
#         hidden = self.enc_hidden(stats)
#         mean = self.mus_mean(hidden).view(-1, K, D)
#         std = torch.exp(self.mus_log_std(hidden).view(-1, K, D))
#         alpha = torch.exp(self.sigmas_log_alpha(hidden).view(-1, K, D))
#         beta = torch.exp(self.sigmas_log_beta(hidden).view(-1, K, D))
        
#         mus = Normal(mean, std).sample((num_samples, )) ## S * B * K * D
#         sigmas = Gamma(alpha, beta).sample((num_samples, )) ## S * B * K * D
        
        return stats

In [6]:
prior_alpha = 4.0
prior_beta = 4.0
prior_mus = torch.zeros((K, D))
prior_nu = 5.0

def log_joints_gmm(Z, Pi, means, stds, precisions, Xs, N, D, K, num_samples, batch_size):
    log_probs = torch.zeros((num_samples, batch_size))
    ## priors on mus and sigmas, S * B
    log_probs = log_probs + Normal(torch.zeros((num_samples, batch_size, K, D)), torch.sqrt(prior_nu / precisions)).log_prob(means).sum(-1).sum(-1)
    log_probs = log_probs + Gamma(torch.ones((batch_size, K, D)) * prior_alpha, torch.ones((batch_size, K, D)) * prior_beta).log_prob(precisions).sum(-1).sum(-1)
    ## Z B-by-T-by-K
#     log_probs = log_probs + cat(Pi).log_prob(Z).sum(-1)
    labels = Z.nonzero()
    log_probs = log_probs + Normal(means[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D), stds[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D)).log_prob(Xs).sum(-1).sum(-1)
    return log_probs

def conjugate_posterior(Xs, stats, N, K, D):
    stat1 = stats[:, :K] + 1e-3
    stat2 = stats[:, K:K+K*D]
    stat3 = stats[:, K+K*D:]
    x_mean = stat2.view(-1, K, D) / stat1.unsqueeze(-1).repeat(1, 1, D)
    ## every parameter is B * K * D
    post_alpha = (prior_alpha + stat1 / 2.).unsqueeze(-1).repeat(1, 1, D)
    post_nu = (prior_nu + stat1).unsqueeze(-1).repeat(1, 1, D)
    post_mus = (prior_nu * prior_mus + stat2.view(-1, K, D)) / (prior_nu + stat1.unsqueeze(-1).repeat(1, 1, D))
    post_beta = prior_beta + (1. / 2) * ((Xs - x_mean) ** 2) + (1. / 2) * (prior_nu * stat1 / (prior_nu + stat1)).unsqueeze(-1).repeat(1, 1, D) * ((prior_mus ** 2) + (x_mean ** 2) - 2 * prior_mus * x_mean)

    return post_alpha, post_beta, post_mus, post_nu

def sample_normal_gamma(alpha, beta, mus, nu, num_samples):
    precisions = Gamma(alpha, beta).sample((num_samples, )) + 0.01
    
    stds = torch.sqrt((1. / precisions) * nu)
    means = Normal(mus.unsqueeze(0).repeat(num_samples, 1, 1, 1), stds).sample()
    log_q = Normal(mus.unsqueeze(0).repeat(num_samples, 1, 1, 1), stds).log_prob(means).sum(-1).sum(-1) + Gamma(alpha, beta).log_prob(precisions).sum(-1).sum(-1)## S * B
    print(Gamma(alpha, beta).log_prob(precisions).sum(-1).sum(-1))
    return means, stds, log_q, precisions

def rws(Xs, Zs, Pi, N, K, D, num_samples, batch_size):
#     stat1, stat2, stats = StatsGMM(Xs, Zs, K, D)
    data = torch.cat((Xs, Zs), dim=-1).view(batch_size*N, -1)
    stats = enc(data, K, D, num_samples, batch_size)
    post_alpha, post_beta, post_mus, post_nu = conjugate_posterior(stats, N, K, D)
    means, stds, log_q, precisions = sample_normal_gamma(post_alpha, post_beta, post_mus, post_nu, num_samples)
    
    log_p = log_joints_gmm(Zs, Pi, means, stds, precisions, Xs, N, D, K, num_samples, batch_size)
#     print(log_p)
    log_weights = log_p - log_q
    weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
    eubo = torch.mul(weights, log_weights).sum(0).mean()
    elbo = log_weights.mean(0).mean()
    ess = (1. / (weights ** 2).sum(0)).mean()
#     posterior_mean, posterior_covs = conjugate_posterior(stat1, stat2, covs, K, D, batch_size)
#     MCKl, TrueKl = kls_gaussians(mus, mus_mean, mus_std, posterior_mean, posterior_covs, K, D)
    return eubo, elbo, ess

def shuffler(batch_Xs, batch_Zs, N, K, D, batch_size):
    indices = torch.cat([torch.randperm(N).unsqueeze(0) for b in range(batch_size)])
    indices_Xs = indices.unsqueeze(-1).repeat(1, 1, D)
    indices_Zs = indices.unsqueeze(-1).repeat(1, 1, K)
    return torch.gather(batch_Xs, 1, indices_Xs), torch.gather(batch_Zs, 1, indices_Zs)

In [7]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [8]:
EUBOs = []
ELBOs = []
ESSs = []
# MCKls = []
# TrueKls = []
num_batches = int((Xs.shape[0] / BATCH_SIZE))
for epoch in range(1):
    indices = torch.randperm(num_seqs)
    time_start = time.time()
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
#     MCKl = 0.0
#     TrueKl = 0.0
    for step in range(1):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_Xs = Xs[batch_indices]
        batch_Zs = Zs[batch_indices]
        batch_Xs, batch_Zs = shuffler(batch_Xs, batch_Zs, N, K, D, BATCH_SIZE)
        data = torch.cat((batch_Xs, batch_Zs), dim=-1).view(BATCH_SIZE*N, -1)
        stats = enc(data, K, D, NUM_SAMPLES, BATCH_SIZE)
        post_alpha, post_beta, post_mus, post_nu = conjugate_posterior(stats, N, K, D)
        means, stds, log_q, precisions = sample_normal_gamma(post_alpha, post_beta, post_mus, post_nu, NUM_SAMPLES)
#         log_p = log_joints_gmm(batch_Zs, Pi, means, stds, precisions, batch_Xs, N, D, K, NUM_SAMPLES, BATCH_SIZE)
#         print(post_beta)
#         eubo, elbo, ess = rws(batch_Xs, batch_Zs, Pi, N, K, D, NUM_SAMPLES, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
#         MCKl += mckl.item()
#         TrueKl += truekl.item()
    EUBO /= num_batches
    ELBO /= num_batches
    ESS /= num_batches
#     MCKl /= num_batches
#     TrueKl /= num_batches
    
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    ESSs.append(ESS)
#     MCKls.append(MCKl)
#     TrueKls.append(TrueKl)
    
    time_end = time.time()
    print('epoch=%d, EUBO=%f, ELBO=%f, ESS=%.3f (%ds)' % (epoch, EUBO, ELBO, ESS, time_end - time_start))
#     print('epoch=%d, EUBO=%f, ELBO=%f, ESS=%.3f, MCKL=%f, TKL=%f (%ds)' % (epoch, EUBO, ELBO, ESS, MCKl, TrueKl, time_end - time_start))

tensor([[    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan, -3.1222,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan,     nan,     nan,     nan,     nan,     nan, -2.6373,
             nan,     nan,  

NameError: name 'eubo' is not defined

In [None]:
Gamma(post_alpha, post_beta).log_prob(precisions)

In [10]:
stat1 = stats[:, :K] + 1e-3
stat2 = stats[:, K:K+K*D]
stat3 = stats[:, K+K*D:]
(1. / 2) * (stat3.view(-1, K, D) - ((stat2.view(-1, K, D) ** 2) / stat1.unsqueeze(-1).repeat(1, 1, D)))

tensor([[[-6.6100e+02,  1.3365e+02],
         [-3.1387e+02, -1.7009e+01],
         [ 2.2881e+00, -8.2317e-02]],

        [[-2.4348e+02,  1.1507e+02],
         [-6.9891e+01, -9.4755e+00],
         [ 1.1921e+01, -4.9813e-02]],

        [[-4.8666e+02,  7.4931e+01],
         [-4.1334e+02, -1.4952e-02],
         [ 2.9633e-01, -3.4268e-04]],

        [[-3.9734e+05, -3.0691e+05],
         [-1.1885e+00, -2.5159e+02],
         [ 9.8242e+01,  0.0000e+00]],

        [[-6.5936e+03,  3.3732e+02],
         [-9.3699e+01, -3.2717e+02],
         [ 9.7354e+01, -1.7525e-05]],

        [[-9.4977e+03,  1.3549e+02],
         [-2.5436e+01, -5.4638e+01],
         [ 3.5837e+00, -2.3607e-07]],

        [[-3.6718e+02,  2.6077e+02],
         [-5.8269e+02, -1.2583e+01],
         [ 0.0000e+00, -2.2556e-06]],

        [[-1.2631e+05, -1.3507e+04],
         [-1.9851e-03, -1.2264e+01],
         [ 2.2666e+01, -5.3101e-06]],

        [[-3.7911e+03,  1.8345e+02],
         [-1.0023e+00, -3.5344e-01],
         [ 1.2164e-01,

In [None]:
def plot_results(EUBOs, ELBOs, ESS, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(20, 20))
#     fig.tight_layout()
    ax1 = fig.add_subplot(2, 1, 1)
    ax3 = fig.add_subplot(2, 1, 2)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')
    ax1.tick_params(labelsize=18)
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, lr=%.1E, samples=%d' % (num_epochs, lr, num_samples), fontsize=18)
    ax1.set_ylim([-150, -80])
    ax1.legend()
    ax3.legend()
    ax3.tick_params(labelsize=18)
#     plt.savefig('gmm_rws_datatodist_lr=%.1E_samples=%d.svg' % (lr, orch.ones((K, D)) * 0.3 num_samples))

In [None]:
# plot_results(EUBOs, ELBOs, ESS, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)