# Layered Adaptive Importance Sampling

https://arxiv.org/pdf/1505.04732.pdf

The notebook implements the Parallel Interacting
Markov Adaptive Importance Sampling (PI-MAIS) mentioned in the paper in JAX. Implementing the other few variants should also be straightforward.

In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False
GIT_TOKEN = ""

if IN_COLAB:
    !pip install -U pip setuptools wheel
    if GIT_TOKEN:
        !pip install git+https://{GIT_TOKEN}@github.com/ethanluoyc/sympais.git#egg=sympais
    else:
        !pip install git+https://github.com/ethanluoyc/sympais.git#egg=sympais

if IN_COLAB:
    !curl -L "https://drive.google.com/uc?export=download&id=1_Im0Ot5TjkzaWfid657AV_gyMpnPuVRa" -o realpaver
    !chmod u+x realpaver
    !cp realpaver /usr/local/bin

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import time

import jax
from jax import tree_util
from jax import random
import jax.numpy as np
import matplotlib.pyplot as plt
from numpyro import distributions
from jax import random
import seaborn as sns
from numpyro import distributions
from matplotlib import animation, rc
from IPython.display import HTML
import functools

from sympais.infer import mcmc
from sympais import distributions as D
from sympais.infer import importance
from sympais.infer.mcmc import metropolis
# sns.set_context("paper")

## Mixture of Gaussian problem

In [None]:
N_MIXTURES = 5
INPUT_DIM = 2

MU = np.array([
    [-10, -10], 
    [0, 16],
    [13, 8],
    [-9, 7],
    [14,-14]
])

COV = np.array([
    [2, 0.6, 0.6, 1],
    [2, -0.4, -0.4, 2],
    [2, 0.8, 0.8, 2],
    [3, 0, 0,  0.5],
    [2, -0.1, -0.1, 2]
]).reshape(-1, INPUT_DIM, INPUT_DIM)


target_dist = D.Mixture(
    distributions.Categorical(probs=np.ones(N_MIXTURES)/N_MIXTURES),
    distributions.MultivariateNormal(
        loc=MU, 
        covariance_matrix=COV)
)

In [None]:
size = 100
x = np.linspace(-20, 20, size)
y = np.linspace(-20, 20, size)
xx, yy = np.meshgrid(x, y)

log_probs = target_dist.log_prob(np.stack([xx, yy], -1))

In [None]:
rng = random.PRNGKey(42)
samples = target_dist.sample(rng, (100, ))

fig, ax = plt.subplots()
ax.set_aspect(1.)
ax.contourf(xx, yy, np.exp(log_probs), cmap='Blues')
ax.scatter(samples[:, 0], samples[:,1], 
           marker='.', color='white', edgecolor='black',
           alpha=0.8
          )
# plt.grid()
ax.set(xlim=(-20, 20), ylim=(-20,20), 
       xlabel='$x_1$', ylabel='$x_2$');

In [None]:
def metropolis_hasting_step(rng,
                            proposal_state,
                            target_log_prob_fn,
                            metropolis_proposal_scale=5):
    proposal_fn = functools.partial(
        metropolis.random_walk_proposal_fn, scale=metropolis_proposal_scale)
    next_proposal_state, extra = metropolis.random_walk_metropolis_hasting_step(
        rng, proposal_state, target_log_prob_fn, proposal_fn)
    return next_proposal_state, extra

def hmc_step(rng, proposal_state, target_log_prob_fn):
    next_state, extra = hmc.hamiltonian_monte_carlo_step(
        rng, target_log_prob_fn, proposal_state,
        path_length=0.2*15,
        step_size=0.2,
        kinetic_energy_fn=jax.partial(hmc.gaussian_kinetic_energy_fn, chain_ndims=1)
    )
    return next_state, extra


In [None]:
import functools
from sympais.infer import mcmc

num_proposals = 100
num_iterations = 100
num_samples_per_iteration = 200
importance_proposal_scale = 0.5
metropolis_proposal_scale = 5.

rng = random.PRNGKey(0)

# initialize the proposal state
proposal_state = jax.random.uniform(
    rng,
    shape=(num_proposals, 2),
    minval=-4,
    maxval=4
)

kernel = functools.partial(
    metropolis_hasting_step, 
    target_log_prob_fn=target_dist.log_prob,
)
# proposal_state = np.zeros((num_proposals, 2))

def get_proposal(proposal_state):
    """Build the lower layer proposal distributions"""
    return distributions.MultivariateNormal(
        proposal_state, 
        np.eye(2)*np.square(importance_proposal_scale)
    )

In [None]:
state, extra = importance.PIMAIS(
    rng,
    target_dist.log_prob,
    proposal_state,
    kernel,
    get_proposal,
    num_iters=num_iterations,
    num_samples=num_samples_per_iteration,
)

In [None]:

fig, ax = plt.subplots(2, 2, figsize=(6,6))
ax = ax.flat
ax[0].contourf(xx, yy, np.exp(log_probs), cmap='Blues');
ax[0].set(xlim=(-20, 20), ylim=(-20,20), ylabel='$x_2$');
ax[1].contourf(xx, yy, np.exp(log_probs), cmap='Blues');
ax[1].set(xlim=(-20, 20), ylim=(-20,20));

ax[1].scatter(
    proposal_state[:, 0],
    proposal_state[:, 1],
    color='C0', 
    edgecolor='black',
    label='$t = 0$'
)
ax[1].scatter(
    extra[0].proposal_state[-1, :, 0], 
    extra[0].proposal_state[-1, :, 1],
    color='C1', 
    edgecolor='black',
    label='$t = T$'
);
ax[1].legend(loc='upper left')
logprob = np.mean((np.exp(get_proposal(proposal_state).log_prob(
    np.expand_dims(np.stack([xx, yy], -1), 2)
))/state.proposal_state.shape[0]), axis=-1)

ax[2].contourf(xx, yy, np.exp(logprob), cmap='Blues');
ax[2].scatter(proposal_state[:, 0], 
           proposal_state[:,1], 
           marker='.', color='white', edgecolor='black',
           alpha=0.8
          );
ax[2].set(xlim=(-20, 20), ylim=(-20,20), xlabel='$x_1$', ylabel='$x_2$');

# final_state = state.proposal_state
final_state = state.proposal_state
logprob = np.mean(
    (np.exp(get_proposal(final_state).log_prob(
    np.expand_dims(np.stack([xx, yy], -1), 2)
))/state.proposal_state.shape[0]), axis=-1)

ax[3].contourf(xx, yy, np.exp(logprob), cmap='Blues');
ax[3].scatter(final_state[:, 0], 
              final_state[:,1], 
              linewidths=.5,
              marker='.', color='white', edgecolor='black',
           alpha=0.8
          )
ax[3].set(xlim=(-20, 20), ylim=(-20,20), 
       xlabel='$x_1$', );

In [None]:
# First set up the figure, the axis, and the plot element we want to animate
fig, ax = plt.subplots(figsize=(6,6))
ax.contourf(xx, yy, np.exp(log_probs), cmap='Blues');
ax.set(xlim=(-20, 20), ylim=(-20,20), 
       xlabel='$x_1$', ylabel='$x_2$');
sc = ax.scatter(
    proposal_state[:, 0],
    proposal_state[:, 1]
)
# initialization function: plot the background of each frame
def init():
    sc.set_offsets(proposal_state)
    return (sc,)
# animation function. This is called sequentially
def animate(i):
    sc.set_offsets(extra[0].proposal_state[i])
    return (sc,)

anim = animation.FuncAnimation(
    fig, 
    animate, 
    init_func=init,
    frames=20, 
    interval=100, 
    blit=True);
plt.close()

In [None]:
HTML(anim.to_html5_video())

In [None]:
fig, ax = plt.subplots(1, 2)
ax[0].plot(extra[0].Zpart, label='part');

ax[0].plot(extra[0].Ztot, label='total');
ax[0].legend()
ax[0].set(xlabel='Number of iterations', ylabel='Z')
# plt.yscale('log')
ax[1].plot(
    np.sum(np.square(extra[0].Ipart - np.array([1.6, 1.4])), axis=-1),
    label='part'
);
ax[1].plot(
    np.sum(np.square(extra[0].Itot - np.array([1.6, 1.4])), axis=-1),
    label='total'
);
ax[1].set(xlabel='Number of iterations', ylabel='MSE')
ax[1].legend();
plt.tight_layout()

## Banana-shaped target distribution

This is the second experiment in the paper, See Section 6.2.

In [None]:
def banana_log_joint_prob(x):
    x1 = x[..., 0]
    x2 = x[..., 1]
    nu1 = 4.
    nu2 = 5.
    nu3 = 5.
    B = 10.
    
    return (
        -1/(2*nu1*nu1)*np.square(4-B*x1-x2*x2)
        -x1*x1/(2*nu1*nu1)
        -x2*x2/(2*nu3*nu3)
    )

In [None]:
x = np.zeros((10 ,2), dtype=np.float32)
banana_log_joint_prob(x)

size = 100
x = np.linspace(-20, 20, size)
y = np.linspace(-20, 20, size)
xx, yy = np.meshgrid(x, y)

log_probs = banana_log_joint_prob(np.stack([xx, yy], -1))

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

ct = ax.contourf(xx, yy, np.exp(log_probs), levels=15, cmap='Blues')
ax.set(xlim=(-8, 4), ylim=(-10,10));
ax.set(xlabel='$x_1$', ylabel='$x_2$')
plt.tight_layout()
# savefig(fig, 'pimais/banana_groundtruth.pdf', transparent=True)

In [None]:
num_proposals = 100
num_iterations = 500
num_samples_per_iteration = 200
importance_proposal_scale = 0.5
metropolis_proposal_scale = 5.

rng = random.PRNGKey(0)

# initialize the proposal state
proposal_state = jax.random.uniform(
    rng,
    shape=(num_proposals, 2),
    minval=-6,
    maxval=6
)

proposal_state = np.stack(
    [
        jax.random.uniform(random.PRNGKey(1), shape=(num_proposals,), minval=-6, maxval=-3),
        jax.random.uniform(random.PRNGKey(2), shape=(num_proposals,), minval=-4, maxval=4),
    ],
    axis=-1)

def get_proposal(proposal_state):
    """Build the lower layer proposal distributions"""
    return distributions.MultivariateNormal(
        proposal_state, 
        np.eye(2)*np.square(importance_proposal_scale)
    )

adapt_proposal_fn = functools.partial(
    mcmc.random_walk_proposal_fn, 
    scale=metropolis_proposal_scale
)

kernel = functools.partial(
    metropolis_hasting_step, 
    target_log_prob_fn=banana_log_joint_prob)

In [None]:
state, extra = importance.PIMAIS(
    rng,
    banana_log_joint_prob,
    proposal_state,
    kernel,
    get_proposal,
    num_iters=num_iterations,
    num_samples=num_samples_per_iteration,
)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(6,6))
ax = ax.flat
ax[0].contourf(xx, yy, np.exp(log_probs), levels=15, cmap='Blues');
ax[0].set(xlim=(-8, 4), ylim=(-10,10), xlabel='$x_1$', ylabel='$x_2$');
ax[1].contourf(xx, yy, np.exp(log_probs), levels=15, cmap='Blues');
ax[1].set(xlim=(-8, 4), ylim=(-10,10), xlabel='$x_1$', ylabel='$x_2$');


ax[1].scatter(
    proposal_state[:, 0],
    proposal_state[:, 1], 
    color='C0', 
    edgecolor='black', label='$t = 0$'
)
ax[1].scatter(
    state.proposal_state[:, 0], 
    state.proposal_state[:, 1],
    color='C1', 
    edgecolor='black', label='$t = T$'
);
ax[1].legend()

logprob = np.mean((np.exp(get_proposal(proposal_state).log_prob(
    np.expand_dims(np.stack([xx, yy], -1), 2)
))/state.proposal_state.shape[0]), axis=-1)

ax[2].contourf(xx, yy, np.exp(logprob), cmap='Blues');
ax[2].scatter(proposal_state[:, 0], 
           proposal_state[:,1], 
           marker='.', color='white', edgecolor='black',
           alpha=0.8
          );
ax[2].set(xlim=(-8, 4), ylim=(-10,10), xlabel='$x_1$', ylabel='$x_2$');

# final_state = state.proposal_state
final_state = state.proposal_state
logprob = np.mean(
    (np.exp(get_proposal(final_state).log_prob(
    np.expand_dims(np.stack([xx, yy], -1), 2)
))/state.proposal_state.shape[0]), axis=-1)

ax[3].contourf(xx, yy, np.exp(logprob), cmap='Blues');
ax[3].scatter(final_state[:, 0], 
              final_state[:,1], 
              marker='.', color='white', edgecolor='black',
           alpha=0.8
          )
ax[3].set(xlim=(-8, 4), ylim=(-10,10), 
       xlabel='$x_1$', ylabel='$x_2$');
plt.tight_layout()
# savefig(fig, 'pimais/banana_mixture.pdf', transparent=True)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(6,2))

ax[0].plot(extra[0].Zpart, label='part',);
ax[0].plot(extra[0].Ztot, label='total', linewidth=2);
ax[0].legend()
ax[0].set(xlabel='Number of iterations', ylabel='Z')

ax[1].plot(
    np.mean(np.square(extra[0].Ipart - np.array([-0.4845, 0])), axis=-1),
    label='part'
);
ax[1].plot(
    np.mean(np.square(extra[0].Itot - np.array([-0.4845, 0])), axis=-1),
    label='total', linewidth=2
);
ax[1].legend();
ax[1].set(ylabel='MSE', xlabel='Number of iterations')
plt.tight_layout()
# savefig(fig, 'pimais/banana_pimais_estimates.pdf', transparent=True)
