In [1]:
%matplotlib inline
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from plots import *
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
from torch.distributions.one_hot_categorical import OneHotCategorical as cat
from torch import logsumexp
import sys
import time
import datetime
sys.path.append('/home/hao/Research/probtorch/')
import probtorch
print('probtorch:', probtorch.__version__, 
      'torch:', torch.__version__, 
      'cuda:', torch.cuda.is_available())

probtorch: 0.0+5a2c637 torch: 1.0.0 cuda: True


In [2]:
N = 20
K = 3
D = 2

## Model Parameters
NUM_SAMPLES = 100
NUM_HIDDEN = 128
NUM_LATENTS = D * K
NUM_OBS =  K+D*K
BATCH_SIZE = 100
NUM_EPOCHS = 10000
LEARNING_RATE = 1e-3
CUDA = False

In [3]:
Xs = torch.from_numpy(np.load('gmm_dataset/sequences.npy')).float()
Zs = torch.from_numpy(np.load('gmm_dataset/states.npy')).float()
# mus_true = torch.from_numpy(np.load('gmm_dataset/means.npy')).float()
covs = torch.from_numpy(np.load('gmm_dataset/covariances.npy')).float()
Pi = torch.from_numpy(np.load('gmm_dataset/init.npy')).float()
num_seqs = Zs.shape[0]

In [4]:
def StatsGMM(Xs, Zs, K, D):
    """
    Xs is B * N * D
    Zs is B * N * K
    return B * (K+D*K)
    """
    stat1 = Zs.sum(1)
    stat2 = torch.mul(Zs.unsqueeze(-1).repeat(1, 1, 1, D), Xs.unsqueeze(-1).repeat(1, 1, 1, K).transpose(-1, -2)).sum(1) 
    return stat1, stat2, torch.cat((stat1, stat2.view(-1, D*K)), dim=-1)

In [5]:
class Encoder(nn.Module):
    def __init__(self, num_obs=NUM_OBS,
                       num_hidden=NUM_HIDDEN,
                       num_latents=NUM_LATENTS):
        super(self.__class__, self).__init__()
        self.enc_hidden = nn.Sequential(
            nn.Linear(num_obs, num_hidden),
            nn.ReLU())
        self.mean = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        self.log_std = nn.Sequential(
            nn.Linear(num_hidden, num_latents))
        
    def forward(self, obs, K, D, num_samples):
        hidden = self.enc_hidden(obs)
        mean = self.mean(hidden).view(-1, K, D)
        std = torch.exp(self.log_std(hidden).view(-1, K, D))
        mus = Normal(mean, std).sample((num_samples, )) ## S * B * K * D
        return mean, std, mus

In [6]:
def initialize():
    enc = Encoder()
    if CUDA:
        enc.cuda()
    optimizer =  torch.optim.Adam(list(enc.parameters()),lr=LEARNING_RATE)    
    return enc, optimizer
enc, optimizer = initialize()

In [7]:
def log_joints_gmm(Z, Pi, mus, covs, Xs, N, D, K, num_samples, batch_size):
    log_probs = torch.zeros((num_samples, batch_size)).float()
    ## S * B
    log_probs = log_probs + Normal(torch.zeros((batch_size, K, D)), torch.ones((batch_size, K, D))).log_prob(mus).sum(-1).sum(-1)
    ## Z B-by-T-by-K
    log_probs = log_probs + cat(Pi).log_prob(Z).sum(-1)
    labels = Z.nonzero()
    covs_expand = covs.unsqueeze(0).unsqueeze(0).repeat(num_samples, batch_size, 1, 1, 1)
    log_probs = log_probs + mvn(mus[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D), covs_expand[:, labels[:, 0], labels[:, -1], :].view(-1, batch_size, N, D, D)).log_prob(Xs).sum(-1)
    return log_probs

def conjugate_posterior(stat1, stat2, covs, K, D, batch_size):
    prior_covs_inv = torch.ones(K, D)
    covs_flat = torch.diagonal(covs, 0, -2, -1).unsqueeze(0).repeat(BATCH_SIZE, 1, 1) #B * K * D
    posterior_covs = 1. / (prior_covs_inv + torch.mul(stat1.unsqueeze(-1).repeat(1, 1, D), 1. / covs_flat))
    posterior_mean = torch.mul(posterior_covs, torch.mul(stat2, 1. / covs_flat))
    return posterior_mean, posterior_covs

def kl_normal_normal(p_mean, p_std, q_mean, q_std):
    var_ratio = (p_std / q_std).pow(2)
    t1 = ((p_mean - q_mean) / q_std).pow(2)
    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())

def kls_gaussians(mus, mus_mean, mus_std, posterior_mean, posterior_covs, K, D):
    log_q = Normal(mus_mean, mus_std).log_prob(mus).sum(-1).sum(-1)
    log_p = Normal(posterior_mean, torch.sqrt(posterior_covs)).log_prob(mus).sum(-1).sum(-1)
    MCKl = (log_q - log_p).mean(0).mean()
    TrueKl = kl_normal_normal(mus_mean, mus_std, posterior_mean, torch.sqrt(posterior_covs)).mean()
    return MCKl, TrueKl
    
def rws(Xs, Zs, Pi, covs, N, K, D, num_samples, batch_size):
    stat1, stat2, stats = StatsGMM(Xs, Zs, K, D)
    mus_mean, mus_std, mus = enc(stats, K, D, num_samples)
    log_q = Normal(mus_mean, mus_std).log_prob(mus).sum(-1).sum(-1) ## S * B
    log_p = log_joints_gmm(Zs, Pi, mus, covs, Xs, N, D, K, num_samples, batch_size)
    log_weights = log_p - log_q
    weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
    eubo = torch.mul(weights, log_weights).sum(0).mean()
    elbo = log_weights.mean(0).mean()
    ess = (1. / (weights ** 2).sum(0)).mean()
    posterior_mean, posterior_covs = conjugate_posterior(stat1, stat2, covs, K, D, batch_size)
    MCKl, TrueKl = kls_gaussians(mus, mus_mean, mus_std, posterior_mean, posterior_covs, K, D)
    return eubo, elbo, ess, MCKl, TrueKl

In [None]:
EUBOs = []
ELBOs = []
ESSs = []
MCKls = []
TrueKls = []
num_batches = int((Xs.shape[0] / BATCH_SIZE))
for epoch in range(NUM_EPOCHS):
    indices = torch.randperm(num_seqs)
    time_start = time.time()
    EUBO = 0.0
    ELBO = 0.0
    ESS = 0.0
    MCKl = 0.0
    TrueKl = 0.0
    for step in range(num_batches):
        optimizer.zero_grad()
        batch_indices = indices[step*BATCH_SIZE : (step+1)*BATCH_SIZE]
        batch_Xs = Xs[batch_indices]
        batch_Zs = Zs[batch_indices]
        eubo, elbo, ess, mckl, truekl = rws(batch_Xs, batch_Zs, Pi, covs, N, K, D, NUM_SAMPLES, BATCH_SIZE)
        eubo.backward()
        optimizer.step()
        EUBO += eubo.item()
        ELBO += elbo.item()
        ESS += ess.item()
        MCKl += mckl.item()
        TrueKl += truekl.item()
#         print('iteration=%d, EUBO=%f, ELBO=%f, MCKL=%f, TKL=%f (%ds)' % (epoch*num_batches+step, eubo, elbo, mckl, truekl, time_end - time_start))
    EUBO /= num_batches
    ELBO /= num_batches
    ESS /= num_batches
    MCKl /= num_batches
    TrueKl /= num_batches
    
    EUBOs.append(EUBO)
    ELBOs.append(ELBO)
    ESSs.append(ESS)
    MCKls.append(MCKl)
    TrueKls.append(TrueKl)
    
    time_end = time.time()
    print('epoch=%d, EUBO=%f, ELBO=%f, ESS=%.3f, MCKL=%f, TKL=%f (%ds)' % (epoch, EUBO, ELBO, ESS, MCKl, TrueKl, time_end - time_start))

epoch=0, EUBO=-541.480426, ELBO=-41169.352271, ESS=1.037, MCKL=41055.172083, TKL=7294.927390 (9s)
epoch=1, EUBO=-264.807706, ELBO=-13678.076050, ESS=1.145, MCKL=13563.888428, TKL=2271.473578 (9s)
epoch=2, EUBO=-484.764951, ELBO=-1630198.919336, ESS=1.103, MCKL=1630084.903125, TKL=263811.261768 (9s)
epoch=3, EUBO=-23034.768167, ELBO=-20689252.562500, ESS=1.138, MCKL=20689136.318750, TKL=3511928.544531 (9s)
epoch=4, EUBO=-2189.177905, ELBO=-10653304.450000, ESS=1.188, MCKL=10653189.237500, TKL=1731425.625000 (9s)
epoch=5, EUBO=-191.425504, ELBO=-236366.583984, ESS=1.351, MCKL=236252.378418, TKL=42041.966235 (9s)
epoch=6, EUBO=-172.564819, ELBO=-4601.851099, ESS=1.279, MCKL=4487.662695, TKL=736.113797 (9s)
epoch=7, EUBO=-171.314444, ELBO=-3856.414429, ESS=1.281, MCKL=3742.225342, TKL=629.367114 (9s)
epoch=8, EUBO=-142.171478, ELBO=-1506.252612, ESS=1.440, MCKL=1392.063843, TKL=231.500211 (9s)
epoch=9, EUBO=-129.488235, ELBO=-450.551349, ESS=1.475, MCKL=336.362466, TKL=55.505237 (9s)
epoch