# Visualise AIS
In this notebook we perform some visualisations of the annealed sampling algorithms performance, such as how AIS scales with the number of intermediate distributions.

In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from fab.sampling_methods import AnnealedImportanceSampler, Metropolis, HamiltoneanMonteCarlo
from fab.utils.logging import ListLogger
from fab.target_distributions import TargetDistribution
from fab.target_distributions.gmm import GMM
from fab.wrappers.torch import WrappedTorchDist
from fab.utils.plotting import plot_history, plot_contours, plot_marginal_pair
from fab.utils.numerical import effective_sample_size

## Setup Target Distribution & AIS based distribution

In [None]:
dim: int = 2
seed: int = 1
transition_operator_type: str = "hmc"
torch.manual_seed(seed)

In [None]:
target = GMM(dim=dim, n_mixes=4, loc_scaling=8)
base_dist = WrappedTorchDist(torch.distributions.MultivariateNormal(loc=torch.zeros(dim),
                                                                 scale_tril=15*torch.eye(dim)))

In [None]:
# plot target
plot_contours(target.log_prob, bounds=[-20, 20], n_contour_levels=50)

In [None]:
# plot base distribution
base_samples = base_dist.sample((500,))
plot_marginal_pair(base_samples, bounds=[-40, 40])

## Setup example of AIS
First we run look at the effect of tuning the step size for a fixed number of intermediate distributions.

In [None]:
n_ais_dist = 5 
batch_size = 1000

In [None]:
def setup_ais(n_ais_intermediate_distributions, transition_operator_type,
             step_size_init=1.0, n_outer=5):
    if transition_operator_type == "hmc":
        transition_operator = HamiltoneanMonteCarlo(
            n_ais_intermediate_distributions=n_ais_intermediate_distributions,
            n_outer=n_outer,
            epsilon=step_size_init, L=5, dim=dim,
            step_tuning_method="p_accept") # other tuning options include "No-U" and "Expected_target_prob"
    elif transition_operator_type == "metropolis":
        transition_operator = Metropolis(n_transitions=n_ais_intermediate_distributions,
                                         n_updates=5)
    ais = AnnealedImportanceSampler(base_distribution=base_dist,
                                    target_log_prob=target.log_prob,
                                    transition_operator=transition_operator,
                                    n_intermediate_distributions=n_ais_intermediate_distributions,
                                    )
    return ais

In [None]:
# we give epsilon a poor initialisation so we can visualise the effect of tuning easily
ais = setup_ais(n_ais_dist, "hmc", step_size_init=10.0)

### Visualise samples before HMC has been tuned. 
Note that we have given epsilone a poor initialisation (too big).

In [None]:
ais.transition_operator.set_eval_mode(True) # turn off tuning

In [None]:
samples, log_w = ais.sample_and_log_weights(batch_size)

plot the ais samples vs the target probability density contours

In [None]:
fig, ax = plt.subplots()
plot_contours(target.log_prob, ax=ax, bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(samples, ax=ax, bounds=[-30, 30])

Histogram of the log weights

In [None]:
n, bins, patches = plt.hist(log_w.numpy(), density=True, alpha=0.75, bins=20)
plt.show()

Histogram of the biggest log w

In [None]:
n, bins, patches = plt.hist(log_w[log_w > -10].numpy(), density=True, alpha=0.75, bins=20)
plt.show()

### Tune HMC and visualise again.
We see that the effective sample size (after ais) goes up during the tuning), and that the samples match the target more closely. 

In [None]:
logger = ListLogger()

In [None]:
ais.transition_operator.set_eval_mode(False) # turn on tuning

In [None]:
def eval(ais, outer_batch_size, inner_batch_size):
    ais.transition_operator.set_eval_mode(True) # turn off tuning for evaluation.
    base_samples, base_log_w, ais_samples, ais_log_w = \
        ais.generate_eval_data(outer_batch_size, inner_batch_size)
    info = {"eval_ess_base": effective_sample_size(log_w=base_log_w, normalised=False).item(),
            "eval_ess_ais": effective_sample_size(log_w=ais_log_w, normalised=False).item()}
    ais.transition_operator.set_eval_mode(False) # turn on tuning
    return info

In [None]:
for i in tqdm(range(100)):
    samples, log_w = ais.sample_and_log_weights(batch_size)
    logging_info = ais.get_logging_info()
    logger.write(logging_info)
    if i % 10 == 0:
        eval_info = eval(ais, 20*batch_size, batch_size)
        logger.write(eval_info)

In the below plot (in comparison to before HMC was tuned), we see that the points generated by AIS are much closer to the target distribution. 

In [None]:
fig, ax = plt.subplots()
plot_contours(target.log_prob, ax=ax, bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(samples, ax=ax, bounds=[-30, 30])

In the history, we see that the step size is decreased to increase the number of accepted HMC trajectories, this (on aggregate) increases the effective sample size. 

In [None]:
plot_history(logger.history)

If we compare the below plot of the log weights we see that the width of the distribution is lower, and their is less mass at the tails of the distribution. 

In [None]:
n, bins, patches = plt.hist(log_w.numpy(), density=True, alpha=0.75, bins=40)
plt.show()

In [None]:
n, bins, patches = plt.hist(log_w[log_w > -10].numpy(), density=True, alpha=0.75, bins=20)
plt.show()

## Visualise the effect of the number of AIS distributions
We see that as the number of AIS distributions increases, the effective sample size increases, and the variance in the importance log weights decreases.

In [None]:
range_n_distributions = [2, 4, 8, 16, 32, 64, 128]

In [None]:
logger = ListLogger()
log_weight_hist = [] # listlogger is meant for scalars so we store the log weight history separately. 

In [None]:
ess_hist = []
for n_ais_dist in tqdm(range_n_distributions):
    # turn off step size tuning, initial step size is reasonable and we only want to visualise the effect of 
    # the number of ais distributions. 
    ais = setup_ais(n_ais_dist, "hmc", step_size_init=1.0, n_outer=1)
    ais.transition_operator.set_eval_mode(True) 
    base_samples, base_log_w, ais_samples, ais_log_w = \
        ais.generate_eval_data(50*batch_size, batch_size)
    info = {"eval_ess_ais": effective_sample_size(log_w=ais_log_w, normalised=False).item(),
           "log_w_var": torch.var(ais_log_w).item()}
    logger.write(info)
    log_weight_hist.append(ais_log_w)

In [None]:
fig, axs = plt.subplots(2)
axs[0].plot(range_n_distributions, logger.history["eval_ess_ais"])
axs[0].set_ylabel("effective sample size")
axs[0].set_xlabel("number of intermediate ais distributions")

axs[1].plot(range_n_distributions, logger.history["log_w_var"])
axs[1].set_ylabel("var log w")
axs[1].set_xlabel("number of intermediate ais distributions")
axs[1].set_yscale("log")
plt.show()

The log variance initialy decreases by a huge amount, however as the number of AIS distributions increases, the log variance decreases more closely to a rate of 1/n. 

In [None]:
logger.history['log_w_var']

In [None]:
# Let's look at samples after a the max number of AIS steps. 
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
plot_contours(target.log_prob, ax=axs[0], bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(ais_samples[:1000], ax=axs[0], bounds=[-30, 30])
axs[0].set_title("samples (ais) vs target contours")

plot_contours(target.log_prob, ax=axs[1], bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(target.sample((1000,)), ax=axs[1], bounds=[-30, 30])
axs[1].set_title("samples (target) vs target contours")

In [None]:
# and for comparison with only a few intermediate distributions
n_ais_dist = 4 # change this number to see how the number of distributions effects the samples from AIS.
ais_2_dist = setup_ais(n_ais_dist, "hmc", step_size_init=1.0, n_outer=1)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
plot_contours(target.log_prob, ax=axs[0], bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(ais_2_dist.sample_and_log_weights(batch_size)[0], ax=axs[0], bounds=[-30, 30])
axs[0].set_title("samples (ais) vs target contours")

plot_contours(target.log_prob, ax=axs[1], bounds=[-30, 30], n_contour_levels=50)
plot_marginal_pair(target.sample((1000,)), ax=axs[1], bounds=[-30, 30])
axs[1].set_title("samples (target) vs target contours")

Plot log weight distribution for a relatively low number of AIS distributions vs a high number of AIS distributions

In [None]:
iter_n_low = 2
print(f"plotting log_w for {range_n_distributions[iter_n_low]} AIS distributions for first 100 samples")
log_w_low = log_weight_hist[iter_n_low][:1000].numpy()
n, bins, patches = plt.hist(log_w_low, density=True, alpha=0.75, bins=40, color="green", 
                            label=f"{range_n_distributions[iter_n_low]} ais dist")

iter_n_high = -1
print(f"plotting log_w for {range_n_distributions[iter_n_high]} AIS distributions for first 100 samples")
log_w_high = log_weight_hist[iter_n_high][:1000].numpy()
n, bins, patches = plt.hist(log_w_high, density=True, alpha=0.75, bins=40, color="blue", 
                            label=f"{range_n_distributions[iter_n_high]} ais dist")

plt.xscale("symlog") # use log x scale so we can see both on the same plot
plt.legend()
plt.xlabel("log_w")
plt.ylabel("density")
plt.show()

The same but dropping very low values for the log weights so we don't need to log the x-axis

In [None]:
log_w_low = log_weight_hist[iter_n_low]
log_w_low = log_w_low[log_w_low > -10][:1000].numpy()
n, bins, patches = plt.hist(log_w_low, density=True, alpha=0.75, bins=40, color="green", 
                            label=f"{range_n_distributions[iter_n_low]} ais dist")

log_w_high = log_weight_hist[iter_n_high]
log_w_high = log_w_high[log_w_high > -10][:1000].numpy()
n, bins, patches = plt.hist(log_w_high, density=True, alpha=0.75, bins=40, color="blue", 
                            label=f"{range_n_distributions[iter_n_high]} ais dist")
plt.legend()
plt.xlabel("log_w")
plt.ylabel("density")
plt.show()