In [1]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.normal import Normal
import numpy as np
from utils import *
from plots import *
# from objectives import *

In [15]:
iterations = 40000
num_samples = 20
p_mu = 0.0
p_sigma2 = 1.0

LEARNING_RATE = 5*1e-3
log_Z = np.log(np.sqrt((2*np.pi)))

q_mu = torch.tensor([4.0], requires_grad=True)
q_sigma = torch.tensor([2.0], requires_grad=True) 
optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)


In [16]:
def iwae(num_samples, q_mu, q_sigma, p_mu, p_sigma2, iterations, optimizer):
    ELBO = []
    Mu = []
    Sigma = []
    Grad_mu = []
    Grad_sigma = []
    ESS = []
    time_start = time.time()
    for i in range(iterations):
        optimizer.zero_grad()
        q = Normal(q_mu, q_sigma)
        xs = q.rsample((num_samples,))
        log_p = (-1.0 / (p_sigma2 *2.0)) * ((xs - p_mu) ** 2)
        log_q = q.log_prob(xs)
        log_weights = log_p - log_q
        weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
        ess = 1. / (weights ** 2).sum()
        elbo = logsumexp(log_weights, 0) - torch.log(torch.FloatTensor([num_samples]))
        estor = torch.mul(weights, log_weights).sum()
        loss = - estor
        loss.backward()
        optimizer.step()
        ELBO.append(elbo.item())
        Mu.append(q_mu.item())
        Sigma.append(q_sigma.item())
        Grad_mu.append((- q_mu.grad).item())
        Grad_sigma.append((- q_sigma.grad).item())
        ESS.append(ess.item())
        if i % 1000 == 0:
            time_end = time.time()
            print('iteration:%d, ELBO:%.3f, ESS:%.3f (%ds)' % (i, elbo, ess, (time_end - time_start)))
            time_start = time.time()
    return ELBO, Mu, Sigma, Grad_mu, Grad_sigma, ESS

In [17]:
ELBO, Mu, Sigma, Grad_mu, Grad_sigma, ESS = iwae(num_samples, q_mu, q_sigma, p_mu, p_sigma2, iterations, optimizer)


iteration:0, ELBO:0.891, ESS:1.946 (0s)
iteration:1000, ELBO:1.297, ESS:7.175 (0s)
iteration:2000, ELBO:1.186, ESS:5.823 (0s)
iteration:3000, ELBO:1.022, ESS:5.477 (0s)
iteration:4000, ELBO:0.716, ESS:6.355 (0s)
iteration:5000, ELBO:0.994, ESS:6.360 (0s)
iteration:6000, ELBO:0.660, ESS:6.740 (0s)
iteration:7000, ELBO:0.444, ESS:4.530 (0s)
iteration:8000, ELBO:0.967, ESS:8.190 (0s)
iteration:9000, ELBO:1.083, ESS:9.850 (0s)
iteration:10000, ELBO:0.997, ESS:8.445 (0s)
iteration:11000, ELBO:0.822, ESS:8.977 (0s)
iteration:12000, ELBO:0.525, ESS:8.323 (0s)
iteration:13000, ELBO:0.448, ESS:6.612 (0s)
iteration:14000, ELBO:1.101, ESS:10.857 (0s)
iteration:15000, ELBO:0.908, ESS:9.743 (0s)
iteration:16000, ELBO:0.441, ESS:6.144 (0s)
iteration:17000, ELBO:0.791, ESS:8.408 (0s)
iteration:18000, ELBO:0.978, ESS:9.235 (0s)
iteration:19000, ELBO:1.015, ESS:11.278 (0s)
iteration:20000, ELBO:0.651, ESS:9.044 (0s)
iteration:21000, ELBO:1.149, ESS:14.466 (0s)
iteration:22000, ELBO:1.174, ESS:13.543 (0

In [18]:
eg_mu, eg2_mu, var_mu, snr_mu = SNR(np.array(Grad_mu), iterations, beta1=0.99, beta2=0.999)    
eg_sigma, eg2_sigma, var_sigma, snr_sigma = SNR(np.array(Grad_sigma), iterations, beta1=0.99, beta2=0.999)    

KeyboardInterrupt: 

In [None]:
def plot_results(ELBO, ESS, num_samples, snr_mu, snr_sigma):
    fig = plt.figure(figsize=(10,10))
    ax1, ax2, ax3 = fig.subplots(3, 1, sharex=True)
    plt.tight_layout()
    ax1.plot(ELBO, 'b', label='ELBOs')
    ax1.legend()
    ## SNR
    ax2.plot(snr_sigma, label='SNR_sigma')
    ax2.plot(snr_mu, label='SNR_mu')
    ax2.legend()
    # ax2.set_ylim([-1,5])
    ## ESS
    ess_ratio = np.array(ESS) / num_samples
    ave_ess = np.reshape(ess_ratio, (-1, 10)).mean(-1)
    N = ave_ess.shape[0]
    ax3.plot(np.arange(N) * 10, ave_ess, 'go', label='ESS')
    ax3.set_ylim([0, 1])
plot_results(ELBO, ESS, num_samples, snr_mu, snr_sigma)