# Visualise AIS Many Well
In this notebook we perform some visualisations of the annealed sampling algorithms performance, such as how AIS scales with the number of intermediate distributions for the Many Well problem with 32 dimensions. 

In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from fab.sampling_methods import AnnealedImportanceSampler, Metropolis, HamiltoneanMonteCarlo
from fab.utils.logging import ListLogger
from fab import FABModel
from fab.target_distributions import TargetDistribution
from fab.target_distributions.many_well import ManyWellEnergy
from examples.make_flow import make_wrapped_normflowdist
from fab.utils.plotting import plot_history, plot_contours, plot_marginal_pair
from fab.utils.numerical import effective_sample_size

## Setup Target Distribution & AIS based distribution

In [None]:
dim: int = 32
seed: int = 1
n_flow_layers = 10
layer_nodes_per_dim = 10
torch.manual_seed(seed)
batch_size = 1000
plotting_bounds = (-3, 3)

In [None]:
# setup target
target = ManyWellEnergy(dim, a=-0.5, b=-6)

In [None]:
# setup flow spec
flow = make_wrapped_normflowdist(dim, n_flow_layers=n_flow_layers,
                                     layer_nodes_per_dim=layer_nodes_per_dim)

In [None]:
n_intermediate_distributions = 4
n_inner_steps = 5

In [None]:
# setup transition_operator spec
transition_operator = HamiltoneanMonteCarlo(
            n_ais_intermediate_distributions=n_intermediate_distributions,
            n_outer=1,
            epsilon=1.0, L=n_inner_steps, dim=dim,
            step_tuning_method="p_accept")

In [None]:
# setup full model
fab_model = FABModel(flow=flow,
                         target_distribution=target,
                         n_intermediate_distributions=n_intermediate_distributions,
                         transition_operator=transition_operator)

In [None]:
# load trained model
fab_model.load("models/many_well_32/model.pt", "cpu")

In [None]:
def plot(fab_model, n_samples: int = batch_size, dim: int = dim):
    n_rows = dim // 2
    fig, axs = plt.subplots(dim // 2, 2, sharex=True, sharey=True, figsize=(10, n_rows * 3))

    samples_flow = fab_model.flow.sample((n_samples,))
    samples_ais = fab_model.annealed_importance_sampler.sample_and_log_weights(n_samples,
                                                                               logging=False)[0]

    for i in range(n_rows):
        plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 0])
        plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 1])

        # plot flow samples
        plot_marginal_pair(samples_flow, ax=axs[i, 0], bounds=plotting_bounds,
                           marginal_dims=(i * 2, i * 2 + 1))
        axs[i, 0].set_xlabel(f"dim {i * 2}")
        axs[i, 0].set_ylabel(f"dim {i * 2 + 1}")

        # plot ais samples
        plot_marginal_pair(samples_ais, ax=axs[i, 1], bounds=plotting_bounds,
                           marginal_dims=(i * 2, i * 2 + 1))
        axs[i, 1].set_xlabel(f"dim {i * 2}")
        axs[i, 1].set_ylabel(f"dim {i * 2 + 1}")
        plt.tight_layout()
    axs[0, 1].set_title("ais samples")
    axs[0, 0].set_title("flow samples")
    return [fig]

In [None]:
plot(fab_model)

## Now for varying number of AIS distributions

In [None]:
# fab_model.annealed_importance_sampler.transition_operator.epsilons

In [None]:
# fab_model.annealed_importance_sampler.transition_operator.common_epsilon

In [None]:
n_ais_dist = 5 
batch_size = 1000

In [None]:
def setup_ais(n_ais_intermediate_distributions):
    transition_operator = HamiltoneanMonteCarlo(
        n_ais_intermediate_distributions=n_ais_intermediate_distributions,
        n_outer=1,
        epsilon=1.0, L=1, dim=dim,
        step_tuning_method="p_accept",
        eval_mode=True)
    transition_operator.common_epsilon = torch.ones_like(transition_operator.common_epsilon) * 0.0065 # from trained HMC
    transition_operator.epsilons = torch.ones_like(transition_operator.epsilons) * 0.15
    base_dist = fab_model.flow
    ais = AnnealedImportanceSampler(base_distribution=base_dist,
                                    target_log_prob=target.log_prob,
                                    transition_operator=transition_operator,
                                    n_intermediate_distributions=n_ais_intermediate_distributions,
                                    )
    return ais

In [None]:
ais = setup_ais(3)

In [None]:
samples, log_w = ais.sample_and_log_weights(batch_size)

In [None]:
fig, ax = plt.subplots()
plot_contours(target.log_prob_2D, ax=ax, bounds=[-3, 3], n_contour_levels=50)
plot_marginal_pair(samples, ax=ax, bounds=[-3, 3])

In [None]:
n, bins, patches = plt.hist(log_w.detach().numpy(), density=True, alpha=0.75, bins=20)
plt.show()

## Visualise the effect of the number of AIS distributions
We see that as the number of AIS distributions increases, the effective sample size increases, and the variance in the importance log weights decreases.

In [None]:
range_n_distributions = [1, 2, 4, 8, 16, 32, 64, 128]

In [None]:
logger = ListLogger()
log_weight_hist = [] # listlogger is meant for scalars so we store the log weight history separately. 

In [None]:
ess_hist = []
for n_ais_dist in tqdm(range_n_distributions):
    # turn off step size tuning, initial step size is reasonable and we only want to visualise the effect of 
    # the number of ais distributions. 
    ais = setup_ais(n_ais_dist)
    ais.transition_operator.set_eval_mode(True) 
    base_samples, base_log_w, ais_samples, ais_log_w = \
        ais.generate_eval_data(50*batch_size, batch_size)
    base_samples, base_log_w, ais_samples, ais_log_w = base_samples.detach(), base_log_w.detach(), ais_samples.detach(), ais_log_w.detach() 
    info = {"eval_ess_ais": effective_sample_size(log_w=ais_log_w, normalised=False).item(),
           "log_w_var": torch.var(ais_log_w).item()}
    logger.write(info)
    log_weight_hist.append(ais_log_w)

In [None]:
fig, axs = plt.subplots(2)
axs[0].plot(range_n_distributions, logger.history["eval_ess_ais"])
axs[0].set_ylabel("effective sample size")
axs[0].set_xlabel("number of intermediate ais distributions")

axs[1].plot(range_n_distributions, logger.history["log_w_var"])
axs[1].set_ylabel("var log w")
axs[1].set_xlabel("number of intermediate ais distributions")
axs[1].set_yscale("log")
plt.show()

In [None]:
logger.history['log_w_var']

In [None]:
# Let's look at samples after a the max number of AIS steps. 
fig, axs = plt.subplots(1, figsize=(15, 5))
plot_contours(target.log_prob_2D, ax=axs, bounds=[-3, 3], n_contour_levels=50)
plot_marginal_pair(ais_samples[:1000], ax=axs, bounds=[-3, 3])
axs.set_title("samples (ais) vs target contours")

In [None]:
# and for comparison with only a few intermediate distributions
n_ais_dist = 1 # change this number to see how the number of distributions effects the samples from AIS.
ais_2_dist = setup_ais(n_ais_dist)

fig, axs = plt.subplots(1, figsize=(15, 5))
plot_contours(target.log_prob_2D, ax=axs, bounds=[-3, 3], n_contour_levels=50)
plot_marginal_pair(ais_2_dist.sample_and_log_weights(1000)[0], ax=axs, bounds=[-30, 30])
axs.set_title("samples (ais) vs target contours")

Plot log weight distribution for a relatively low number of AIS distributions vs a high number of AIS distributions

In [None]:
iter_n_low = 0
print(f"plotting log_w for {range_n_distributions[iter_n_low]} AIS distributions for first 100 samples")
log_w_low = log_weight_hist[iter_n_low][:1000].numpy()
n, bins, patches = plt.hist(log_w_low, density=True, alpha=0.75, bins=40, color="green", 
                            label=f"{range_n_distributions[iter_n_low]} ais dist")

iter_n_high = -1
print(f"plotting log_w for {range_n_distributions[iter_n_high]} AIS distributions for first 100 samples")
log_w_high = log_weight_hist[iter_n_high][:1000].numpy()
n, bins, patches = plt.hist(log_w_high, density=True, alpha=0.75, bins=40, color="blue", 
                            label=f"{range_n_distributions[iter_n_high]} ais dist")

plt.xscale("symlog") # use log x scale so we can see both on the same plot
plt.legend()
plt.xlabel("log_w")
plt.ylabel("density")
plt.show()

The same but dropping very low values for the log weights so we don't need to log the x-axis

In [None]:
log_w_low = log_weight_hist[iter_n_low]
log_w_low = log_w_low[log_w_low > -10][:1000].numpy()
n, bins, patches = plt.hist(log_w_low, density=True, alpha=0.75, bins=40, color="green", 
                            label=f"{range_n_distributions[iter_n_low]} ais dist")

log_w_high = log_weight_hist[iter_n_high]
log_w_high = log_w_high[log_w_high > -10][:1000].numpy()
n, bins, patches = plt.hist(log_w_high, density=True, alpha=0.75, bins=40, color="blue", 
                            label=f"{range_n_distributions[iter_n_high]} ais dist")
plt.legend()
plt.xlabel("log_w")
plt.ylabel("density")
plt.show()

# Visualise sample changes within a long chain
Useful to toggle with / without loaded model. 

In [None]:
def run(n_ais_dist, batch_size = 64):
    ais = setup_ais(n_ais_dist)
    target_minus_base_hist = []
    
    
    # Initialise AIS with samples from the base distribution.
    x, log_prob_p0 = ais.base_distribution.sample_and_log_prob((batch_size,))
    x, log_prob_p0 = ais._remove_nan_and_infs(x, log_prob_p0, descriptor="chain init")
    target_minus_base = ais.target_log_prob(x) - ais.base_distribution.log_prob(x)
    target_minus_base_hist.append(np.asarray(target_minus_base.detach()))

    log_w = ais.intermediate_unnormalised_log_prob(x, 1) - log_prob_p0
    # Move through sequence of intermediate distributions via MCMC.
    for j in range(1, ais.n_intermediate_distributions+1):
        x, log_w = ais.perform_transition(x, log_w, j)
        target_minus_base = ais.target_log_prob(x) - ais.base_distribution.log_prob(x)
        target_minus_base_hist.append(np.asarray(target_minus_base.detach()))
    
    return target_minus_base_hist

In [None]:
target_minus_base_hist = run(n_ais_dist=200)

In [None]:
plt.plot(np.asarray(target_minus_base_hist), "ob", alpha=0.05)
plt.ylabel("log prob target - log prob base")
plt.show()