In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
import normflow as nf
import matplotlib.pyplot as plt
import torch

from fab import FABModel, HamiltoneanMonteCarlo, Metropolis
from fab.train_with_prioritised_buffer import PrioritisedBufferTrainer
from fab.utils.logging import ListLogger
from fab.utils.plotting import plot_history, plot_contours, plot_marginal_pair
from fab.utils.prioritised_replay_buffer import PrioritisedReplayBuffer
from examples.make_flow import make_wrapped_normflowdist

## Configure Training

In [None]:
loss_type = "alpha_2_div" # "forward_kl" or "alpha_2_div"

dim: int = 2
n_intermediate_distributions: int = 2
layer_nodes_per_dim = 5
batch_size: int = 32
n_iterations: int = 500
n_eval = 20
eval_batch_size = batch_size * 10
n_plots: int = 5 # number of plots shows throughout tranining
lr: float = 2e-4
transition_operator_type: str = "hmc"  # "metropolis" or "hmc"
seed: int = 0
n_flow_layers: int = 10
# torch.set_default_dtype(torch.float64) # works with 32 bit precision
torch.manual_seed(seed)

log_w_clip_frac = 0.0

In [None]:
# buffer config
n_batches_buffer_sampling = 5
maximum_buffer_length = batch_size * 100
min_buffer_length = batch_size * n_batches_buffer_sampling * 5

## Setup Double Well target distribution

In [None]:
from fab.target_distributions.many_well import ManyWellEnergy
assert dim % 2 == 0
target = ManyWellEnergy(dim, a=-0.5, b=-6)
plotting_bounds = (-3, 3)

In [None]:
# plot target
plot_contours(target.log_prob, bounds=plotting_bounds)

## Setup Flow
By wrapping the [normflow library](https://github.com/VincentStimper/normalizing-flows). 

In [None]:
flow = make_wrapped_normflowdist(dim, n_flow_layers=n_flow_layers, layer_nodes_per_dim=layer_nodes_per_dim,
                                act_norm = True)

## Setup transition operator

In [None]:
if transition_operator_type == "hmc":
    # very lightweight HMC.
    transition_operator = HamiltoneanMonteCarlo(
        n_ais_intermediate_distributions=n_intermediate_distributions,
        n_outer=1,
        epsilon=1.0, L=2, dim=dim,
        step_tuning_method="p_accept")
elif transition_operator_type == "metropolis":
    transition_operator = Metropolis(n_transitions=n_intermediate_distributions,
                                     n_updates=5, adjust_step_size=True)
else:
    raise NotImplementedError

## Define model, trainer and plotter

In [None]:
fab_model = FABModel(flow=flow,
                     target_distribution=target,
                     n_intermediate_distributions=n_intermediate_distributions,
                     transition_operator=transition_operator,
                    loss_type=loss_type)
optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
scheduler = None
logger = ListLogger(save=False) # save training history

In [None]:
def plot(fab_model, n_samples = 300):
    fig, axs = plt.subplots(1, 2)

    plot_contours(target.log_prob, bounds=plotting_bounds, ax=axs[0])
    plot_contours(target.log_prob, bounds=plotting_bounds, ax=axs[1])

    # plot flow samples
    samples_flow = fab_model.flow.sample((n_samples,))
    plot_marginal_pair(samples_flow, ax=axs[0], bounds=plotting_bounds)


    # plot ais samples
    samples_ais = fab_model.annealed_importance_sampler.sample_and_log_weights(n_samples,
                                                                               logging=False)[0]
    plot_marginal_pair(samples_ais, ax=axs[1], bounds=plotting_bounds)
    axs[0].set_title("flow samples")
    axs[1].set_title("ais samples")
    plt.show()

In [None]:
# buffer
def initial_sampler():
    x, log_w = fab_model.annealed_importance_sampler.sample_and_log_weights(
            batch_size, logging=False)
    log_q_x = fab_model.flow.log_prob(x)
    return x, log_w, log_q_x
buffer = PrioritisedReplayBuffer(dim=dim, max_length=maximum_buffer_length,
                      min_sample_length=min_buffer_length,
                      initial_sampler=initial_sampler)

In [None]:
trainer = PrioritisedBufferTrainer(model=fab_model, optimizer=optimizer, logger=logger, plot=plot,
                  optim_schedular=scheduler,
                        buffer=buffer,
                        n_batches_buffer_sampling=n_batches_buffer_sampling)

In [None]:
plot(fab_model)

# Run with visualisation

In [None]:
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots, \
            n_eval=n_eval, eval_batch_size=eval_batch_size, save=False)

In the below plot:
ess = effective sample size
"Distance" refers to the distance moved during each intermediate transition. 

In [None]:
plot_history(logger.history)

In [None]:
plt.plot(logger.history["eval_ess_ais_p_target"])
plt.plot(logger.history["eval_ess_ais_p2overq_target"])

In [None]:
plt.plot(logger.history["w_adjust_max"])
plt.ylim([0, 10.0])

In [None]:
plt.plot(logger.history["w_adjust_mean"])
plt.ylim([0, 2.0])

In [None]:
history_with_priority = logger.history

## With normal buffer and current fab loss

In [None]:
from fab.utils.replay_buffer import ReplayBuffer
from fab import BufferTrainer

In [None]:
loss_type = "alpha_2_div" # "forward_kl" or "alpha_2_div"

In [None]:
# setup
flow = make_wrapped_normflowdist(dim, n_flow_layers=n_flow_layers, layer_nodes_per_dim=layer_nodes_per_dim,
                                act_norm = True)

if transition_operator_type == "hmc":
    # very lightweight HMC.
    transition_operator = HamiltoneanMonteCarlo(
        n_ais_intermediate_distributions=n_intermediate_distributions,
        n_outer=1,
        epsilon=1.0, L=2, dim=dim,
        step_tuning_method="p_accept")
elif transition_operator_type == "metropolis":
    transition_operator = Metropolis(n_transitions=n_intermediate_distributions,
                                     n_updates=5, adjust_step_size=True)
else:
    raise NotImplementedError
    

fab_model = FABModel(flow=flow,
                 target_distribution=target,
                 n_intermediate_distributions=n_intermediate_distributions,
                 transition_operator=transition_operator, loss_type=loss_type)
optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.995)
scheduler = None
logger = ListLogger(save=False) # save training history


# buffer
def initial_sampler():
    x, log_w = fab_model.annealed_importance_sampler.sample_and_log_weights(
            batch_size, logging=False)
    return x, log_w
buffer = ReplayBuffer(dim=dim, max_length=maximum_buffer_length,
                      min_sample_length=min_buffer_length,
                      initial_sampler=initial_sampler)


trainer = BufferTrainer(model=fab_model, optimizer=optimizer, logger=logger, plot=plot,
                  optim_schedular=scheduler,
                        buffer=buffer,
                        n_batches_buffer_sampling=n_batches_buffer_sampling,
                        clip_ais_weights_frac=log_w_clip_frac)

In [None]:
# run
trainer.run(n_iterations=n_iterations, batch_size=batch_size, n_plot=n_plots, \
            n_eval=n_eval, eval_batch_size=eval_batch_size, save=False)

In [None]:
plot_history(logger.history)

In [None]:
# flow vs ais
plt.plot(logger.history["eval_ess_flow"], label="flow")
plt.plot(logger.history["eval_ess_ais"], label="ais")
plt.legend()

In [None]:
# p^2/q vs fab, after ais
plt.plot(history_with_priority["eval_ess_ais_p_target"], label="p^2/q")
plt.plot(logger.history["eval_ess_ais"], label="current fab")
plt.legend()

In [None]:
# p^2/q vs fab, for flow
plt.plot(history_with_priority["eval_ess_flow_p_target"], label="p^2/q")
plt.plot(logger.history["eval_ess_flow"], label="current fab")
plt.legend()