<a href="https://colab.research.google.com/github/lollcat/fab-jax/blob/eval/experiments/fabjax_quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install fab-jax library

In [None]:
!git clone https://github.com/lollcat/fab-jax.git

In [None]:
import os
os.listdir()
os.chdir("fab-jax")
os.listdir()

In [None]:
!pip install -e .

# Run code
To run the experiments, I reccomend using the commands in the repo's README, with the Wandb logger. However, in this notebook I show how to run the experiments with the list-logger, additionally this visualizes the performance of the flow/AIS during the training run.

In [None]:
# Restart the notebook after the install and then run the code from here.
import os
os.listdir()
os.chdir("fab-jax")
os.listdir()

In [None]:
from hydra import compose, initialize
import jax
import jax.numpy as jnp
import chex
import matplotlib.pyplot as plt

from fabjax.targets.gmm_v1 import GaussianMixture2D as GMMV1
from fabjax.targets.gmm_v0 import GMM as GMMV0
from fabjax.targets.many_well import ManyWellEnergy
from fabjax.targets.cox import CoxDist
from fabjax.targets.funnel import FunnelSet
from fabjax.sampling.resampling import log_effective_sample_size
from fabjax.train.evaluate import calculate_log_forward_ess
from fabjax.train.generic_training_loop import train

from experiments.setup_training import setup_fab_config, setup_general_train_config

In [None]:
problem_names = ["gmm_v0", "gmm_v1", "many_well", "cox", "funnel"]

problem_name = problem_names[1]

with initialize(version_base=None, config_path="experiments/config", job_name="colab_app"):
    cfg = compose(config_name=f"{problem_name}")

# Replace Wandb logger with list logger.
if "logger" in cfg.keys():
    del cfg.logger

In [None]:
# Setup target distribution.
if problem_name == "gmm_v0":
  target = GMMV0()
elif problem_name == "gmm_v1":
  target = GMMV1(width_in_n_modes=cfg.target.width_in_n_modes)
elif problem_name == "many_well":
  # By default cfg.target.dim = 32. Can manually override this to make the problem easier/more challenging
  target = ManyWellEnergy(dim=cfg.target.dim)
elif problem_name == "cox":
  # By default cfg.target.num_grid_per_dim = 40.
  target = CoxDist(num_grid_per_dim=cfg.target.num_grid_per_dim)
elif problem_name == "funnel":
  target = FunnelSet()

In [None]:
full_run = True
if not full_run:
  cfg.training.n_epoch = cfg.training.n_epoch / 10

In [None]:
fab_config = setup_fab_config(cfg, target)
experiment_config = setup_general_train_config(fab_config)
logger, state = train(experiment_config)

In [None]:
# Print error in estimate of log_Z estimation via importance sampling with the flow.
# See the final plot in the above cell, and logger.history.keys() for more evaluation metrics (such as forward and reverse effective sample size).
logger.history['mean_abs_err_log_z_flow']

# Manually perform forward pass of flow and AIS.

This code just shows a simple example of running some inference and evaluation. See [here](https://github.com/lollcat/fab-jax/blob/32d4d6521203e39384bdea674b19ea2d58455446/experiments/setup_training.py#L235) for the proper evaluation code used throughout training.

In [None]:
# Sample and log prob from flow
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
n_samples = 128

def log_q_fn(x: chex.Array) -> chex.Array:
    return fab_config.flow.log_prob_apply(state.flow_params, x)

# Sample from flow.
x_flow, log_q_flow = fab_config.flow.sample_and_log_prob_apply(state.flow_params, key1, (n_samples,))
log_w_flow = fab_config.log_p_fn(x_flow) - log_q_flow # Importance weights

point, log_w_ais, smc_state, smc_info = fab_config.ais_eval.step(x_flow, state.smc_state, log_q_fn, fab_config.log_p_fn)
x_ais = point.x

In [None]:
# Sample from target (in this case we have access to ground truth samples).
x_target = target.sample(key2, (n_samples,))

In [None]:
# Plot marginal for first two dimensions.
plt.plot(x_flow[:, 0], x_flow[:, 1], "o", label="flow samples")
plt.plot(x_ais[:, 0], x_ais[:, 1], "o", label="AIS samples")
plt.plot(x_target[:, 0], x_target[:, 1], "o", label="target samples")
plt.legend()
plt.show()

In [None]:


# Accuracy is estimation of log_Z (with n_samples).
log_z_flow = jax.nn.logsumexp(log_w_flow, axis=-1) - jnp.log(n_samples)
log_z_ais = jax.nn.logsumexp(log_w_ais, axis=-1) - jnp.log(n_samples)
abs_err_log_z_flow = jnp.abs(log_z_flow - target.log_Z)
abs_err_log_z_ais = jnp.abs(log_z_ais - target.log_Z)

# Reverse ESS.
reverse_ess_flow = jnp.exp(log_effective_sample_size(log_w_flow))
reverse_ess_ais = jnp.exp(log_effective_sample_size(log_w_ais))

# Forward ESS.
log_w_fwd = target.log_prob(x_target) - fab_config.flow.log_prob_apply(state.flow_params, x_target)
fwd_ess = jnp.exp(calculate_log_forward_ess(log_w_fwd))

print(f"log_Z abs error flow: {(abs_err_log_z_flow):.3f} %")
print(f"log_Z abs error ais: {(abs_err_log_z_ais):.3f} %")
print(f"Reverse ESS flow: {(reverse_ess_flow*100):.2f} %")
print(f"Reverse ESS AIS: {(reverse_ess_ais*100):.2f} %")
print(f"Forward ESS flow: {(fwd_ess*100):.2f} %")