In [12]:
import os, sys
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
proj = Path.home() / "papercode/variational_sampling_methods"
sys.path.insert(0, str(proj))
# os.environ['PYTHONPATH'] = os.environ.get("PYTHONPATH","") + "~/papercode/variational_sampling_methods"
from algorithms.common import flow_transport, markov_kernel
import hydra
import jax
import matplotlib
import distrax
import jax.numpy as jnp
import numpy as np
import time
import matplotlib.pyplot as plt
from targets.funnel import Funnel
from targets.gaussian_mixture import GaussianMixtureModel
from omegaconf import OmegaConf
os.environ['HYDRA_FULL_ERROR'] = '1'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
alg_cfg = OmegaConf.create({
    "num_temps": 128,
    "resample_threshold": 0.3,
    "init_mean": 0.0,
    "init_std":1.0,
    "use_resampling": False,
    "use_markov": True,
})

mcmc_cfg = SimpleNamespace()
mcmc_cfg.mcmc_kernel='hmc'
mcmc_cfg.hmc_step_times=[0., 0.5]
mcmc_cfg.hmc_steps_per_iter=1
mcmc_cfg.hmc_num_leapfrog_steps=10
mcmc_cfg.hmc_step_sizes = [0.001, 0.1]
dim=2
seed =0

In [None]:
from targets.gmm40 import GMM40


target = Funnel(dim=dim, sample_bounds=[-30., 30.])
target = GMM40(dim=dim, sample_bounds=[-30., 30.])
final_log_density = target.log_prob
key = jax.random.PRNGKey(seed)

initial_density = distrax.MultivariateNormalDiag(jnp.ones(dim) * alg_cfg.init_mean, jnp.ones(dim) * alg_cfg.init_std)
log_density_initial = initial_density.log_prob
initial_sampler = initial_density.sample
density_by_step = flow_transport.GeometricAnnealingSchedule(log_density_initial, final_log_density, alg_cfg.num_temps)
markov_kernel_by_step = markov_kernel.MarkovTransitionKernel(mcmc_cfg, density_by_step, alg_cfg.num_temps)

logger = {
    'KL/elbo': [],
    'KL/eubo': [],
    'logZ/delta_forward': [],
    'logZ/forward': [],
    'logZ/delta_reverse': [],
    'logZ/reverse': [],
    'ESS/forward': [],
    'ESS/reverse': [],
    'discrepancies/mmd': [],
    'discrepancies/sd': [],
    'other/target_log_prob': [],
    'other/EMC': [],
    "stats/step": [],
    "stats/wallclock": [],
    "stats/nfe": [],
}

n_eval_samples = 1000
batch_size= 2000
def eval_fn(samples, elbo, rev_lnz, eubo, fwd_lnz):

    if target.log_Z is not None:
        logger['logZ/delta_reverse'].append(jnp.abs(rev_lnz - target.log_Z))

    logger['logZ/reverse'].append(rev_lnz)
    logger['KL/elbo'].append(elbo)
    logger['other/target_log_prob'].append(jnp.mean(target.log_prob(samples)))

    if eubo is not None:
        if target.log_Z is not None:
            logger['logZ/delta_forward'].append(jnp.abs(fwd_lnz - target.log_Z))
        logger['logZ/forward'].append(fwd_lnz)
        logger['KL/eubo'].append(eubo)

    logger.update(target.visualise(samples=samples, show=True))
    return logger

In [14]:
from algorithms.common.eval_methods.sis_methods import get_eval_fn
from algorithms.smc.smc import get_short_inner_loop, get_short_reverse_inner_loop


key, subkey = jax.random.split(key)
num_temps = alg_cfg.num_temps
target_samples = target.sample(jax.random.PRNGKey(0), (n_eval_samples,))
alg_cfg.use_resampling=True
samples = initial_sampler(seed=jax.random.PRNGKey(0), sample_shape=(batch_size,))
log_weights = -jnp.log(batch_size) * jnp.ones(batch_size)
inner_loop_jit = jax.jit(get_short_inner_loop(markov_kernel_by_step, density_by_step, alg_cfg))
reverse_inner_loop_jit = jax.jit(get_short_reverse_inner_loop(markov_kernel_by_step, density_by_step, alg_cfg))


ln_z = 0.
elbo = 0.
start_time = time.time()

acceptance_hmc = []
acceptance_rwm = []

for step in range(1, num_temps):
    subkey, key = jax.random.split(key)
    samples, log_weights, ln_z_inc, acceptance = inner_loop_jit(
        subkey, samples, log_weights, step)
    ln_z_inc, elbo_inc = ln_z_inc
    acceptance_hmc.append(float(np.asarray(acceptance[0])))
    acceptance_rwm.append(float(np.asarray(acceptance[1])))
    ln_z += ln_z_inc
    elbo += elbo_inc


logger = eval_fn(samples, elbo, ln_z, None, None)

finish_time = time.time()
delta_time = finish_time - start_time

  plt.show()


In [15]:
wb_img = logger['figures/vis'][0]
pil = wb_img.image
plt.figure()
plt.imshow(pil)
plt.show()
wb_img.image.save("smc_result.png")
logger

  plt.show()


{'KL/elbo': [Array(-2.3806105, dtype=float32)],
 'KL/eubo': [],
 'logZ/delta_forward': [],
 'logZ/forward': [],
 'logZ/delta_reverse': [Array(2.3005915, dtype=float32)],
 'logZ/reverse': [Array(-2.3005915, dtype=float32)],
 'ESS/forward': [],
 'ESS/reverse': [],
 'discrepancies/mmd': [],
 'discrepancies/sd': [],
 'other/target_log_prob': [Array(-6.529519, dtype=float32)],
 'other/EMC': [],
 'stats/step': [],
 'stats/wallclock': [],
 'stats/nfe': [],
 'figures/vis': [<wandb.sdk.data_types.image.Image at 0x737f346c5a50>]}

Use AIS Sampling Methods

In [20]:
alg_cfg=SimpleNamespace()
alg_cfg.alpha=2
smc_cfg=SimpleNamespace()
smc_cfg.batch_size=batch_size
smc_cfg.use_resampling=False
smc_cfg.n_intermediate_distributions=12
smc_cfg.spacing_type="linear"
smc_cfg.transition_operator="hmc"  # [hmc or metropolis]
smc_cfg.point_is_valid_fn_type="default"   # [default, in_bounds]

hmc_cfg=SimpleNamespace()
hmc_cfg.n_outer_steps=1
hmc_cfg.n_inner_steps=10
hmc_cfg.init_step_size=1e-1
hmc_cfg.target_p_accept=0.65
hmc_cfg.tune_step_size=True
seed=0

In [21]:
from algorithms.fab.flow.simple_flow import make_realnvp_flow_networks
from algorithms.fab.sampling.point_is_valid import default_point_is_valid_fn
from algorithms.fab.sampling.mcmc.hmc import build_blackjax_hmc
from algorithms.fab.sampling.smc import build_smc


transition_operator = build_blackjax_hmc(
    dim=dim,
    n_outer_steps=hmc_cfg.n_outer_steps,
    init_step_size=hmc_cfg.init_step_size,
    target_p_accept=hmc_cfg.target_p_accept,
    adapt_step_size=hmc_cfg.tune_step_size,
    n_inner_steps=hmc_cfg.n_inner_steps)
smc = build_smc(transition_operator=transition_operator,
            n_intermediate_distributions=smc_cfg.n_intermediate_distributions,
            spacing_type=smc_cfg.spacing_type, alpha=alg_cfg.alpha,
            use_resampling=smc_cfg.use_resampling, point_is_valid_fn=default_point_is_valid_fn)
rng = jax.random.PRNGKey(seed)
rng, key, rng_params = jax.random.split(rng,3)
smc_state = smc.init(key)
nfm = make_realnvp_flow_networks(num_blocks=6, in_channels=2, channels=256)
flow_params = nfm.init({"params": rng_params}, batch_size=5)  
# samples, loq_p = nfm.apply(flow_params, mode='sample', low=[-20,-20], high=[20,20], rng=key, n_samples=batch_size)
low_bound = [-20, -20]
high_bound = [20, 20]
def log_q_fn(x):
    return nfm.apply(flow_params, mode='log_prob', low=low_bound, high=high_bound, x=x)

def body_fn(carry, xs):
    """Generate samples with AIS/SMC."""
    smc_state = carry
    key = xs
    x0, _= nfm.apply(flow_params, mode='sample', low=low_bound, high=high_bound, rng=key, n_samples=batch_size)
    
    point, log_w, smc_state, smc_info = smc.step(x0, smc_state, log_q_fn, target.log_prob)# compute metrics
    ln_z = jax.nn.logsumexp(log_w, axis=-1) - jnp.log(batch_size)
    elbo = jnp.mean(log_w)
    return smc_state, (point.x, log_w, point.log_q, elbo, ln_z)

smc_state, (x, log_w, log_q, elbo, ln_z) = jax.lax.scan(body_fn, init=smc_state,
                                            xs=jax.random.split(key, 10))
logger = eval_fn(x[-1], elbo[-1], ln_z[-1], None, None)

In [23]:
from IPython.display import display
wb_img = logger['figures/vis'][0]
pil = wb_img.image
plt.figure()
plt.imshow(pil)
plt.show()
wb_img.image.save("ais_result.png")
# display(logger['figures/vis'][0])

  plt.show()


In [24]:
logger

{'KL/elbo': [Array(-2.3806105, dtype=float32),
  Array(-3.7049894, dtype=float32),
  Array(-2.3510933, dtype=float32)],
 'KL/eubo': [],
 'logZ/delta_forward': [],
 'logZ/forward': [],
 'logZ/delta_reverse': [Array(2.3005915, dtype=float32),
  Array(1.2598205, dtype=float32),
  Array(120.710464, dtype=float32)],
 'logZ/reverse': [Array(-2.3005915, dtype=float32),
  Array(-1.2598205, dtype=float32),
  Array(120.710464, dtype=float32)],
 'ESS/forward': [],
 'ESS/reverse': [],
 'discrepancies/mmd': [],
 'discrepancies/sd': [],
 'other/target_log_prob': [Array(-6.529519, dtype=float32),
  Array(-6.72477, dtype=float32),
  Array(-6.1614933, dtype=float32)],
 'other/EMC': [],
 'stats/step': [],
 'stats/wallclock': [],
 'stats/nfe': [],
 'figures/vis': [<wandb.sdk.data_types.image.Image at 0x737f2c7de380>]}

In [28]:
alg_cfg=SimpleNamespace()
alg_cfg.alpha=2
smc_cfg.batch_size=batch_size
smc_cfg.use_resampling=False
smc_cfg.n_intermediate_distributions=12
smc_cfg.spacing_type="linear"
smc_cfg.transition_operator="hmc"  # [hmc or metropolis]
smc_cfg.point_is_valid_fn_type="default"   # [default, in_bounds]

mh_cfg=SimpleNamespace()
mh_cfg.n_outer_steps=1
mh_cfg.init_step_size=10.
mh_cfg.target_p_accept=0.65
mh_cfg.tune_step_size=True
seed=0

In [29]:
from algorithms.fab.sampling.mcmc.metropolis import build_metropolis


transition_operator = build_metropolis(dim, mh_cfg.n_outer_steps,
                                               mh_cfg.init_step_size,
                                               target_p_accept=mh_cfg.target_p_accept,
                                               tune_step_size=mh_cfg.tune_step_size)
smc = build_smc(transition_operator=transition_operator,
            n_intermediate_distributions=smc_cfg.n_intermediate_distributions,
            spacing_type=smc_cfg.spacing_type, alpha=alg_cfg.alpha,
            use_resampling=smc_cfg.use_resampling, point_is_valid_fn=default_point_is_valid_fn)
rng = jax.random.PRNGKey(seed)
rng, key, rng_params = jax.random.split(rng,3)
smc_state = smc.init(key)
nfm = make_realnvp_flow_networks(num_blocks=6, in_channels=2, channels=256)
flow_params = nfm.init({"params": rng_params}, batch_size=5)  
# samples, loq_p = nfm.apply(flow_params, mode='sample', low=[-20,-20], high=[20,20], rng=key, n_samples=batch_size)
low_bound = [-20, -20]
high_bound = [20, 20]
def log_q_fn(x):
    return nfm.apply(flow_params, mode='log_prob', low=low_bound, high=high_bound, x=x)

def body_fn(carry, xs):
    """Generate samples with AIS/SMC."""
    smc_state = carry
    key = xs
    x0, _= nfm.apply(flow_params, mode='sample', low=low_bound, high=high_bound, rng=key, n_samples=batch_size)
    
    point, log_w, smc_state, smc_info = smc.step(x0, smc_state, log_q_fn, target.log_prob)# compute metrics
    ln_z = jax.nn.logsumexp(log_w, axis=-1) - jnp.log(batch_size)
    elbo = jnp.mean(log_w)
    return smc_state, (point.x, log_w, point.log_q, elbo, ln_z)

smc_state, (x, log_w, log_q, elbo, ln_z) = jax.lax.scan(body_fn, init=smc_state,
                                            xs=jax.random.split(key, 10))
logger = eval_fn(x[-1], elbo[-1], ln_z[-1], None, None)

  plt.show()


In [30]:
from IPython.display import display
wb_img = logger['figures/vis'][0]
pil = wb_img.image
plt.figure()
plt.imshow(pil)
plt.show()
wb_img.image.save("ais_mh_result.png")
# display(logger['figures/vis'][0])

  plt.show()


logger