In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import matplotlib as mpl
params = {
    'axes.labelsize': 8,
    'font.size': 8,
    'legend.fontsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'text.usetex': False,
    'figure.figsize': [4.5, 4.5],
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': True,
}
mpl.rcParams['axes.unicode_minus'] = False
mpl.rcParams.update(params)
mpl.use("pgf")

import seaborn as sns
sns.set_style("darkgrid")

# Define a fat 2D pancake

Consider a fat-tailed 2D pancake defined by the product distribution
$$StudentT(\nu=1) \otimes N(0,1)$$

In [3]:
import torch
import torch.distributions as dist

cmap = 'YlGnBu'

class Pancake(dist.Distribution):
    support = dist.constraints.real
    
    def __init__(self, fat_dims=1, thin_dims=1, df=torch.ones(1)):
        self.fat_dim_dist = dist.StudentT(df=torch.ones(fat_dims))
        self.thin_dim_dist = dist.Normal(loc=torch.zeros(thin_dims), scale=torch.ones(thin_dims))
        self.fat_dims = fat_dims
        self.thin_dims = thin_dims
        self._event_shape = [fat_dims + thin_dims]
        
    def sample(self, sample_shape=torch.Size([])):
        return torch.stack([
            self.fat_dim_dist.sample(sample_shape),
            self.thin_dim_dist.sample(sample_shape),
        ], dim=1).squeeze()
    
    def log_prob(self, x):
        fat, thin = x.split(self.fat_dims, dim=-1)
        return self.fat_dim_dist.log_prob(fat) + self.thin_dim_dist.log_prob(thin)
target = Pancake()

In [4]:
import pandas as pd
import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

def estimate_density(d, num_samples=1000):
    data = target.sample((num_samples,))
    x = data[:, 0]
    y = data[:, 1]
    xmin, xmax = -10, 10
    ymin, ymax = -5, 5

    # Peform the kernel density estimate
    xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    values = np.vstack([x, y])
    kernel = scipy.stats.gaussian_kde(values)
    f = np.reshape(kernel(positions), xx.shape)

    fig = plt.figure()
    ax = fig.gca()
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    im = ax.imshow(np.rot90(f), cmap=cmap, extent=[xmin, xmax, ymin, ymax])
    fig.colorbar(im, ax=ax)
    ax.set_xlabel('StudentT(df=1)')
    ax.set_ylabel('Normal(0,1)')
    return fig

def plot_density(d, ax=None, bounds=[-10, 10, -5, 5], norm=None, exp=False):
    #ounds = [-6, 4, 0, 5]
    ax.grid(False)
    xmin, xmax, ymin, ymax = bounds
    xx, yy = torch.meshgrid(
        torch.linspace(xmin, xmax, 100),
        torch.linspace(ymin, ymax, 100),
    )
    f = d.log_prob(torch.stack((xx, yy), dim=-1).reshape(-1, 2)).reshape((100,100)).detach().numpy()
    if exp:
        f = np.exp(f)
    if ax is None:
        fig = plt.figure()
        ax = fig.gca()
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    if norm:
        im = ax.imshow(np.rot90(f), cmap=cmap, norm=norm, extent=[xmin, xmax, ymin, ymax], aspect="auto")
    else:
        im = ax.imshow(np.rot90(f), cmap=cmap, extent=[xmin, xmax, ymin, ymax], aspect="auto")
    #fig.colorbar(im, ax=ax)
    ax.contour(xx, yy, f, colors='w', linestyles='dashed')
    #ax.set_ylabel('$\\sigma$')
    #ax.set_xlabel('$\\mu$')
    return ax

In [5]:
#estimate_density(target, num_samples=int(1e3)).show()
#fig = plot_density(target, norm=plt.Normalize(-5, -2), bounds=[-6, 4, -5, 5])

fig, axs = plt.subplots(2,2,figsize=(5.5,3))
plot_density(target,
             bounds=[-10, 10, -5, 5],
             ax=axs[0][0],
                   #norm=plt.Normalize(0, 0.1),
                   #bounds=[-6, 4, -5, 5], exp=True)
                  )
#fig = plot_density(target, bounds=[-10, 10, -5, 5])
axs[0][0].set_title('Target')
axs[0][0].set_xlabel('StudentT($\\nu=1$)')
axs[0][0].set_ylabel('Normal($\\mu=0,\\sigma^2=1$)')
fig.tight_layout()

## Approximate it using ADVI

First consider the variational family $q_{\mu \in \mathbb{R}^2,\sigma \in \mathbb{R}^2_{>0}} = \mu + \sigma \odot N(0, I_2)$

In [6]:
import beanmachine.ppl as bm
import flowtorch.bijectors
import flowtorch.params
from tqdm.auto import tqdm

@bm.random_variable
def pancake():
    return Pancake()

from beanmachine.ppl.experimental.vi.variational_infer import (
    MeanFieldVariationalInference,
)

def on_iter(it, loss, vi_dicts):
    if it % 10 == 0:
        tqdm.write(f"Loss: {loss}", end='')

vi_dicts = MeanFieldVariationalInference().infer(
    queries=[pancake()],
    observations={},
    num_iter=1000,
    progress_bar=True,
    flow=lambda: flowtorch.bijectors.AffineAutoregressive(
        flowtorch.params.DenseAutoregressive(hidden_dims=(8,8))
    ),
    lr=1e-2,
    on_iter=on_iter,
)
vapprox_advi = vi_dicts(pancake())

HBox(children=(HTML(value='Training iterations'), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Loss: tensor([0.2093], grad_fn=<SubBackward0>)


In [7]:
#estimate_density(vapprox_advi).show()
plot_density(vapprox_advi,
                   ax=axs[0][1],
                   bounds=[-10, 10, -5, 5],
                   #norm=plt.Normalize(0, 0.1), bounds=[-6, 4, -5, 5], exp=True
                  )
                   
#fig = plot_density(vapprox_advi, bounds=[-6, 4, -5, 5])
axs[0][1].set_title('ADVI')

Text(0.5, 1.0, 'ADVI')

The resulting Normal approximation has mean $\approx (0,0)$

In [8]:
print(f"""
Mean: {vapprox_advi.new_dist.base_dist.base_dist.mean}
Var: {vapprox_advi.new_dist.base_dist.base_dist.variance}
""")


Mean: Parameter containing:
tensor([ 0.0544, -0.0957], requires_grad=True)
Var: tensor([3.0773, 4.8643], grad_fn=<PowBackward0>)



Its variance is higher in the heavy-tailed direction, but it does not capture the tail decay

## Approximate with ATAF

Let us now use FTVI, which uses $q_{\nu \in R_+^2, \mu \in \mathbb{R}^2, \sigma \in R^2_+} = \mu + \sigma \odot StudentT(\nu)$

In [9]:
import beanmachine.ppl as bm
import flowtorch.bijectors
import torch.nn as nn
from tqdm.auto import tqdm

from beanmachine.ppl.experimental.vi.variational_infer import (
    MeanFieldVariationalInference,
)

def on_iter(it, loss, vi_dicts):
    if it % 10 == 0:
        tqdm.write(f"Loss: {loss}, DF: {vi_dicts(pancake()).new_dist.base_dist.base_dist.df}", end='')

vi_dicts = MeanFieldVariationalInference().infer(
    queries=[pancake()],
    observations={},
    num_iter=1000,
    progress_bar=True,
    flow=lambda: flowtorch.bijectors.Compose([]),
#     flow=lambda: flowtorch.bijectors.AffineAutoregressive(
#         flowtorch.params.DenseAutoregressive(hidden_dims=(8,8))
#     ),
    lr=1e-2,
    on_iter=on_iter,
    base_dist=dist.StudentT,
    base_args={
        'df': nn.Parameter(torch.tensor([5.0]).log()),
        'loc': nn.Parameter(torch.tensor([0.0])),
        'scale': nn.Parameter(torch.tensor([1.0])),
    }
)
vapprox_ataf = vi_dicts(pancake())

HBox(children=(HTML(value='Training iterations'), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Loss: tensor([0.0003], grad_fn=<SubBackward0>), DF: tensor([ 1.0060, 12.3989], grad_fn=<AddBackward0>))


In [10]:
plot_density(vapprox_ataf,
                   ax=axs[1][1],
                   bounds=[-10, 10, -5, 5],
                   #norm=plt.Normalize(0, 0.1), bounds=[-6, 4, -5, 5], exp=True
                  )
axs[1][1].set_title('ATAF')

Text(0.5, 1.0, 'ATAF')

## Approximate with TAF

The previous SOTA for heavy-tailed density estimation uses a single tail-index for a high-dimensional distribution, that is $\nu > 0$ rather than $\nu \in R^2_+$.
Everything else is identical to FTVI.

This requires a bit of hacking to get working; here we patch `vapprox` to have a scalar `df` and hard-code $\mu = 0$ and $\sigma = 1$:

In [11]:
import torch.nn as nn

# initialize a MFVApprox
vi_dicts = MeanFieldVariationalInference().infer(
    queries=[pancake()],
    observations={},
    num_iter=1,
    progress_bar=False,
#     flow=lambda: flowtorch.bijectors.AffineAutoregressive(
#         flowtorch.params.DenseAutoregressive(hidden_dims=(8,8))
#     ),
    flow=lambda: flowtorch.bijectors.Compose([]),
    lr=1e-2,
    base_dist=dist.StudentT,
    base_args={
        'df': nn.Parameter(torch.tensor([5.0]).log()),
        'loc': torch.zeros(2),
        'scale': torch.ones(2),
    }
)

# monkey patch it to do TAF
vapprox_taf_fixed = vi_dicts(pancake())
vapprox_taf_fixed.base_dist = dist.StudentT
vapprox_taf_fixed.base_args = {
    'df': nn.Parameter(vapprox_taf_fixed.base_args['df'].mean()),
    'loc': torch.zeros(2),
    'scale': torch.ones(2).log(),
}
print(vapprox_taf_fixed.recompute_transformed_distribution().df)
#print("Params: ", list(vapprox_taf_fixed.parameters()))

tensor([5., 5.], grad_fn=<ExpandBackward>)


Next, we manually write the ELBO training loop (beanmachine doesn't propagate gradients for some reason):

In [12]:
import torch.optim

#optimizer = torch.optim.Adam([vapprox_taf_fixed.base_args['df']], lr=1e-2)
optimizer = torch.optim.Adam(vapprox_taf_fixed.parameters(), lr=1e-3)
for _ in tqdm(range(1000)):
    optimizer.zero_grad()
    
    M = 100
    zk = vapprox_taf_fixed.rsample((M,))
    elbo = Pancake().log_prob(zk).sum()
    #elbo = BLR().log_prob(zk).sum()
    elbo -= vapprox_taf_fixed.log_prob(zk).sum()
    elbo /= M
    
    loss = -elbo
    loss.backward(retain_graph=True)
    tqdm.write(f"Loss: {loss.item()}, DF: {vapprox_taf_fixed.new_dist.base_dist.df}", end='')
    #print(vapprox_taf_fixed.base_args['df'])
    optimizer.step()
    vapprox_taf_fixed.recompute_transformed_distribution()
    #print(vapprox_taf_fixed.base_args['df'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

Loss: 0.27453795075416565, DF: tensor([6.0400, 6.0400], grad_fn=<ExpandBackward>)>)


In [13]:
ax = axs[1][0]
#fig, ax = plt.subplots()
plot_density(vapprox_taf_fixed,
                   ax=ax,
                   bounds=[-10, 10, -5, 5],
                   #norm=plt.Normalize(0, 0.1), bounds=[-6, 4, -5, 5], exp=True
                  )
ax.set_title('TAF')
#fig.show()

Text(0.5, 1.0, 'TAF')

In [14]:
fig.savefig('pancake.pdf')