In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib as mpl
params = {
    'axes.labelsize': 8,
    'font.size': 8,
    'legend.fontsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'text.usetex': False,
    'figure.figsize': [4.5, 4.5],
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': True,
}
mpl.rcParams['axes.unicode_minus'] = False
mpl.rcParams.update(params)
mpl.use("pgf")

import seaborn as sns
sns.set_style("darkgrid")

In [3]:
import beanmachine.ppl as bm
import torch.distributions as dist
import torch

base_dist = dist.Normal(torch.zeros(1), torch.ones(1))
target = dist.Cauchy(0, 1)

import flowtorch.bijectors
import flowtorch.params

flow = flowtorch.bijectors.AffineAutoregressive(
    flowtorch.params.DenseAutoregressive(hidden_dims=(32,32,32))
)
new_dist, params = flow(base_dist)

import torch.optim
optimizer = torch.optim.Adam(params.parameters(), lr=1e-3)

from tqdm.auto import tqdm
for _ in tqdm(range(int(1e3))):
    optimizer.zero_grad()
    samples = new_dist.rsample((1000,))
    log_q = new_dist.log_prob(samples)
    log_p = target.log_prob(samples)
    elbo = log_p - log_q
    loss = -elbo.mean()
    loss.backward()
    optimizer.step()
    tqdm.write(f"Loss: {loss:.4f}", end='')



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Loss: 0.1873


In [4]:
import beanmachine.ppl as bm
import torch.nn as nn

@bm.random_variable
def tgt():
    return target

from beanmachine.ppl.experimental.vi.variational_infer import (
    MeanFieldVariationalInference,
)

def on_iter(it, loss, vi_dicts):
    if it % 10 == 0:
        tqdm.write(f"Loss: {loss}", end='')

vi_dicts = MeanFieldVariationalInference().infer(
    queries=[tgt()],
    observations={},
    num_iter=int(500),
    progress_bar=True,
    flow=lambda: flowtorch.bijectors.AffineAutoregressive(
        flowtorch.params.DenseAutoregressive(hidden_dims=(32,32,32))
    ),
    lr=1e-3,
    on_iter=on_iter,
    base_dist=dist.StudentT,
    base_args={
        'df': nn.Parameter(torch.tensor([5.0]).log()),
        'loc': nn.Parameter(torch.tensor([0.0])),
        'scale': nn.Parameter(torch.tensor([1.0])),
    },
    num_elbo_mc_samples=1000,
)

HBox(children=(HTML(value='Training iterations'), FloatProgress(value=0.0, max=500.0), HTML(value='')))

Loss: tensor([0.0824], grad_fn=<SubBackward0>)


In [7]:
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler

default_cycler = (cycler(color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']) +
                  cycler(linestyle=['-', '--', ':', '-.']))
plt.rc('axes', prop_cycle=default_cycler)

fig, ax = plt.subplots(1,3,figsize=(5.5,2))

xs = torch.linspace(-5, 5).unsqueeze(1)
ax[0].plot(xs, new_dist.log_prob(xs).exp().detach())
ax[0].plot(xs, vi_dicts(tgt()).log_prob(xs).exp().detach())
ax[0].plot(xs, target.log_prob(xs).exp())
ax[0].set_xlabel('$x$')
ax[0].set_title('$p(x)$')

xs = torch.linspace(0, 10).unsqueeze(1)
ax[1].plot(xs, new_dist.log_prob(xs).exp().detach())
ax[1].plot(xs, vi_dicts(tgt()).log_prob(xs).exp().detach())
ax[1].plot(xs, target.log_prob(xs).exp())
ax[1].set_yscale('log')
ax[1].set_xlabel('$x$')
ax[1].set_title('$\log p(x)$')

import scipy.stats
ax[2].hist(
    [scipy.stats.kstest(lambda size: new_dist.sample((size,)).squeeze().numpy(), lambda x: target.cdf(torch.tensor(x)).numpy(), N=100).pvalue for _ in range(1000)],
    #density=True,
    alpha=0.5,
)
ax[2].hist(
    #[scipy.stats.kstest(lambda size: vi_dicts(tgt()).sample((size,)).squeeze().numpy(), lambda x: target.cdf(torch.tensor(x)).numpy()).pvalue for _ in range(1000)],
    [scipy.stats.kstest(lambda size: vi_dicts(tgt()).sample((size,)).squeeze().numpy(), lambda x: target.cdf(torch.tensor(x)).numpy(), N=100).pvalue for _ in range(1000)],
    #density=True,
    alpha=0.5,
)
ax[2].set_title('K-S p-values')
ax[2].set_ylabel('Count')
ax[2].set_xlabel('$p$-value')

fig.tight_layout()
fig.legend(['ADVI', 'ATAF', 'Target'], loc="center right")
fig.subplots_adjust(right=0.8)
fig.show()

In [8]:
fig.savefig('fat_tail_ks.pdf')