<a href="https://colab.research.google.com/github/lollcat/fab-jax/blob/main/experiments/fabjax_quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install fab-jax library

In [None]:
!git clone https://github.com/lollcat/fab-jax.git

In [None]:
import os
os.listdir()
os.chdir("fab-jax")
os.listdir()

In [None]:
!pip install -e .

# Run code
To run the experiments, I reccomend using the commands in the repo's README, with the Wandb logger. However, in this notebook I show how to run the experiments with the list-logger, additionally this visualizes the performance of the flow/AIS during the training run.

In [None]:
# Restart the notebook after the install and then run the code from here.
import os
os.listdir()
os.chdir("fab-jax")
os.listdir()

In [None]:
from hydra import compose, initialize
import jax

from fabjax.train.generic_training_loop import train

from experiments.setup_training import setup_fab_config, setup_general_train_config

from fabjax.targets.gmm_v1 import GaussianMixture2D as GMMV1
from fabjax.targets.gmm_v0 import GMM as GMMV0
from fabjax.targets.many_well import ManyWellEnergy
from fabjax.targets.cox import CoxDist
from fabjax.targets.funnel import FunnelSet

In [None]:
problem_names = ["gmm_v0", "gmm_v1", "many_well", "cox", "funnel"]

problem_name = problem_names[1]

with initialize(version_base=None, config_path="experiments/config", job_name="colab_app"):
    cfg = compose(config_name=f"{problem_name}")

# Replace Wandb logger with list logger.
if "logger" in cfg.keys():
    del cfg.logger

In [None]:
# Setup target distribution.
if problem_name == "gmm_v0":
  target = GMMV0()
elif problem_name == "gmm_v1":
  target = GMMV1(width_in_n_modes=cfg.target.width_in_n_modes)
elif problem_name == "many_well":
  # By default cfg.target.dim = 32. Can manually override this to make the problem easier/more challenging
  target = ManyWellEnergy(dim=cfg.target.dim)
elif problem_name == "cox":
  # By default cfg.target.num_grid_per_dim = 40.
  target = CoxDist(num_grid_per_dim=cfg.target.num_grid_per_dim)
elif problem_name == "funnel":
  target = FunnelSet()

In [None]:
full_run = True
if not full_run:
  cfg.training.n_epoch = cfg.training.n_epoch / 10

In [None]:
fab_config = setup_fab_config(cfg, target)
experiment_config = setup_general_train_config(fab_config)
logger, state = train(experiment_config)

In [None]:
# Print error in estimate of log_Z estimation via importance sampling with the flow.
# See the final plot in the above cell, and logger.history.keys() for more evaluation metrics (such as forward and reverse effective sample size).
logger.history['mean_abs_err_log_z_flow']