In [None]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
import numpy as np

In [None]:
iterations = 20000
p_mu = 0.0
# q_sigma = torch.tensor([1.0])

lr = 1e-3

log_Z = np.log(np.sqrt((2*np.pi)))

In [None]:
def train(num_samples, q_mu, q_sigma, lr):
    EUBOs = []
    ELBOs = []
    Mus = []
    Sigmas = []
    for i in range(iterations):
        proposal = Normal(q_mu, q_sigma)
        xs = proposal.sample((num_samples,))
        log_gammas = (-1.0 / 2.0) * ((xs - p_mu) ** 2)
        log_q = proposal.log_prob(xs)

        log_weights = log_gammas - log_q
        weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
        eubo = torch.mul(weights, log_weights).sum()
        elbo = log_weights.mean()
        grads = torch.autograd.grad(eubo, [q_mu, q_sigma])
        q_mu = q_mu - lr * grads[0]
        q_sigma =  q_sigma - lr * grads[1]
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        Mus.append(q_mu.item())
        Sigmas.append(q_sigma.item())
    return EUBOs, ELBOs, Mus, Sigmas

In [None]:
init_mu = np.array([6, 8, 10])
init_sigma = np.array([1.0, 2.0, 4.0, 6.0])
NUM_SAMPLES = np.array([100, 1000])

for i in range(init_mu.shape[0]):
    fig = plt.figure(figsize=(30, 40))
    for j in range(init_sigma.shape[0]):
        for k in range(NUM_SAMPLES.shape[0]):
            time_start = time.time()
            q_mu = torch.tensor([1.0], requires_grad=True) * init_mu[i]
            q_sigma = torch.tensor([1.0], requires_grad=True) * init_sigma[j]
            EUBOs, ELBOs, Mus, Sigmas = train(NUM_SAMPLES[k], q_mu, q_sigma, lr)
            time_end = time.time()
            print('init_mu=%.1f, init_sigma : %.1f, samples : %d (%ds)' % (init_mu[i], init_sigma[j], NUM_SAMPLES[k], time_end - time_start))
            ax = fig.add_subplot(init_sigma.shape[0] * 2, NUM_SAMPLES.shape[0], (2*j) * NUM_SAMPLES.shape[0]+ k + 1)
            ax.plot(EUBOs, 'r', label='EUBOs')
            ax.plot(ELBOs, 'b', label='ELBOs')

            ax.plot(np.ones(iterations) * log_Z, 'k', label='log_Z')
            ax.tick_params(labelsize=18)
            ax.set_ylim([-50, 20])
            
            ax2 = fig.add_subplot(init_sigma.shape[0] * 2, NUM_SAMPLES.shape[0], (2*j+1) * NUM_SAMPLES.shape[0]+ k + 1)
#             ax2.plot(Mus, 'r', label='Mus')
            ax2.plot(Sigmas, 'b', label='Sigmas')
            ax2.tick_params(labelsize=18)
            ax2.set_ylim([0, 6.5])
            
            if j == 0 and k == 0:
                ax.legend()
                ax2.legend()
            ax.set_title('mu=%d, sigma=%d, samples=%d' % (init_mu[i], init_sigma[j], NUM_SAMPLES[k]), fontsize=18)
    plt.savefig('univariate_gaussian_rws_mu=%d.svg' % (q_mu.item()))

In [None]:
fig = plt.figure(figsize=(8,8))
# plt.plot(ELBOs)
ax = fig.add_subplot(111)
ax.plot(EUBOs, 'r', label='EUBOs')
ax.plot(ELBOs, 'b', label='ELBOs')
ax.plot(np.ones(iterations) * log_Z, 'k', label='log_Z')
ax.legend()
plt.savefig('rws-univariate-samples=%d-prior=2.png' % num_samples)