<a href="https://colab.research.google.com/github/lollcat/fab-torch/blob/master/experiments/many_well/fab_many_well.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flow Annealed Importance Sampling Bootstrap: Many Well Problem
In this notebook we will compare training a flow using FAB with a prioritised buffer, to training a flow by reverse KL divergence minimisation. We will train the models relatively briefly to get indications of how well each method works in a relatively small amount of time, however better results may be obtained by simply increasing the training time. In this notebook we train a flow on a 6 dimensional version of the Many Well problem. The problem difficulty may be increased by increasing the dimension of the Many Well problem.

GPU is not required for this notebook. Each experiment runs on my laptop (CPU only) in under 10 minuates. If one decreases the number of AIS distributions to 1 then this is even faster (e.g. 2 minutes to run). 

# Setup Repo

In [None]:
# If using colab then run this cell.
!git clone https://github.com/lollcat/fab-torch
    
import os
os.chdir("fab-torch")

!pip install --upgrade .

# Let's go!

## Imports

In [None]:
import normflows as nf
import matplotlib.pyplot as plt
import torch
import numpy as np

from fab import FABModel, HamiltonianMonteCarlo, Metropolis
from fab.utils.logging import ListLogger
from fab.utils.plotting import plot_history, plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from fab.utils.prioritised_replay_buffer import PrioritisedReplayBuffer
from fab import Trainer, PrioritisedBufferTrainer
from fab.utils.plotting import plot_contours, plot_marginal_pair


from experiments.make_flow import make_wrapped_normflow_realnvp

## Setup Target distribution

In [None]:
dim = 6 # Can increase in to higher values that are multiples of two.
seed = 0

In [None]:
torch.manual_seed(0)  # seed of 0 for GMM problem
target = ManyWellEnergy(dim, a=-0.5, b=-6, use_gpu=True)

In [None]:
# plot the contours for the marginal distribution of the first 2D of target (i.e. the Double Well Problem). 
target.to("cpu")
fig, ax = plt.subplots()
plotting_bounds = (-3, 3)
plot_contours(target.log_prob_2D, bounds=plotting_bounds, n_contour_levels=40, ax=ax, grid_width_n_points=100)
if torch.cuda.is_available():
    target.to("cuda")

## Create FAB model

In [None]:
# hyper-parameters

# Flow
n_flow_layers = 10
layer_nodes_per_dim = 40
lr = 2e-4
max_gradient_norm = 100.0
batch_size = 128
n_iterations = 500
n_eval = 10
eval_batch_size = batch_size * 10
n_plots = 10 # number of plots shows throughout tranining
use_64_bit = True
alpha = 2.0

# AIS
transition_operator_type = "hmc"
n_intermediate_distributions = 4

# buffer config
n_batches_buffer_sampling = 4
maximum_buffer_length = batch_size * n_batches_buffer_sampling * 100
min_buffer_length = batch_size * n_batches_buffer_sampling * 10

# target p^\alpha q^{a-\alpha} as target for AIS. 
min_is_target = True
p_target = not min_is_target # Whether to use p as the target. 

In [None]:
if use_64_bit:
    torch.set_default_dtype(torch.float64)
    target = target.double()
    print(f"running with 64 bit")

### Setup flow

In [None]:
flow = make_wrapped_normflow_realnvp(dim, n_flow_layers=n_flow_layers, 
                                 layer_nodes_per_dim=layer_nodes_per_dim,
                                act_norm = False)

### Setup Transition operator

In [None]:
if transition_operator_type == "hmc":
    # very lightweight HMC.
    transition_operator = HamiltonianMonteCarlo(
            n_ais_intermediate_distributions=n_intermediate_distributions,
            dim=dim,
            base_log_prob=flow.log_prob,
            target_log_prob=target.log_prob,
            alpha=alpha,
            p_target=p_target,
            n_outer=1,
            L=5)
elif transition_operator_type == "metropolis":
    transition_operator = Metropolis(            
        n_ais_intermediate_distributions=n_intermediate_distributions,
        dim=dim,
        base_log_prob=flow.log_prob,
        target_log_prob=target.log_prob,
        alpha=alpha,
        p_target=p_target,
        n_updates=1,
        adjust_step_size=False,
        max_step_size=metropolis_step_size, # the same for all metropolis steps 
        min_step_size=metropolis_step_size,
        eval_mode=False,
                                  )
else:
    raise NotImplementedError

### Setup FAB model with prioritised replay buffer

In [None]:
# use GPU if available
if torch.cuda.is_available():
    flow.cuda()
    transition_operator.cuda()
    print(f"Running with GPU")

In [None]:
fab_model = FABModel(flow=flow,
                     target_distribution=target,
                     n_intermediate_distributions=n_intermediate_distributions,
                     transition_operator=transition_operator,
                     alpha=alpha,
                    )
optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
logger = ListLogger(save=False) # save training history

In [None]:
# Setup buffer.
def initial_sampler():
  # fill replay buffer using initialised model and AIS.
    point, log_w = fab_model.annealed_importance_sampler.sample_and_log_weights(
            batch_size, logging=False)
    return point.x, log_w, point.log_q
buffer = PrioritisedReplayBuffer(dim=dim, max_length=maximum_buffer_length,
                      min_sample_length=min_buffer_length,
                      initial_sampler=initial_sampler)

In [None]:
def plot(fab_model, n_samples = batch_size, dim=dim):
    n_rows = dim // 2
    fig, axs = plt.subplots(dim // 2, 2,  sharex=True, sharey=True, figsize=(10, n_rows*3))


    samples_flow = fab_model.flow.sample((n_samples,))
    samples_ais = fab_model.annealed_importance_sampler.sample_and_log_weights(n_samples,
                                                                               logging=False)[0].x

    for i in range(n_rows):
        plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 0], n_contour_levels=40)
        plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 1], n_contour_levels=40)

        # plot flow samples
        plot_marginal_pair(samples_flow, ax=axs[i, 0], bounds=plotting_bounds, marginal_dims=(i*2,i*2+1))
        axs[i, 0].set_xlabel(f"dim {i*2}")
        axs[i, 0].set_ylabel(f"dim {i*2 + 1}")



        # plot ais samples
        plot_marginal_pair(samples_ais, ax=axs[i, 1], bounds=plotting_bounds, marginal_dims=(i*2,i*2+1))
        axs[i, 1].set_xlabel(f"dim {i*2}")
        axs[i, 1].set_ylabel(f"dim {i*2+1}")
        plt.tight_layout()
    axs[0, 1].set_title("ais samples")  
    axs[0, 0].set_title("flow samples")
    plt.show()
    return [fig]

In [None]:
plot(fab_model) # Visualise model during initialisation.

In [None]:
# Setup trainer.
trainer = PrioritisedBufferTrainer(model=fab_model, optimizer=optimizer, 
                                   logger=logger, plot=plot,
                        buffer=buffer, n_batches_buffer_sampling=n_batches_buffer_sampling,
                     max_gradient_norm=max_gradient_norm, alpha=alpha, w_adjust_max_clip=None)

## Train model

Initially in training it is quite common to have nan losses. This is because HMC/AIS sometimes finds regions far outside the typical target range (e.g. position of 100 in a dimension). However training stabilizes as it progresses, and this is not an issue late in training. I have made a PR in the code to allow the user to create a filter criterion to automatically remove samples outside of a reasonable bound - this should improve training stability further. 

In [None]:
# Now run!
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots, \
            n_eval=n_eval, eval_batch_size=eval_batch_size, save=False) # note that the progress bar during training prints ESS w.r.t p^2/q. 

In the below plot of samples from the flow vs the target contours, and with the test set log prob throughout training, we see that the flow covers the target distribution quite well. It may be trained further to obtain even better results. 

In [None]:
# "_eval" means metrics calculated with eval_batch_size, _p_target means metrics calculated with AIS targetting p, p2overq_target means calculated with AIS targeting p^2/q. 
# For example 'eval_ess_flow_p2overq_target' is the effective sample size of the flow w.r.t the target distribution p^2/q when sampling from AIS with p^2/q as the target.
logger.history.keys() 

In [None]:
# Test set probability using samples from the target distribution.
eval_iters = np.linspace(0, n_iterations, n_eval)
plt.plot(eval_iters, logger.history['flow_test_set_exact_mean_log_prob_p_target'])
plt.ylabel("mean test set log prob")
plt.xlabel("training iteration")

In [None]:
# Effective sample size
eval_iters = np.linspace(0, n_iterations, n_eval)
plt.plot(eval_iters, logger.history['eval_ess_flow_p_target'], label="flow")
plt.ylabel("Effective Sample Size")
plt.xlabel("training iteration")

In [None]:
# Probability of test set containing a point on each mode
eval_iters = np.linspace(0, n_iterations, n_eval)
plt.plot(eval_iters, logger.history['flow_test_set_modes_mean_log_prob_p_target'])
plt.ylabel("Average log prob of modes test set")
plt.xlabel("training iteration")

In [None]:
# We can use the AIS (targetting p for evaluation) to further improve accuracy of 
# estimates of the normalization constant. 
plt.plot(eval_iters, logger.history['flow_MSE_log_Z_estimate_p_target'], label="flow")
plt.plot(eval_iters, logger.history['ais_MSE_log_Z_estimate_p_target'], label="ais")
plt.ylabel("MSE in estimation of the normalizing constant")
plt.legend()
plt.xlabel("training iteration")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(5, 5))
target.to("cpu")
plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs, n_contour_levels=40, grid_width_n_points=200)
if torch.cuda.is_available():
    target.to("cuda")

n_samples = 1000
samples_flow = fab_model.flow.sample((n_samples,)).detach()
plot_marginal_pair(samples_flow, ax=axs, bounds=plotting_bounds)

# Training a flow by reverse KL divergence minimisation.

In [None]:
loss_type = "flow_reverse_kl" # can set to "target_foward_kl" for training by maximum likelihood of samples from the Many Well target.

In [None]:
# Create flow using the same architecture.
flow = make_wrapped_normflow_realnvp(dim, n_flow_layers=n_flow_layers, 
                                 layer_nodes_per_dim=layer_nodes_per_dim,
                                act_norm = False)
optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
logger = ListLogger(save=False) # save training history

In [None]:
# use GPU if available
if torch.cuda.is_available():
    flow.cuda()
    transition_operator.cuda()
    print(f"Running with GPU")
    target.to("cuda")

In [None]:
n_iterations = int(4*(n_iterations)) # Training the flow by KL minimisation is cheaper per iteration, so we run it for more iterations.

In [None]:
reverse_kld_model = FABModel(flow=flow,
                     target_distribution=target,
                     n_intermediate_distributions=n_intermediate_distributions,
                     transition_operator=transition_operator,
                     loss_type=loss_type,
                    )

In [None]:
def plot_flow_reverse_kld(fab_model, n_samples = batch_size, dim=dim):
    n_rows = dim // 2
    fig, axs = plt.subplots(dim // 2, 1,  sharex=True, sharey=True, figsize=(5, n_rows*3))

    
    samples_flow = fab_model.flow.sample((n_samples,))

    for i in range(n_rows):
        plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i], n_contour_levels=40)

        # plot flow samples
        plot_marginal_pair(samples_flow, ax=axs[i], bounds=plotting_bounds, marginal_dims=(i*2,i*2+1))
        axs[i].set_xlabel(f"dim {i*2}")
        axs[i].set_ylabel(f"dim {i*2 + 1}")
        plt.tight_layout()
    plt.show()
    return [fig]

In [None]:
trainer = Trainer(model=reverse_kld_model, optimizer=optimizer, logger=logger, plot=plot_flow_reverse_kld, max_gradient_norm=max_gradient_norm)

In [None]:
# Now run!
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots, \
            n_eval=n_eval, eval_batch_size=eval_batch_size, save=False)

We evaluate the flow on samples from the target distribution, we see that because the flow trained by kl divergence minimisation is missing modes performance is very poor.

In [None]:
eval_iters = np.linspace(0, n_iterations, n_eval)
plt.plot(eval_iters, logger.history["flow_test_set_exact_mean_log_prob"])
plt.ylabel("mean test set log prob")
plt.xlabel("eval iteration")

In [None]:
eval_iters = np.linspace(0, n_iterations, n_eval)
plt.plot(eval_iters, logger.history["flow_test_set_modes_mean_log_prob"])
plt.ylabel("mean mode-test set log prob")
plt.xlabel("eval iteration")