In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
!echo $CUDA_VISIBLE_DEVICES

In [None]:
!git checkout craft-update 

In [None]:
from annealed_flow_transport.train import run_experiment
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from annealed_flow_transport.many_well_plotting import plot, plot_marginal_pair
from annealed_flow_transport.densities import ManyWell
from annealed_flow_transport.resampling import log_effective_sample_size

import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 100

In [4]:
from configs.many_well_original import get_config as many_well_original_get_config
from configs.many_well import get_config as many_well_get_config

In [6]:
def get_experiment_config(seed, experiment_type = "original", long=False):
    if experiment_type == "original":
        exp_config = many_well_original_get_config()
        exp_config.n_samples_plotting = 100
        if long:
            exp_config.craft_num_iters = int(1e10 / exp_config.mcmc_config.hmc_num_leapfrog_steps 
                                             / exp_config.batch_size / (exp_config.num_temps - 1))
            exp_config.report_step = int(exp_config.craft_num_iters // 6)
    else:
        assert experiment_type == "custom"
        exp_config = many_well_get_config()
#         exp_config.craft_num_iters = int(exp_config.craft_num_iters/10)
#         exp_config.report_step = int(exp_config.report_step/10)
        if long:
            raise Exception
    exp_config.seed = seed
    exp_config.params_filename = f"checkpoint_{experiment_type}_L{long}_seed{exp_config.seed}"
    
    n_mill_target_eval = exp_config.craft_batch_size*exp_config.craft_num_iters*exp_config.num_temps*exp_config.mcmc_config.hmc_num_leapfrog_steps/1e6
    print(f"experiment using {n_mill_target_eval} target evaluations")
    return exp_config

# Train using the exact config

In [None]:
exp_config = get_experiment_config(0, "custom")
results0 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(0, "original")
results0 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(1, "original")
results1 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(2, "original")
results2 = run_experiment(exp_config)

# Train for longer

In [None]:
exp_config = get_experiment_config(0, "original", long=True)
results0 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(1, "original", long=True)
results1 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(2, "original", long=True)
results2 = run_experiment(exp_config)

# Train using FAB config

In [None]:
exp_config = get_experiment_config(0, "custom")
results0 = run_experiment(exp_config)

training for 244140 iterations
experiment using 12499.968 target evaluations


  0%|                                     | 4/244140 [00:01<14:52:53,  4.56it/s]

effective sample size of 0.34112319350242615
Step 0: Free energy -147.48159790039062 Log Normalizer estimate 161.5111083984375


  0%|                                    | 455/244140 [00:15<2:05:08, 32.45it/s]

In [None]:
exp_config = get_experiment_config(1, "custom")
results1 = run_experiment(exp_config)

In [None]:
exp_config = get_experiment_config(2, "custom")
results2 = run_experiment(exp_config)

# Evaluation

In [None]:
# jax.config.update("jax_enable_x64", True)

In [None]:
from evaluation import evaluate_many_well, load_checkpoint, make_forward_pass_func, \
    make_get_ess, make_get_resample_info, get_flow_init_params

In [None]:
target_density = ManyWell(exp_config.final_config.density, (32,))

In [None]:
eval_batch_size = 1000

In [None]:
def evaluate_Z_estimation(seed, experiment_type = "original", long=False, 
                          eval_batch_size=1000, n_eval=50, flow_identity=False):
    exp_config = get_experiment_config(seed, experiment_type, long)
    if flow_identity:
        transition_params = get_flow_init_params(exp_config)
    else:
        transition_params = load_checkpoint(exp_config.params_filename)
    forward_pass_function = make_forward_pass_func(exp_config, transition_params=transition_params,
                                                  eval_batch_size=eval_batch_size)
    key = jax.random.PRNGKey(0)
    abs_errors = []
    for i in range(n_eval):
        key, subkey = jax.random.split(key)
        particle_state = forward_pass_function(subkey)
        log_Z_estimate = particle_state.log_normalizer_estimate
        relative_error = jnp.exp(log_Z_estimate - target_density.log_Z) - 1
        abs_error = jnp.abs(relative_error)
        abs_errors.append(abs_error)
    return jnp.stack(abs_errors)

#### Original

In [None]:
# standard
abs_errors0 = evaluate_Z_estimation(0, "original", False)
jnp.mean(abs_errors0)

In [None]:
# standard
abs_errors1 = evaluate_Z_estimation(1, "original", False)
jnp.mean(abs_errors1)

In [None]:
abs_errors2 = evaluate_Z_estimation(2, "original", False)
jnp.mean(abs_errors2)

#### Original Long

In [None]:
abs_errors1 = evaluate_Z_estimation(1, "original", True)
jnp.mean(abs_errors1)

In [None]:
abs_errors2 = evaluate_Z_estimation(2, "original", True)
jnp.mean(abs_errors2)

#### Custom

In [None]:
abs_errors_long = evaluate_Z_estimation(0, "custom", False)

In [None]:
jnp.mean(abs_errors_long)

In [None]:
jnp.mean(abs_errors_long)

In [None]:
abs_errors0 = evaluate_Z_estimation(0)

In [None]:
abs_errors1 = evaluate_Z_estimation(1)

In [None]:
abs_errors2 = evaluate_Z_estimation(2)

In [None]:
jnp.mean(abs_errors0), jnp.mean(abs_errors1), jnp.mean(abs_errors2)

In [None]:
abs_errors0

### Evalation of full craft