In [None]:
# If running locally
import sys
sys.path.insert(0, "../")

In [None]:
# If using Colab
!pip install git+https://github.com/lollcat/FAB-JAX.git

In [None]:
from fab.learnt_distributions.real_nvp import make_realnvp_dist_funcs
from fab.target_distributions.many_well import ManyWellEnergy
from fab.agent.fab_agent import AgentFAB
from fab.utils.plotting import plot_history, plot_marginal_pair, plot_contours_2D
import matplotlib.pyplot as plt
import optax
import jax
import jax.numpy as jnp

In [None]:
dim = 8
flow_num_layers = 10
mlp_hidden_size_per_x_dim = 5
layer_norm = False
act_norm = True

In [None]:
real_nvp_flo = make_realnvp_dist_funcs(dim, flow_num_layers,
                                       mlp_hidden_size_per_x_dim=mlp_hidden_size_per_x_dim, 
                                       use_exp=True, layer_norm=layer_norm, act_norm=act_norm)
target_log_prob = ManyWellEnergy(dim=dim).log_prob
batch_size = 128
n_iter = int(2e4)
n_plots = 10
lr = 5e-4
max_grad_norm = 1.0
n_intermediate_distributions: int = 2
AIS_kwargs = {"additional_transition_operator_kwargs": {"step_tuning_method": "p_accept"}}
optimizer = optax.chain(optax.zero_nans(), optax.clip_by_global_norm(max_grad_norm), optax.adamw(lr))

In [None]:
def plotter(fab_agent, n_samples = batch_size, dim=dim, key=jax.random.PRNGKey(0)):
    plotting_bounds = 3
    target = ManyWellEnergy(dim=dim)
    n_rows = dim // 2
    fig, axs = plt.subplots(dim // 2, 2,  sharex=True, sharey=True, figsize=(10, n_rows*3))


    samples_flow = fab_agent.learnt_distribution.sample.apply(fab_agent.learnt_distribution_params, key, (batch_size,))
    samples_ais = fab_agent.annealed_importance_sampler.run(key, fab_agent.learnt_distribution_params)[0]

    for i in range(n_rows):
        plot_contours_2D(target.log_prob_2D, bound=plotting_bounds, ax=axs[i, 0])
        plot_contours_2D(target.log_prob_2D, bound=plotting_bounds, ax=axs[i, 1])

        # plot flow samples
        plot_marginal_pair(samples_flow, ax=axs[i, 0], bounds=(-plotting_bounds, plotting_bounds), marginal_dims=(i*2,i*2+1))
        axs[i, 0].set_xlabel(f"dim {i*2}")
        axs[i, 0].set_ylabel(f"dim {i*2 + 1}")



        # plot ais samples
        plot_marginal_pair(samples_ais, ax=axs[i, 1], bounds=(-plotting_bounds, plotting_bounds), marginal_dims=(i*2,i*2+1))
        axs[i, 1].set_xlabel(f"dim {i*2}")
        axs[i, 1].set_ylabel(f"dim {i*2+1}")
        plt.tight_layout()
    axs[0, 1].set_title("ais samples")  
    axs[0, 0].set_title("flow samples")
    plt.show()
    return [fig]

In [None]:
fab_agent = AgentFAB(learnt_distribution=real_nvp_flo,
                     target_log_prob=target_log_prob,
                     n_intermediate_distributions=n_intermediate_distributions,
                     AIS_kwargs=AIS_kwargs,
                     optimizer=optimizer,
                    plotter=plotter)

In [None]:
fab_agent.run(n_iter=n_iter, n_plots=n_plots)

In [None]:
plot_history(fab_agent.history)
plt.show()

In [None]:
fab_agent.history["n_non_finite_ais_x_samples"]

In [None]:
plotter(fab_agent)