In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
from fab.learnt_distributions.real_nvp import make_realnvp_dist_funcs
from fab.target_distributions.gmm import GMM
from fab.agent.fab_agent import AgentFAB
from fab.utils.plotting import plot_history, plot_marginal_pair, plot_contours_2D
import matplotlib.pyplot as plt
import optax
import jax
import jax.numpy as jnp

In [None]:
dim = 2
flow_num_layers = 10
mlp_hidden_size_per_x_dim = 5
layer_norm = True
act_norm = True

In [None]:
real_nvp_flo = make_realnvp_dist_funcs(dim, flow_num_layers,
                                       mlp_hidden_size_per_x_dim=mlp_hidden_size_per_x_dim,
                                      layer_norm=layer_norm, act_norm=act_norm)
target = GMM(dim, n_mixes=5, loc_scaling=12, log_var_scaling=1.0, seed=1)
target_log_prob = target.log_prob
batch_size = 128
n_iter = int(1e4)
lr = 2e-4
n_intermediate_distributions: int = 4
AIS_kwargs = {"additional_transition_operator_kwargs": {"step_tuning_method": "p_accept"}}
optimizer = optax.adamw(lr)

In [None]:
def plotter(fab_agent, batch_size=500):
    key=jax.random.PRNGKey(0)
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    plot_contours_2D(fab_agent.target_log_prob, ax=axs[0], bound=20, levels=30)
    plot_contours_2D(fab_agent.target_log_prob, ax=axs[1], bound=20, levels=30)
    samples = jax.jit(fab_agent.learnt_distribution.sample.apply, static_argnums=2)(
        fab_agent.state.learnt_distribution_params,
        key, (batch_size,))
    samples_ais, _, _, _ = \
                jax.jit(fab_agent.annealed_importance_sampler.run, static_argnums=0)(
                    batch_size, key, fab_agent.state.learnt_distribution_params,
                    fab_agent.state.transition_operator_state)
    plot_marginal_pair(samples, ax=axs[0], bounds=(-20, 20))
    plot_marginal_pair(samples_ais, ax=axs[1], bounds=(-20, 20))
    axs[0].set_title("samples from flow")
    axs[1].set_title("samples from ais")
    plt.show()
    return [fig]

In [None]:
fab_agent = AgentFAB(learnt_distribution=real_nvp_flo,
                     target_log_prob=target_log_prob,
                     n_intermediate_distributions=n_intermediate_distributions,
                     AIS_kwargs=AIS_kwargs,
                     optimizer=optimizer,
                     plotter=plotter)

In [None]:
plotter(fab_agent)

In [None]:
fab_agent.run(n_iter=n_iter, batch_size=batch_size, n_plots=5)

In [None]:
plt.plot(fab_agent.logger.history["ess_base"])
plt.title("ess_base")
plt.show()
plt.plot(fab_agent.logger.history["ess_ais"])
plt.title("ess_ais")
plt.show()
plt.show()

In [None]:
plotter(fab_agent)