Showing off Sequential Neural Likelihood/Likelihoord Ratio method on the example problem from https://arxiv.org/abs/1805.07226 detailed in A.1 with posteriors plotted in Figure 5a. 

SNLR is performing quite well. Still need to figure out why SNL isn't working as well

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import optax
from trax.jaxboard import SummaryWriter
from lbi.prior import SmoothedBoxPrior
from lbi.dataset import getDataLoaderBuilder
from lbi.diagnostics import MMD, ROC_AUC, LR_ROC_AUC
from lbi.sequential.sequential import sequential
from lbi.models.base import get_train_step, get_valid_step
from lbi.models.flows import InitializeFlow
from lbi.models.classifier import InitializeClassifier
from lbi.trainer import getTrainer
from lbi.sampler import hmc
from lbi.examples.TractableProblem.tractable_problem_functions import get_simulator

import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
import datetime

%load_ext autoreload
%autoreload 2

In [2]:
# remove top and right axis from plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [3]:
model_type = "classifier"  # "classifier" or "flow"

seed = 1234
rng, model_rng, hmc_rng = jax.random.split(jax.random.PRNGKey(seed), num=3)

# Model hyperparameters
num_layers = 5
hidden_dim = 512

# Optimizer hyperparmeters
max_norm = 1e-3
learning_rate = 3e-4
weight_decay = 1e-1
sync_period = 5
slow_step_size = 0.5

# Train hyperparameters
nsteps = 250000
patience = 500
eval_interval = 100

# Sequential hyperparameters
num_rounds = 1
num_initial_samples = 100000
num_samples_per_round = 1000
num_chains = 2

In [4]:
# set up simulation and observables
simulate, obs_dim, theta_dim = get_simulator()

# set up true model for posterior inference test
true_theta = np.array([0.7, -2.9, -1.0, -0.9, 0.6])
X_true = simulate(rng, true_theta, num_samples_per_theta=1)

In [5]:
data_loader_builder = getDataLoaderBuilder(
    sequential_mode=model_type,
    batch_size=128,
    train_split=0.95,
    num_workers=0,
    add_noise=False,
)

In [6]:
# set up prior
log_prior, sample_prior = SmoothedBoxPrior(
    theta_dim=theta_dim, lower=-3.0, upper=3.0, sigma=0.02
)

In [7]:
# Create model
if model_type == "classifier":
    model_params, loss, log_pdf = InitializeClassifier(
        model_rng=model_rng,
        obs_dim=obs_dim,
        theta_dim=theta_dim,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
    )
else:
    model_params, loss, (log_pdf, sample) = InitializeFlow(
        model_rng=model_rng,
        obs_dim=obs_dim,
        theta_dim=theta_dim,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
    )

# Create optimizer
optimizer = optax.chain(
    # Set the parameters of Adam optimizer
    optax.adamw(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        b1=0.9,
        b2=0.999,
        eps=1e-8,
    ),
    optax.adaptive_grad_clip(max_norm),
)
optimizer = optax.lookahead(
    optimizer, sync_period=sync_period, slow_step_size=slow_step_size
)

model_params = optax.LookaheadParams.init_synced(model_params)
opt_state = optimizer.init(model_params)

# Create trainer
train_step = get_train_step(loss, optimizer)
valid_step = get_valid_step({"valid_loss": loss})

trainer = getTrainer(
    train_step,
    valid_step=valid_step,
    nsteps=nsteps,
    eval_interval=eval_interval,
    patience=patience,
    logger=None,
    train_kwargs=None,
    valid_kwargs=None,
)

In [8]:
# Train model sequentially
model_params, Theta_post = sequential(
    rng,
    X_true,
    model_params,
    log_pdf,
    log_prior,
    sample_prior,
    simulate,
    opt_state,
    trainer,
    data_loader_builder,
    num_rounds=num_rounds,
    num_initial_samples=num_initial_samples,
    num_samples_per_round=num_samples_per_round,
    num_samples_per_theta=1,
    num_chains=num_chains,
    logger=None,
)

STARTING ROUND 1


Valid loss: 0.0025:   2%|▏         | 4490/250000 [04:12<3:49:50, 17.80it/s]


Keyboard interrupted. Stopping early


sample: 100%|██████████| 2000/2000 [00:20<00:00, 95.56it/s, 7 steps of size 6.23e-01. acc. prob=0.83] 
sample: 100%|██████████| 2000/2000 [00:18<00:00, 109.52it/s, 7 steps of size 5.28e-01. acc. prob=0.89]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.68      0.02      0.68      0.65      0.70   1633.97      1.00
Param:0[1]     -2.87      0.02     -2.87     -2.90     -2.84   2303.68      1.00
Param:0[2]      0.47      0.45      0.49     -0.23      1.16   1059.04      1.00
Param:0[3]      0.29      0.32      0.30     -0.21      0.81   2324.96      1.00
Param:0[4]      1.67      0.56      1.68      0.81      2.63   1475.34      1.00

Number of divergences: 85


  plt.show()


In [9]:
def potential_fn(theta):
    if len(theta.shape) == 1:
        theta = theta[None, :]
    log_post = (
        -log_pdf(
            model_params.fast if hasattr(model_params, "fast") else model_params,
            X_true,
            theta,
        )
        - log_prior(theta)
    )
    return log_post.sum()

In [10]:
num_chains = 2
init_theta = sample_prior(rng, num_samples=num_chains)

mcmc = hmc(
    rng,
    potential_fn,
    init_theta,
    adapt_step_size=True,
    adapt_mass_matrix=True,
    dense_mass=True,
    step_size=1e0,
    max_tree_depth=6,
    num_warmup=2000,
    num_samples=2000,
    num_chains=num_chains,
)
mcmc.print_summary()

theta_samples = mcmc.get_samples(group_by_chain=False).squeeze()

theta_dim = theta_samples.shape[-1]
true_theta = onp.array([0.7, -2.9, -1.0, -0.9, 0.6])

sample: 100%|██████████| 4000/4000 [00:34<00:00, 114.98it/s, 3 steps of size 5.44e-01. acc. prob=0.88]
sample: 100%|██████████| 4000/4000 [00:29<00:00, 133.83it/s, 3 steps of size 6.08e-01. acc. prob=0.85]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.68      0.02      0.68      0.66      0.71   4041.89      1.00
Param:0[1]     -2.90      0.02     -2.90     -2.93     -2.87   3791.87      1.00
Param:0[2]      0.52      0.40      0.52     -0.16      1.18   4680.21      1.00
Param:0[3]      0.26      0.33      0.27     -0.28      0.79   4019.45      1.00
Param:0[4]      1.67      0.57      1.68      0.70      2.62   3292.51      1.00

Number of divergences: 143


In [12]:
corner.corner(
    onp.array(theta_samples),
    range=[(-3, 3) for i in range(theta_dim)],
    truths=true_theta,
    bins=75,
    smooth=(1.0),
    smooth1d=(1.0),
)
plt.show()
# plt.savefig("hmc_corner.png")