In [None]:
import sys
sys.path.insert(0, "../")

In [None]:
from fab.learnt_distributions.real_nvp import make_realnvp_dist_funcs
from fab.target_distributions.gmm import GMM
from fab.agent.fab_agent_prioritised import PrioritisedAgentFAB, State
from fab.utils.prioritised_replay_buffer import PrioritisedReplayBuffer
from fab.utils.plotting import plot_history, plot_marginal_pair, plot_contours_2D
import matplotlib.pyplot as plt
import optax
import jax
import jax.numpy as jnp

In [None]:
loc_scaling= 40
n_mixes = 40
dim = 2
flow_num_layers = 30
mlp_hidden_size_per_x_dim = 10
layer_norm = False
act_norm = False
lu_layer = False

In [None]:
real_nvp_flo = make_realnvp_dist_funcs(dim, flow_num_layers,
                                       mlp_hidden_size_per_x_dim=mlp_hidden_size_per_x_dim,
                                      layer_norm=layer_norm, act_norm=act_norm,
                                      lu_layer=lu_layer,
                                      use_exp=True)
target = GMM(dim, n_mixes=n_mixes, loc_scaling=loc_scaling, log_var_scaling=1.0, seed=1)
target_log_prob = target.log_prob
batch_size = 64
eval_batch_size = batch_size*2
n_evals = 10
n_iter = int(3e4)
lr = 1e-4
n_plots = 6
n_buffer_updates_per_forward = 8
n_intermediate_distributions: int = 4
AIS_kwargs = {"transition_operator_type": "hmc_tfp",
             "additional_transition_operator_kwargs":
                      {"init_step_size": 1.0}} 
# optimizer = optax.chain(optax.zero_nans(), optax.clip_by_global_norm(max_grad_norm), optax.adam(lr))
optimizer = optax.chain(optax.zero_nans(), optax.adam(lr))

In [None]:
# Visualise the target
bound = int(loc_scaling * 1.4)
levels=80
plot_contours_2D(target_log_prob, bound=bound, levels=80)

In [None]:
target

In [None]:
# Visualise the target
bound = int(loc_scaling * 1.4)
levels=80
fig, ax = plt.subplots()
samples = target.sample(jax.random.PRNGKey(0), (500,))
ax.plot(samples[:, 0], samples[:, 1],  "o", alpha=0.1)
plot_contours_2D(target_log_prob, ax=ax, bound=bound, levels=80)
plt.show()

In [None]:
buffer = PrioritisedReplayBuffer(dim=dim,
                          max_length=batch_size*n_buffer_updates_per_forward*100,
                          min_sample_length=batch_size*n_buffer_updates_per_forward*10)


In [None]:
def plotter(fab_agent, batch_size = 500):
    log_prob_2D = fab_agent.target_log_prob
    @jax.jit
    def get_info(state):
        base_log_prob = fab_agent.get_base_log_prob(state.learnt_distribution_params)
        target_log_prob = fab_agent.get_ais_target_log_prob(state.learnt_distribution_params)
        x_base, log_q_x_base = fab_agent.learnt_distribution.sample_and_log_prob.apply(
            state.learnt_distribution_params, rng=state.key,
            sample_shape=(batch_size,))
        x_ais_loss, _, _, _ = \
            fab_agent.annealed_importance_sampler.run(
                x_base, log_q_x_base, state.key,
                state.transition_operator_state,
                base_log_prob=base_log_prob,
                target_log_prob=target_log_prob
            )
        x_buffer = buffer_samples = fab_agent.replay_buffer.sample(state.buffer_state, state.key, batch_size)[0]
        return x_base, x_ais_loss, x_buffer

    x_base, x_ais_loss, x_buffer = get_info(fab_agent.state)
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    plot_contours_2D(log_prob_2D, ax=axs[0], bound=bound, levels=levels)
    plot_marginal_pair(x_base, ax=axs[0], bounds=(-bound, bound))
    axs[0].set_title("base samples")
    plot_contours_2D(log_prob_2D, ax=axs[1], bound=bound, levels=levels)
    plot_marginal_pair(x_ais_loss, ax=axs[1], bounds=(-bound, bound))
    axs[1].set_title("p^2 / q samples")
    plot_contours_2D(log_prob_2D, ax=axs[2], bound=bound, levels=levels)
    plot_marginal_pair(x_buffer, ax=axs[2], bounds=(-bound, bound))
    axs[2].set_title("buffer samples")
    plt.show()
    return [fig]


In [None]:
fab_agent = PrioritisedAgentFAB(learnt_distribution=real_nvp_flo,
                                    target_log_prob=target_log_prob,
                                    n_intermediate_distributions=n_intermediate_distributions,
                                    replay_buffer=buffer,
                                    n_buffer_updates_per_forward=n_buffer_updates_per_forward,
                                    AIS_kwargs=AIS_kwargs,
                                    optimizer=optimizer,
                                    plotter=plotter,
                                    max_w_adjust=10.0,
                                    )

In [None]:
plotter(fab_agent)

In [None]:
fab_agent.run(n_iter=n_iter, batch_size=batch_size, n_plots=n_plots, n_evals=n_evals, eval_batch_size=eval_batch_size)

In [None]:
plt.plot(fab_agent.logger.history["ess_base"])
plt.title("ess_base")
plt.show()
plt.plot(fab_agent.logger.history["ess_ais"])
plt.title("ess_ais")
plt.show()
plt.show()

In [None]:
plt.plot(fab_agent.logger.history["eval_ess_flow"])
plt.title("ess_base")
plt.show()
plt.plot(fab_agent.logger.history["eval_ess_ais_p"])
plt.title("ess_ais")
plt.show()
plt.show()

In [None]:
plotter(fab_agent)