<a href="https://colab.research.google.com/github/lollcat/fab-torch/blob/hug_face_many_well/demo/many_well.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install fab-torch repo

In [None]:
# If using colab then run this cell.
!git clone https://github.com/lollcat/fab-torch

import os
os.chdir("fab-torch")

!pip install --upgrade .

# Download weights from huggingface and run example of inference
We can just use CPU as the model is not that expensive.

In [None]:
# Restart after install, then run the below code
import os
import urllib

import matplotlib.pyplot as plt
from matplotlib import rc
import matplotlib as mpl
from hydra import compose, initialize
import torch

from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from experiments.setup_run import setup_model
from experiments.many_well.many_well_visualise_all_marginal_pairs import get_target_log_prob_marginal_pair

In [None]:
with initialize(version_base=None, config_path="fab-torch/experiments/config/", job_name="colab_app"):
    cfg = compose(config_name=f"many_well")

In [None]:
target = ManyWellEnergy(cfg.target.dim, a=-0.5, b=-6, use_gpu=False)
model = setup_model(cfg, target)

In [None]:
# Download weights from huggingface, and load them into the model
urllib.request.urlretrieve('https://huggingface.co/VincentStimper/fab/resolve/main/many_well/model.pt', 'model.pt')
model.load("model.pt", map_location="cpu")

In [None]:
# Sample from the model
n_samples: int = 200
samples_flow = model.flow.sample((n_samples,)).detach()

In [None]:
# Visualise samples
alpha = 0.3
plotting_bounds = (-3, 3)
dim = cfg.target.dim
fig, axs = plt.subplots(2, 2, sharex="row", sharey="row")

for i in range(2):
    for j in range(2):
        target_log_prob = get_target_log_prob_marginal_pair(target.log_prob, i, j + 2, dim)
        plot_contours(target_log_prob, bounds=plotting_bounds, ax=axs[i, j],
                      n_contour_levels=20, grid_width_n_points=100)
        plot_marginal_pair(samples_flow, marginal_dims=(i, j+2),
                           ax=axs[i, j], bounds=plotting_bounds, alpha=alpha)


        if j == 0:
            axs[i, j].set_ylabel(f"$x_{i + 1}$")
        if i == 1:
            axs[i, j].set_xlabel(f"$x_{j + 1 + 2}$")