In [None]:
import matplotlib.pyplot as plt
import time
import torch
from torch import logsumexp
import torch.nn.functional as F
from torch.distributions.normal import Normal

import numpy as np
from utils_v1 import *
from plots import *
from objectives_v1 import *

In [None]:
## training parameters
STEPS = 10000
NUM_SAMPLES = 5
NUM_BATCHES = 100
LEARNING_RATE = 5*1e-3
## model parameters
p_mu = torch.tensor([0.0])
p_sigma = torch.tensor([1.0])
q_mu = torch.tensor([4.0], requires_grad=True)
q_sigma = torch.tensor([2.0], requires_grad=True) 
## initialize optimizer
optimizer = torch.optim.SGD([q_mu, q_sigma], lr=LEARNING_RATE)

In [None]:
loss, _, _, _, _ = rws(q_mu, q_sigma, p_mu, p_sigma, NUM_SAMPLES, alpha=None, num_batches=NUM_BATCHES)

In [None]:
loss.shape

In [None]:
q = Normal(q_mu, q_sigma)
xs = q.sample((NUM_SAMPLES, NUM_BATCHES))

In [None]:
torch.autograd.grad(loss, [q_mu, q_sigma])

In [None]:
def SNR(obj, q_mu, q_sigma, p_mu, p_sigma, num_samples, optimizer, alpha, num_batches):
    Grad_mu = []
    Grad_sigma = []
    optimizer.zero_grad()
    loss, _, _, _, _ = obj(q_mu, q_sigma, p_mu, p_sigma, num_samples, alpha=alpha, num_batches=num_batches)
        loss.backward()
        Grad_mu.append(- q_mu.grad.item())
        Grad_sigma.append(- q_sigma.grad.item())

    snr_mu, var_mu = stats(np.array(Grad_mu))
    snr_sigma, var_sigma = stats(np.array(Grad_sigma))
    optimizer.zero_grad()
    return (snr_mu + snr_sigma) / 2, (var_mu + var_sigma) / 2

In [None]:
def rws(q_mu, q_sigma, p_mu, p_sigma, num_samples, num_batches=None, alpha=None):
    q = Normal(q_mu, q_sigma)
    xs = q.sample((num_samples,)) ## nonreparam sampler
    log_p = (-1.0 / ((p_sigma**2) * 2.0)) * ((xs - p_mu) ** 2)
    log_q = q.log_prob(xs)
    log_weights = log_p - log_q
    weights = F.softmax(log_weights, 0).detach()
    ess = 1. / (weights ** 2).sum()
    eubo = (weights * log_weights).sum()
    iwelbo = logsumexp(log_weights, 0) - torch.log(torch.FloatTensor([num_samples]))
    elbo = log_weights.mean()
    loss = eubo
    return loss, eubo, elbo, iwelbo, ess

In [None]:
EUBOs, ELBOs, IWELBOs, ESSs, SNRs, VARs = train(dreg, q_mu, q_sigma, p_mu, p_sigma, STEPS, NUM_SAMPLES, NUM_SAMPLES_SNR, optimizer, filename='rws', alpha=0.0)

In [None]:
def plot_results(EUBO, ELBO, ESS, num_samples, snr_mu, snr_sigma):
    fig = plt.figure(figsize=(10,10))
    ax1, ax2, ax3 = fig.subplots(3, 1, sharex=True)
    plt.tight_layout()
    ax1.plot(EUBO, 'r', label='EUBOs')
    ax1.plot(ELBO, 'b', label='ELBOs')
    ax1.legend()
    ## SNR
    ax2.set_yscale('log')
    ax2.plot(snr_sigma, label='SNR_sigma')
    ax2.plot(snr_mu, label='SNR_mu')
    ax2.legend()
    ax2.set_ylim([1e-2,1e2])
    ## ESS
    ess_ratio = np.array(ESS) / num_samples
    ave_ess = np.reshape(ess_ratio, (-1, 10)).mean(-1)
    N = ave_ess.shape[0]
    ax3.plot(np.arange(N) * 10, ave_ess, 'go', label='ESS')
    ax3.set_ylim([0, 1])
# plot_results(EUBO, ELBO, ESS, num_samples, snr_mu, snr_sigma)