In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
from fab.learnt_distributions.real_nvp import make_realnvp_dist_funcs
from fab.target_distributions.many_well import DoubleWellEnergy
from fab.agent.fab_agent import AgentFAB
from fab.utils.plotting import plot_history, plot_marginal_pair, plot_contours_2D
import matplotlib.pyplot as plt
import optax
import jax
import jax.numpy as jnp

In [None]:
dim = 2
flow_num_layers = 10
mlp_hidden_size_per_x_dim = 5
layer_norm = False
act_norm = True

In this notebook we visualise two different versions of the bootstrapped flow-ais training procedure (based off different losses). Both result in the flow fitting the target well. 

## Try with AIS chain targeting p^2 / q
i.e. style = "proptoloss"

In [None]:
real_nvp_flo = make_realnvp_dist_funcs(dim, flow_num_layers,
                                       mlp_hidden_size_per_x_dim=mlp_hidden_size_per_x_dim,
                                      layer_norm=layer_norm, act_norm=act_norm)
target_log_prob = DoubleWellEnergy(dim=dim).log_prob
batch_size = 64
eval_batch_size = batch_size
n_evals = 10
n_iter = int(3e3)
lr = 1e-3
n_intermediate_distributions: int = 2
AIS_kwargs = {"transition_operator_type": "hmc_tfp"} 
optimizer = optax.chain(optax.zero_nans(), optax.adam(lr))

# loss type 1
# loss_type = "alpha_2_div"  # "forward_kl"  "alpha_2_div"
# style = "vanilla"
# loss type 2
loss_type = "alpha_2_div"
style = "proptoloss"

In [None]:
def plotter(fab_agent, log_prob_2D=target_log_prob):
    batch_size = 100
    if style == "proptoloss":
        target_name = "p^2/q"
    else:
        target_name = "p"

    @jax.jit
    def get_info(state):
        base_log_prob = fab_agent.get_base_log_prob(state.learnt_distribution_params)
        target_log_prob = fab_agent.get_target_log_prob(state.learnt_distribution_params)
        x_base, log_q_x_base = fab_agent.learnt_distribution.sample_and_log_prob.apply(
            state.learnt_distribution_params, rng=state.key,
            sample_shape=(batch_size,))
        x_ais_loss, _, _, _ = \
            fab_agent.annealed_importance_sampler.run(
                x_base, log_q_x_base, state.key,
                state.transition_operator_state,
                base_log_prob=base_log_prob,
                target_log_prob=target_log_prob
            )
        return x_base, x_ais_loss

    x_base, x_ais_target = get_info(fab_agent.state)
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    plot_contours_2D(log_prob_2D, ax=axs[0], bound=3, levels=20)
    plot_marginal_pair(x_base, ax=axs[0])
    axs[0].set_title("base samples")
    plot_contours_2D(log_prob_2D, ax=axs[1], bound=3, levels=20)
    plot_marginal_pair(x_ais_target, ax=axs[1])
    axs[1].set_title(f"ais samples with target of: {target_name}")
    plt.show()
    return [fig]

In [None]:
fab_agent = AgentFAB(learnt_distribution=real_nvp_flo,
                     target_log_prob=target_log_prob,
                     n_intermediate_distributions=n_intermediate_distributions,
                     AIS_kwargs=AIS_kwargs,
                     optimizer=optimizer,
                     plotter=plotter,
                     loss_type=loss_type,
                      style=style,)

In [None]:
plotter(fab_agent)

In [None]:
fab_agent.run(n_iter=n_iter, batch_size=batch_size, n_plots=5, n_evals=n_evals, eval_batch_size=eval_batch_size)

In [None]:
plt.plot(fab_agent.logger.history["ess_base"])
plt.title("ess_base p^2/q")
plt.show()
plt.plot(fab_agent.logger.history["ess_ais"])
plt.title("ess_ais p^2/q")
plt.show()
plt.plot(fab_agent.logger.history['eval_ess_ais'])
plt.title("ess_ais over p")
plt.show()
plt.plot(fab_agent.logger.history['eval_ess_flow'])
plt.title("ess_base over p")
plt.show()

In [None]:
plotter(fab_agent)

## Try with ais chain targetting the target
i.e. style = "vanilla"

In [None]:
n_evals = None # don't need to evaluate seperately as we already target p to get our ESS

# set loss type to vanilla
loss_type = "alpha_2_div"  # "forward_kl"  "alpha_2_div"
style = "vanilla"

In [None]:
real_nvp_flo = make_realnvp_dist_funcs(dim, flow_num_layers,
                                       mlp_hidden_size_per_x_dim=mlp_hidden_size_per_x_dim,
                                      layer_norm=layer_norm, act_norm=act_norm)

In [None]:
fab_agent = AgentFAB(learnt_distribution=real_nvp_flo,
                     target_log_prob=target_log_prob,
                     n_intermediate_distributions=n_intermediate_distributions,
                     AIS_kwargs=AIS_kwargs,
                     optimizer=optimizer,
                     plotter=plotter,
                     loss_type=loss_type,
                     style=style,
                     
                    )

In [None]:
fab_agent.run(n_iter=n_iter, batch_size=batch_size, n_plots=5, n_evals=n_evals, eval_batch_size=eval_batch_size)

In [None]:
plt.plot(fab_agent.logger.history["ess_base"])
plt.title("ess_base")
plt.show()
plt.plot(fab_agent.logger.history["ess_ais"])
plt.title("ess_ais")
plt.show()