In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch.distributions.gamma import Gamma
from torch import logsumexp
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 30
K = 3
D = 2

## Model Parameters
NUM_SAMPLES = 10
NUM_HIDDEN = 64
NUM_STATS = K+D*K+D*K
NUM_LATENTS = D * K
NUM_OBS = D + K
BATCH_SIZE = 10
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-4
CUDA = False

In [3]:
Xs = torch.from_numpy(np.load('gmm_dataset/sequences.npy')).float()
Zs = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
mus_true = torch.from_numpy(np.load('gmm_dataset/means.npy')).float()
# covs = torch.from_numpy(np.load('gmm_dataset2/covariances.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
num_seqs = Zs.shape[0]

In [4]:
def StatsGMM(Xs, Zs, K, D):
    """
    Xs is B * N * D
    Zs is B * N * K
    stat1 corresponds I[z_n=1], ..., I[z_n=K]
    stat2 corresponds I[z_n=1]x_n, ..., I[z_n=K]x_n
    stat3 corresponds I[z_n=1]x_n**2, ..., I[z_n=K]x_n**2
    return B * (K+D*K+D*K)
    """
    stat1 = Zs.sum(1)
    stat2 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), Xs.unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    stat3 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), torch.mul(Xs, Xs).unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    return stat1, stat2, stat3, torch.cat((stat1, stat2.view(-1, D*K), stat3.view(-1, D*K)), dim=-1)

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_stats=NUM_STATS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_stats = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_stats))
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_stats+K*D, num_hidden),
            nn.ReLU())
        self.sigmas_log_alpha = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.sigmas_log_beta = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, K, D, mus, num_samples, batch_size):
        stats = self.enc_stats(obs).view(batch_size, N, -1).sum(1)
        stats_mus = torch.cat((stats, mus.view(-1, K*D)), dim=-1)
        hidden = self.enc_hidden(stats_mus)
        alpha = torch.exp(self.sigmas_log_alpha(hidden).view(-1, K, D))
        beta = torch.exp(self.sigmas_log_beta(hidden).view(-1, K, D))
        precisions = Gamma(alpha, beta).sample((num_samples,))
        return alpha, beta, precisions

In [6]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [7]:
def log_joints_gmm(Z, Pi, mus, precisions, Xs, N, D, K, num_samples, batch_size):
    log_probs = torch.zeros((num_samples, batch_size)).float()
    ## priors on mus and sigmas, S * B
#     log_probs = log_probs + Normal(torch.zeros((batch_size, K, D)), torch.ones((batch_size, K, D))).log_prob(mus).sum(-1).sum(-1)
    log_probs = log_probs + Gamma(torch.ones((batch_size, K, D)) * 2.0, torch.ones((batch_size, K, D)) * 2.0).log_prob(precisions).sum(-1).sum(-1)
    ## Z B-by-T-by-K
#     log_probs = log_probs + cat(Pi).log_prob(Z).sum(-1)
    labels = Z.nonzero()
    sigmas = 1. / torch.sqrt(precisions)
    mus_expand = mus.unsqueeze(0).repeat(num_samples, 1, 1, 1)
    log_probs = log_probs + Normal(mus_expand[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D), sigmas[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D)).log_prob(Xs).sum(-1).sum(-1)
    return log_probs

def conjugate_posterior(stat1, stat2, stat3, mus, K, D, batch_size):
    stat1_expand = stat1.unsqueeze(-1).repeat(1, 1, D)
    posterior_alpha = torch.ones((batch_size, K, D)) * 2.0 + (stat1_expand / 2.)
    posterior_beta = torch.ones((batch_size, K, D)) * 2.0 + (stat3 + (stat1_expand * (mus ** 2)) - 2 * mus * stat2) / 2.
    return posterior_alpha, posterior_beta
    
def kl_gamma_gamma(p_alpha, p_beta, q_alpha, q_beta):
    t1 = q_alpha * (p_beta / q_beta).log()
    t2 = torch.lgamma(q_alpha) - torch.lgamma(p_alpha)
    t3 = (p_alpha - q_alpha) * torch.digamma(p_alpha)
    t4 = (q_beta - p_beta) * (p_alpha / p_beta)
    return t1 + t2 + t3 + t4

def kls_gammas(weights, tau, q_alpha, q_beta, p_alpha, p_beta, K, D):
    log_q = Gamma(q_alpha, q_beta).log_prob(tau).sum(-1).sum(-1)
    log_p = Gamma(p_alpha, p_beta).log_prob(tau).sum(-1).sum(-1)
    MCKl_exclusive = (log_q - log_p).mean(0).mean()
    TrueKl_exclusive = kl_gamma_gamma(q_alpha, q_beta, p_alpha, p_beta).mean()
    
    MCKl_inclusive = torch.mul(weights, log_p - log_q).sum(0).mean()
    TrueKl_inclusive = kl_gamma_gamma(p_alpha, p_beta, q_alpha, q_beta).mean()
    return MCKl_inclusive, TrueKl_inclusive, MCKl_exclusive, TrueKl_exclusive

def rws(Xs, Zs, Pi, N, K, D, mus, num_samples, batch_size):
    stat1, stat2, stat3, stats = StatsGMM(Xs, Zs, K, D)
    data = torch.cat((Xs, Zs), dim=-1).view(batch_size*N, -1)
    sigmas_alpha, sigmas_beta, precisions = enc(data, K, D, mus, num_samples, batch_size)
    log_q =  Gamma(sigmas_alpha, sigmas_beta).log_prob(precisions).sum(-1).sum(-1)## S * B

    log_p = log_joints_gmm(Zs, Pi, mus, precisions, Xs, N, D, K, num_samples, batch_size)
    log_weights = log_p - log_q
    weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
    eubo = torch.mul(weights, log_weights).sum(0).mean()
    elbo = log_weights.mean(0).mean()
    ess = (1. / (weights ** 2).sum(0)).mean()
    posterior_alpha, posterior_beta = conjugate_posterior(stat1, stat2, stat3, mus, K, D, batch_size)
    MCKl_inclusive, TrueKl_inclusive, MCKl_exclusive, TrueKl_exclusive = kls_gammas(weights, precisions, sigmas_alpha, sigmas_beta, posterior_alpha, posterior_beta, K, D)
    return eubo, elbo, ess, MCKl_inclusive, TrueKl_inclusive, MCKl_exclusive, TrueKl_exclusive

def shuffler(batch_Xs, batch_Zs, N, K, D, batch_size):
    indices = torch.cat([torch.randperm(N).unsqueeze(0) for b in range(batch_size)])
    indices_Xs = indices.unsqueeze(-1).repeat(1, 1, D)
    indices_Zs = indices.unsqueeze(-1).repeat(1, 1, K)
    return torch.gather(batch_Xs, 1, indices_Xs), torch.gather(batch_Zs, 1, indices_Zs)

In [8]:
EUBOs = []
ELBOs = []
ESSs = []
MCKls_inclusive = []
TrueKls_inclusive = []
MCKls_exclusive = []
TrueKls_exclusive = []

num_batches = int((Xs.shape[0] / BATCH_SIZE))
time_start = time.time()
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(num_seqs)

    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    MCKl_inclusive = 0.0
    TrueKl_inclusive = 0.0
    MCKl_exclusive = 0.0
    TrueKl_exclusive = 0.0
    for step in range(num_batches):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_Xs = Xs[batch_indices]
        batch_Zs = Zs[batch_indices]
        batch_Mus = mus_true[batch_indices]
        batch_Xs, batch_Zs = shuffler(batch_Xs, batch_Zs, N, K, D, BATCH_SIZE)
        eubo, elbo, ess, mckl_inclusive, truekl_inclusive, mckl_exclusive, truekl_exclusive = rws(batch_Xs, batch_Zs, Pi, N, K, D, batch_Mus, NUM_SAMPLES, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
        MCKl_inclusive += mckl_inclusive.item()
        MCKl_exclusive += mckl_exclusive.item()
        TrueKl_inclusive += truekl_inclusive.item()
        TrueKl_exclusive += truekl_exclusive.item()
        
    EUBO /= num_batches
    ELBO /= num_batches
    ESS /= num_batches
    MCKl_inclusive /= num_batches
    TrueKl_inclusive /= num_batches
    MCKl_exclusive /= num_batches
    TrueKl_exclusive /= num_batches
    
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    ESSs.append(ESS)
    MCKls_inclusive.append(MCKl_inclusive)
    TrueKls_inclusive.append(TrueKl_inclusive)
    MCKls_exclusive.append(MCKl_exclusive)
    TrueKls_exclusive.append(TrueKl_exclusive)
    
#     time_end = time.time()
    if epoch % 10 == 0:
        time_end = time.time()
        print('epoch=%d, EUBO=%f, ELBO=%f, ESS=%.3f, inc MCKL=%f, inc TKL=%f, exc MCKL=%f, exc TKL=%f (%ds)' % (epoch, EUBO, ELBO, ESS, MCKl_inclusive, TrueKl_inclusive, MCKl_exclusive, TrueKl_exclusive, time_end - time_start))
        time_start = time.time()

epoch=0, EUBO=-561.683459, ELBO=-889.891391, ESS=1.031, inc MCKL=-476.432370, inc TKL=52.501810, exc MCKL=804.640382, exc TKL=349.170189 (1s)
epoch=10, EUBO=-87.928332, ELBO=-100.395998, ESS=1.638, inc MCKL=-2.677206, inc TKL=1.190143, exc MCKL=15.144872, exc TKL=2.540500 (14s)
epoch=20, EUBO=-87.234352, ELBO=-97.429085, ESS=1.750, inc MCKL=-1.983223, inc TKL=0.933414, exc MCKL=12.177959, exc TKL=2.042764 (12s)
epoch=30, EUBO=-86.861317, ELBO=-96.613957, ESS=1.830, inc MCKL=-1.610202, inc TKL=0.833192, exc MCKL=11.362832, exc TKL=1.908248 (11s)
epoch=40, EUBO=-86.467625, ELBO=-95.825248, ESS=1.822, inc MCKL=-1.216491, inc TKL=0.805451, exc MCKL=10.574121, exc TKL=1.807884 (12s)
epoch=50, EUBO=-86.494432, ELBO=-96.118696, ESS=1.824, inc MCKL=-1.243309, inc TKL=0.769535, exc MCKL=10.867569, exc TKL=1.814002 (11s)
epoch=60, EUBO=-86.496709, ELBO=-95.612619, ESS=1.882, inc MCKL=-1.245585, inc TKL=0.765315, exc MCKL=10.361492, exc TKL=1.772094 (6s)
epoch=70, EUBO=-86.278066, ELBO=-95.700177

KeyboardInterrupt: 

In [None]:
def plot_results(EUBOs, ELBOs, ESSs, MCKls_exclusive, TrueKls_exclusive, MCKls_inclusive, TrueKls_inclusive, num_samples, num_epochs, lr):
    fig = plt.figure(figsize=(30, 30))
    fig.tight_layout()
    ax1 = fig.add_subplot(3, 1, 1)
    ax2 = fig.add_subplot(3, 1, 2)
    ax3 = fig.add_subplot(3, 1, 3)
    ax1.plot(EUBOs, 'r', label='EUBOs')
    ax1.plot(ELBOs, 'b', label='ELBOs')
    ax2.plot(TrueKls_exclusive, '#66b3ff', label='true exclusive KL')
    ax2.plot(MCKls_exclusive, '#ff9999', label='est exclusive KL')
    ax2.plot(TrueKls_inclusive, '#99ff99', label='true inclusive KL')
    ax2.plot(MCKls_inclusive, 'gold', label='est inclusive KL')
    
    ax1.tick_params(labelsize=18)
    ax3.plot(np.array(ESSs) / num_samples, 'm', label='ESS')
    ax1.set_title('epoch=%d, batch_size=%d, lr=%.1E, samples=%d' % (num_epochs, BATCH_SIZE, lr, num_samples), fontsize=18)
    ax1.set_ylim([-300, -80])
    ax1.legend()
    ax2.set_ylim([-50, 50])
    ax2.legend()
    ax3.legend()
    ax2.tick_params(labelsize=18)
    ax3.tick_params(labelsize=18)
    plt.savefig('gmm_rws_datatodist_lr=%.1E_samples=%d.svg' % (lr, num_samples))

In [None]:
plot_results(EUBOs, ELBOs, ESSs, MCKls_exclusive, TrueKls_exclusive, MCKls_inclusive, TrueKls_inclusive, NUM_SAMPLES, NUM_EPOCHS, LEARNING_RATE)