In [2]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.normal import Normal
import numpy as np

In [3]:
iterations = 100
p_mu = 0.0

lr = 1e-3

log_Z = np.log(np.sqrt((2*np.pi)))

In [4]:
def train(num_samples, q_mu, lr):
    EUBOs = []
    ELBOs = []
    AGrads = []
    Grads = []
    for i in range(iterations):
        proposal = Normal(q_mu, q_sigma)
        xs = proposal.sample((num_samples,))
        log_gammas = (-1.0 / 2.0) * ((xs - p_mu) ** 2)
        log_q = proposal.log_prob(xs)

        log_weights = log_gammas - log_q
        weights = torch.exp(log_weights - logsumexp(log_weights, dim=0)).detach()
        eubo = torch.mul(weights, log_weights).sum()
        elbo = log_weights.mean()
        gradient = torch.autograd.grad(eubo, q_mu)
        analytical_gradient = q_mu
        q_mu = q_mu - lr * gradient[0]
        EUBOs.append(eubo.item())
        ELBOs.append(elbo.item())
        AGrads.append(analytical_gradient.item())
        Grads.append(gradient[0].item())
        
    return EUBOs, ELBOs, AGrads, Grads

In [5]:
def smooth_exponential(values, beta=0.9):
    v = np.asarray(values)
    t = np.arange(len(values))
    dt = t[None, :] - t[:,None] 
    w = np.exp((1-beta) * dt + np.log(t[None,:] <= t[:, None]))
    w /= w.sum(1)[:, None]
    v_smooth = np.dot(w, v[:, None])
    return v_smooth

In [7]:
# T = 1001
# t = np.arange(T)
# x = t/T
# y = 1 - x + 0.5 * np.random.randn(T)
# y_avg = smooth_exponential(y, 0.9)
# y2_avg = smooth_exponential(y, 0.999)
# y_adam = smooth_exponential(y, 0.9) / y2_avg**0.5
# plt.figure()
# plt.plot(t, y, label='noisy')
# plt.plot(t, y_adam, label='adam')
# plt.legend()

In [8]:
def plot_results(EUBOs, ELBOs, Grads, AGrads, num_mus, num_sigmas, num_samples):
    for i in range(num_mus):
        fig = plt.figure(figsize=(30, 40))
        for j in range(num_sigmas):
            for k in range(num_samples):
                ax = fig.add_subplot(num_sigmas, num_samples, j*num_samples + k + 1)
                ax.plot(EUBOs[i*()], 'r', label='EUBOs')
                ax.plot(ELBOs[], 'b', label='ELBOs')
                ax.plot(Grads, 'g', label='estimated grad')
                #AdamGrads = smooth_exponential(Grads, adam_b1) / smooth_exponential(Grads, adam_b2) 
                #ax.plot(AdamGrads, 'g', label='adam grad')
                ax.plot(AGrads, 'gray', label='true grad' )
                ax.plot(np.ones(iterations) * log_Z, 'k', label='log_Z')
                ax.tick_params(labelsize=18)
                ax.set_ylim([-50, 20])
            if i == 0 and j == 0:
                ax.legend()
                    ax.set_title('mu=%d, sigma=%d, samples=%d' % (init_mu[i], init_sigma[j], NUM_SAMPLES[k]), fontsize=18)
    plt.savefig('univariate_gaussian_rws_mu=%d.svg' % (init_mu[i]))

SyntaxError: invalid syntax (<ipython-input-8-87b4093d10c9>, line 7)

In [None]:
init_mus = np.array([6, 8, 10])
init_sigmas = np.array([1.0, 2.0, 4.0, 6.0])
SAMPLES = np.array([100, 1000])
num_mus = init_mus.shape[0]
num_sigmas = init_sigmas.shape[0]
num_samples = SAMPLES.shape[0]

adam_b1 = 0.9
adam_b2 = 0.999

EUBOs = []
ELBOs = []
AGrads = []
Grads = []

for i in range(num_mus):
    for j in range(num_sigmas):
        for k in range(num_samples):
            time_start = time.time()
            q_mu = torch.tensor([1.0], requires_grad=True) * init_mus[i]
            q_sigma = torch.tensor([1.0]) * init_sigmas[j]
            EUBO, ELBO, AGrad, Grad = train(SAMPLES[k], q_mu, lr)
            EUBOs.append(EUBO)
            ELBOs.append(ELBO)
            AGrads.append(AGrad)
            Grads.append(Grad)
            time_end = time.time()
            print('init_mu=%d, init_sigma=%d, samples : %d (%ds)' % (init_mus[i], init_sigmas[j], SAMPLES[k], time_end - time_start))    

In [None]:
fig = plt.figure(figsize=(8,8))
# plt.plot(ELBOs)
ax = fig.add_subplot(111)
ax.plot(EUBOs, 'r', label='EUBOs')
ax.plot(ELBOs, 'b', label='ELBOs')
ax.plot(np.ones(iterations) * log_Z, 'k', label='log_Z')
ax.legend()
plt.savefig('rws-univariate-samples=%d-prior=2.png' % num_samples)

In [None]:
log_Z