In [None]:
%matplotlib inline
from bivariate import *
from utils import *
from enc import *
from plots import *
import torch
import numpy as np
import time
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
# from torch.distributions.one_hot_categorical import OneHotCategorical as cat

### Gibbs Sampling

In [None]:
# ## Define a target bivariate Gaussian for unit test
# mu1 = torch.ones(1) * 5.0
# mu2 = torch.ones(1) * 8.0
# sigma1 = torch.ones(1) * 1.0
# sigma2 = torch.ones(1) * 2.5
# rho = torch.ones(1) * 0.6

# bg = Bi_Gaussian(mu1, mu2, sigma1, sigma2, rho, CUDA=False, device=None)

# STEPS = 500

# updates = Gibbs(bg, STEPS, sampling=True)

In [None]:
# Plot_updates(updates, bg, sigma_factor=5, pts=1000, fs=10, levels=5, back_to_cpu=False)

### Amortized Gibbs Sampling

In [None]:
GRAD_STEPS = 1000
LEARNING_RATE = 1e-3
NUM_SAMPLES = 10
MCMC_STEPS = 10
NUM_HIDDENS = 8

CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda:1')

In [None]:
## Define a target bivariate Gaussian for unit test
mu1 = torch.ones(1) * 5.0
mu2 = torch.ones(1) * 8.0
sigma1 = torch.ones(1) * 1.0
sigma2 = torch.ones(1) * 2.5
rho = torch.ones(1) * 0.6

q_x1 = Kernel(NUM_HIDDENS, mu1, sigma1, CUDA, DEVICE)
q_x2 = Kernel(NUM_HIDDENS, mu2, sigma2, CUDA, DEVICE)
bg = Bi_Gaussian(mu1, mu2, sigma1, sigma2, rho, CUDA=True, device=DEVICE)

if CUDA:
     with torch.cuda.device(DEVICE):
        q_x1.cuda()
        q_x2.cuda()
optimizer = torch.optim.Adam(list(q_x1.parameters())+list(q_x2.parameters()), lr=LEARNING_RATE)

In [None]:
def fb(q, bg, x_cond, x_cond_name, x_old, sampling=True):
    r_mu, r_sigma = bg.conditional(x_cond, cond=x_cond_name)
    if sampling:
        x_new, log_q_f, q_mu, q_sigma = q.forward(x_cond)
    else:
        _, _, q_mu, q_sigma = q.forward(x_cond)
        x_new = q_mu
        
    kls = kl_normal_normal(r_mu, r_sigma, q_mu, q_sigma).mean()
    log_p_f = bg.log_pdf_gamma(x_new, r_mu, r_sigma)
    log_q_b = q.backward(x_cond, x_old)
    log_p_b = bg.log_pdf_gamma(x_old, r_mu, r_sigma)  
    log_w_f = log_p_f.sum(-1) - log_q_f.sum(-1)
    log_w_b = log_p_b.sum(-1) - log_q_b.sum(-1)
    log_w = log_w_f - log_w_b
    w = F.softmax(log_w, 0).detach()
    loss = (w * log_w_f).sum(0).mean().unsqueeze(0)
    ess = (1. / (w ** 2).sum(0)).mean().unsqueeze(0)
    return loss, ess, x_new, log_w, w, kls

def ag(q_x1, q_x2, bg, mcmc_steps, num_samples):
    losss = []
    esss = []

    ## start with sampling x1 from the prior
    x1, log_p_x1_f = q_x1.sample_prior(num_samples)
    x2, log_q_x2, _, _ = q_x2.forward(x1)
    r_mu, r_sigma = bg.conditional(x1, cond='x1')
    log_p_x2_f = bg.log_pdf_gamma(x2, r_mu, r_sigma)
    log_w = log_p_x2_f.sum(-1).detach() - log_q_x2.sum(-1)
    w_x2 = F.softmax(log_w, 0).detach()
    losss.append((w_x2 * log_w).sum(0).mean().unsqueeze(0))
    esss.append((1. / (w_x2** 2).sum(0)).mean().unsqueeze(0))
    for m in range(mcmc_steps):
        x2 = resample(x2, w_x2) ## resample x2
        x1_old = x1
        loss_x1, ess_x1, x1, log_w_x1, w_x1, _ = fb(q_x1, bg, x2, 'x2', x1_old) ## update x1
        x1 = resample(x1, w_x1) ## resample x1
        x2_old = x2
        loss_x2, ess_x2, x2, log_w_x2, w_x2, _ = fb(q_x2, bg, x1, 'x1', x2_old) ## update x2
        losss.append(loss_x1+loss_x2)
        esss.append((ess_x1 + ess_x2) / 2)
    return torch.cat(losss, 0).sum(), torch.cat(esss, 0).mean()

In [None]:
LOSSS = []
ESSS = []

time_start = time.time()
for i in range(GRAD_STEPS):

    optimizer.zero_grad()
    loss, ess = ag(q_x1, q_x2, bg, MCMC_STEPS, NUM_SAMPLES)
    ##
    loss.backward()
    optimizer.step()
    LOSSS.append(loss)
    ESSS.append(ess)
    if i % 100 == 0:
        time_end = time.time()
        print('Step=%d, Loss=%.4f, ESS=%.4f (%ds)' % (i, loss, ess, time_end - time_start))
        time_start = time.time()

In [None]:
def test(q_x1, q_x2, bg, mcmc_steps, num_samples, sampling=True):
    updates = []
    DBs = []
    KLs = []
    ## start with sampling x1 from the prior
    x1, log_p_x1_f = q_x1.sample_prior(num_samples)
    if sampling:
        x2, log_q_x2, q_mu, q_sigma = q_x2.forward(x1)
    else:
        _, _, q_mu, q_sigma = q_x2.forward(x1)
        x2 = q_mu
    r_mu, r_sigma = bg.conditional(x1, cond='x1')
    kls = kl_normal_normal(r_mu, r_sigma, q_mu, q_sigma).mean()
    KLs.append(kls.unsqueeze(0))
    log_p_x2_f = bg.log_pdf_gamma(x2, r_mu, r_sigma)
    log_w = log_p_x2_f.sum(-1).detach() - log_q_x2.sum(-1)
    w_x2 = F.softmax(log_w, 0).detach()
    updates.append(torch.cat((x1, x2), -1))
    for m in range(mcmc_steps):
#         x2 = resample(x2, w_x2) ## resample x2
        x1_old = x1
        loss_x1, ess_x1, x1, log_w_x1, w_x1, kls_x1 = fb(q_x1, bg, x2, 'x2', x1_old) ## update x1
        DBs.append((w_x1 * log_w_x1).sum(0).mean().unsqueeze(0))
#         x1 = resample(x1, w_x1) ## resample x1
        x2_old = x2
        loss_x2, ess_x2, x2, log_w_x2, w_x2, kls_x2 = fb(q_x2, bg, x1, 'x1', x2_old) ## update x2
        KLs.append((kls_x1.unsqueeze(0) + kls_x2.unsqueeze(0)) / 2)

        updates.append(torch.cat((x1, x2), -1))
        DBs.append((w_x2 * log_w_x2).sum(0).mean().unsqueeze(0))
    return torch.cat(updates, 0), torch.cat(DBs, 0), torch.cat(KLs, 0)

In [None]:
updates, DBs, KLs = test(q_x1, q_x2, bg, 500, 1)
plot_kls(DBs.cpu().data.numpy(), KLs.cpu().data.numpy())

In [None]:
Plot_updates(updates, bg, sigma_factor=5, pts=100, fs=10, levels=5)