# Visualise AIS

In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
from fab.sampling_methods.annealed_importance_sampling import AnnealedImportanceSampler
from fab.target_distributions.gmm import GMM
from fab.types_ import HaikuDistribution
from fab.utils.plotting import plot_marginal_pair
from fab.sampling_methods.mcmc.hamiltonean_monte_carlo import HMCStatePAccept
import haiku as hk
import distrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [None]:
dim = 2
batch_size = 512
rng = hk.PRNGSequence(0)

## Define Gaussian Learnt Distribution

In [None]:
def make_gaussian_base_dist(event_shape = (dim, ), dtype = jnp.float32):
    loc = hk.get_parameter("loc", shape=event_shape, init=jnp.zeros, dtype=dtype)
    log_scale = hk.get_parameter("log_scale", shape=event_shape, init=jnp.zeros, dtype=dtype)
    scale = jnp.exp(log_scale)
    base_dist = distrax.Independent(
        distrax.Normal(
            loc=loc,
            scale=scale),
        reinterpreted_batch_ndims=len(event_shape))
    return base_dist

get_model = lambda: make_gaussian_base_dist()

In [None]:
@hk.without_apply_rng
@hk.transform
def log_prob(data):
    model = get_model()
    return model.log_prob(data)

@hk.transform
def sample_and_log_prob(sample_shape):
    model = get_model()
    return model.sample_and_log_prob(seed=hk.next_rng_key(), sample_shape=sample_shape)


@hk.transform
def sample(sample_shape):
    model = get_model()
    return model.sample(seed=hk.next_rng_key(), sample_shape=sample_shape)

learnt_distribution = HaikuDistribution(dim, log_prob, sample_and_log_prob, sample)
samples = jnp.ones(dim)
learnt_distribution_params = learnt_distribution.log_prob.init(jax.random.PRNGKey(0), samples)

In [None]:
samples = learnt_distribution.sample.apply(
    learnt_distribution_params,
    jax.random.PRNGKey(0), (500,))
plot_marginal_pair(samples)

## Define target distribution (GMM)

In [None]:
target = GMM(dim, n_mixes=5, loc_scaling=2, log_var_scaling=-2.0)
target_log_prob = target.log_prob

In [None]:
samples = target.sample(seed=jax.random.PRNGKey(0), sample_shape=(500,))
plot_marginal_pair(samples, bounds=(-10, 10))

## Get AIS up and running. Check that thinnging samples by log weights works. 
With 5> intermediate AIS distributions, the samples look great.

In [None]:
n_intermediate_distributions = 2

In [None]:
AIS = AnnealedImportanceSampler(
             learnt_distribution=learnt_distribution,
             target_log_prob=target_log_prob,
             n_parallel_runs=batch_size,
             n_intermediate_distributions=n_intermediate_distributions)

In [None]:
transition_operator_state = AIS.transition_operator_manager.get_init_state()

In [None]:
transition_operator_state = HMCStatePAccept(
    no_grad_params=transition_operator_state.no_grad_params,
    step_size_params=jnp.ones_like(transition_operator_state.step_size_params)*0.1)

In [None]:
x_new, log_w, _trans_state, info = AIS.run(next(rng), learnt_distribution_params, transition_operator_state)

In [None]:
base_samples = learnt_distribution.sample.apply(
    learnt_distribution_params,
    jax.random.PRNGKey(0), (batch_size,))
plot_marginal_pair(base_samples)
plot_marginal_pair(x_new, bounds=(-10, 10))

In [None]:
indxs = jax.random.choice(next(rng), log_w.shape[0], shape=(batch_size,), replace=True, p=jax.nn.softmax(log_w))

In [None]:
plot_marginal_pair(x_new[indxs], bounds=(-10, 10)) # looks good

## Visualise effective sample size trend for number of AIS distributions

In [None]:
def run(n_intermediate_distributions):
    AIS = AnnealedImportanceSampler(
             learnt_distribution=learnt_distribution,
             target_log_prob=target_log_prob,
             n_parallel_runs=batch_size,
             n_intermediate_distributions=n_intermediate_distributions)
    transition_operator_state = AIS.transition_operator_manager.get_init_state()
    transition_operator_state = HMCStatePAccept(
    no_grad_params=transition_operator_state.no_grad_params,
    step_size_params=jnp.ones_like(transition_operator_state.step_size_params)*0.1)
    x_new, log_w, _trans_state, info = AIS.run(next(rng), learnt_distribution_params, transition_operator_state)
    ess = info["ess_ais"]
    return ess

In [None]:
ais_dist_range = [1,2,4,8,16, 32]
ess_hist = []
for n_intermediate_distributions in ais_dist_range:
    ess = run(n_intermediate_distributions)
    ess_hist.append(ess)

In [None]:
plt.plot(ais_dist_range, ess_hist)
plt.xlabel("n ais distributions")
plt.ylabel("effective sample size")