<a href="https://colab.research.google.com/github/lollcat/annealed_flow_transport/blob/craft-update/craft_gmm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Colab setup
If not using colab, then skip this section.

In [None]:
!git clone https://github.com/lollcat/annealed_flow_transport.git

In [None]:
import os
os.chdir("annealed_flow_transport")

In [None]:
!pip install chex==0.1.5 ml_collections optax dm-haiku distrax -q

## Imports

In [None]:
from annealed_flow_transport.train import run_experiment
from evaluation import make_forward_pass_func, load_checkpoint
from evaluation import evaluate_mog as evaluate
import numpy as np
import numpy as np
import jax

In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
from configs.fab_mog import get_config

In [None]:
exp_config = get_config()

# Train

### Run first seed

In [None]:
exp_config.seed = 1
exp_config.params_filename = "checkpoint_seed1"

In [None]:
results = run_experiment(exp_config)

In [None]:
filename = "checkpoint_seed1"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info = evaluate(forward_pass_function)
print(eval_info)

In [None]:
import numpy as np
import jax
particle_state = forward_pass_function(jax.random.PRNGKey(0))
samples = np.array(particle_state.samples)
np.save(open(f"samples_seed{exp_config.seed}.np", "wb"), samples)

### Run second seed

In [None]:
exp_config.seed = 2
exp_config.params_filename = "checkpoint_seed2"

In [None]:
results = run_experiment(exp_config)

In [None]:
filename = "checkpoint_seed2"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info = evaluate(forward_pass_function)
print(eval_info)

In [None]:
particle_state = forward_pass_function(jax.random.PRNGKey(0))
samples = np.array(particle_state.samples)
np.save(open(f"samples_seed{exp_config.seed}.np", "wb"), samples)

### Run third seed

In [None]:
exp_config.seed = 3
exp_config.params_filename = "checkpoint_seed3"

In [None]:
results = run_experiment(exp_config)

In [None]:
filename = "checkpoint_seed3"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info = evaluate(forward_pass_function)
print(eval_info)

In [None]:
particle_state = forward_pass_function(jax.random.PRNGKey(0))
samples = np.array(particle_state.samples)
np.save(open(f"samples_seed{exp_config.seed}.np", "wb"), samples)

# Evaluation of trained models

In [None]:
filename = "checkpoint_seed1"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info_seed1 = evaluate(forward_pass_function)
print(eval_info_seed1)

In [None]:
filename = "checkpoint_seed2"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info_seed2 = evaluate(forward_pass_function)
print(eval_info_seed2)

In [None]:
filename = "checkpoint_seed3"
transition_params = load_checkpoint(filename)
forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params)
eval_info_seed3 = evaluate(forward_pass_function)
print(eval_info_seed3)

In [None]:
eval_info = {key: np.asarray([eval_info_seed1[key], eval_info_seed2[key], eval_info_seed3[key]]) for key in eval_info_seed1.keys()}

In [None]:
{key: np.mean(val) for key, val in eval_info.items()}

In [None]:
import scipy

In [None]:
{key: scipy.stats.sem(val, ddof=0) for key, val in eval_info.items()}