# Setup Repository (if running in Colab, otherwise skip)
This notebook requires quite a bit of compute and time to run. We recommend running it with
Colab using GPU.

In [None]:
!git clone https://github.com/lollcat/FAB-TORCH/

In [None]:
import os
os.chdir("FAB-TORCH")


In [None]:
!pip install --upgrade . # install FAB library

# Non-Colab
If running locally (not in colab), then simply run the below cell to get setup.

In [None]:
import sys
sys.path.insert(0, "../")

# Let's go

## Setup imports

In [None]:
import normflow as nf
import matplotlib.pyplot as plt
import torch

from fab import FABModel, HamiltoneanMonteCarlo, Trainer, Metropolis
from fab.utils.logging import ListLogger
from fab.utils.plotting import plot_history, plot_contours, plot_marginal_pair
from examples.make_flow import make_wrapped_normflowdist

## Configure Training

In [None]:
dim: int = 16
layer_nodes_per_dim = 10
n_intermediate_distributions: int = 2
batch_size: int = 1024
n_iterations: int = int(4e4)
n_eval = 100
eval_batch_size = batch_size * 10
n_plots: int = 20 # number of plots shows throughout tranining
lr: float = 5e-4
transition_operator_type: str = "hmc"  # "metropolis" or "hmc"
seed: int = 0
n_flow_layers: int = 10
# torch.set_default_dtype(torch.float64) # works with 32 bit
torch.manual_seed(seed)

## Setup ManyWell target distribution

In [None]:
from fab.target_distributions.many_well import ManyWellEnergy
assert dim % 2 == 0
target = ManyWellEnergy(dim, a=-0.5, b=-6)
plotting_bounds = (-3, 3)

In [None]:
# plot target of 2 dimensions
plot_contours(target.log_prob_2D, bounds=plotting_bounds)

In [None]:
flow = make_wrapped_normflowdist(dim, n_flow_layers=n_flow_layers, layer_nodes_per_dim=layer_nodes_per_dim)

In [None]:
flow.sample((3,)).shape # check sample shape is of correct dim

## Setup transition operator

In [None]:
if transition_operator_type == "hmc":
    # very lightweight HMC.
    transition_operator = HamiltoneanMonteCarlo(
        n_ais_intermediate_distributions=n_intermediate_distributions,
        n_outer=1,
        epsilon=1.0, L=5, dim=dim,
        step_tuning_method="p_accept")
elif transition_operator_type == "metropolis":
    transition_operator = Metropolis(n_transitions=n_intermediate_distributions,
                                     n_updates=5, adjust_step_size=True)
else:
    raise NotImplementedError

## Define model, trainer and plotter

In [None]:
# use GPU if available
if torch.cuda.is_available():
  flow.cuda()
  transition_operator.cuda()
  print("utilising GPU")
flow.sample((1,)).device

In [None]:
fab_model = FABModel(flow=flow,
                     target_distribution=target,
                     n_intermediate_distributions=n_intermediate_distributions,
                     transition_operator=transition_operator)
optimizer = torch.optim.AdamW(flow.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
scheduler = None
logger = ListLogger(save=False) # save training history

In [None]:
def plot(fab_model, n_samples = batch_size, dim=dim):
    n_rows = dim // 2
    fig, axs = plt.subplots(dim // 2, 2,  sharex=True, sharey=True, figsize=(10, n_rows*3))

    
    samples_flow = fab_model.flow.sample((n_samples,))
    samples_ais = fab_model.annealed_importance_sampler.sample_and_log_weights(n_samples,
                                                                               logging=False)[0]

    for i in range(n_rows):
      plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 0])
      plot_contours(target.log_prob_2D, bounds=plotting_bounds, ax=axs[i, 1])

      # plot flow samples
      plot_marginal_pair(samples_flow, ax=axs[i, 0], bounds=plotting_bounds, marginal_dims=(i*2,i*2+1))
      axs[i, 0].set_xlabel(f"dim {i*2}")
      axs[i, 0].set_ylabel(f"dim {i*2 + 1}")
      


      # plot ais samples
      plot_marginal_pair(samples_ais, ax=axs[i, 1], bounds=plotting_bounds, marginal_dims=(i*2,i*2+1))
      axs[i, 1].set_xlabel(f"dim {i*2}")
      axs[i, 1].set_ylabel(f"dim {i*2+1}")
      plt.tight_layout()
    axs[0, 1].set_title("ais samples")  
    axs[0, 0].set_title("flow samples")
    plt.show()

In [None]:
# Create trainer
trainer = Trainer(model=fab_model, optimizer=optimizer, logger=logger, plot=plot,
                  optim_schedular=scheduler)

# Run with visualisation

Note: The Nan's that pop-up during training are because the flow produces some extreme samples, that give NaN under the target. This does not harm long term training, and I will simplify the error message to make it prettier.

In [None]:
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots,
            n_eval=n_eval, eval_batch_size=eval_batch_size, save=False)

In [None]:
# forgot to drop NaN IS weights in the eval step, which makes the eval_ess_ais graph not work.
This is fixed in more recent versions of the code so should look fine wh this function is re-run.
plot_history(logger.history)

In [None]:
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots)

In [None]:
plot_history(logger.history)

## Visualise Trained Flow Model

In [None]:
plot(fab_model)