Showing off Sequential Neural Likelihood/Likelihoord Ratio method on the example problem from https://arxiv.org/abs/1805.07226 detailed in A.1 with posteriors plotted in Figure 5a. 

SNLR is performing quite well. Still need to figure out why SNL isn't working as well

In [1]:
import jax
import jax.numpy as np
import numpy as onp
import optax
from trax.jaxboard import SummaryWriter
from lbi.prior import SmoothedBoxPrior
from lbi.dataset import getDataLoaderBuilder
from lbi.diagnostics import MMD, ROC_AUC, LR_ROC_AUC
from lbi.sequential.sequential import sequential
from lbi.models.steps import get_train_step, get_valid_step
from lbi.models.flows import InitializeFlow
from lbi.models.classifier import InitializeClassifier
from lbi.trainer import getTrainer
from lbi.sampler import hmc
from lbi.examples.TractableProblem.tractable_problem_functions import get_simulator

import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
import datetime

%load_ext autoreload
%autoreload 2

2021-10-11 18:26:21.794890: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [2]:
# remove top and right axis from plots
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

In [3]:
model_type = "flow"  # "classifier" or "flow"

seed = 1234
rng, model_rng, hmc_rng = jax.random.split(jax.random.PRNGKey(seed), num=3)

# Model hyperparameters
num_layers = 5
hidden_dim = 512

# Optimizer hyperparmeters
max_norm = 1e-3
learning_rate = 3e-4
weight_decay = 1e-1
sync_period = 5
slow_step_size = 0.5

# Train hyperparameters
nsteps = 1000
patience = 50
eval_interval = 100

# Sequential hyperparameters
num_rounds = 10
num_initial_samples = 10000
num_samples_per_round = 10000
num_chains = 2

In [4]:
# set up simulation and observables
simulate, obs_dim, theta_dim = get_simulator()

# set up true model for posterior inference test
true_theta = np.array([0.7, -2.9, -1.0, -0.9, 0.6])
X_true = simulate(rng, true_theta, num_samples_per_theta=1)

In [5]:
data_loader_builder = getDataLoaderBuilder(
    sequential_mode=model_type,
    batch_size=128,
    train_split=0.95,
    num_workers=0,
    add_noise=False,
)

In [6]:
# set up prior
log_prior, sample_prior = SmoothedBoxPrior(
    theta_dim=theta_dim, lower=-3.0, upper=3.0, sigma=0.02
)

In [7]:
# Create model
if model_type == "classifier":
    model_params, loss, log_pdf = InitializeClassifier(
        model_rng=model_rng,
        obs_dim=obs_dim,
        theta_dim=theta_dim,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
    )
else:
    model_params, loss, (log_pdf, sample) = InitializeFlow(
        model_rng=model_rng,
        obs_dim=obs_dim,
        theta_dim=theta_dim,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
    )

# Create optimizer
optimizer = optax.chain(
    # Set the parameters of Adam optimizer
    optax.adamw(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        b1=0.9,
        b2=0.999,
        eps=1e-8,
    ),
    optax.adaptive_grad_clip(max_norm),
)
optimizer = optax.lookahead(
    optimizer, sync_period=sync_period, slow_step_size=slow_step_size
)

model_params = optax.LookaheadParams.init_synced(model_params)
opt_state = optimizer.init(model_params)

# Create trainer
train_step = get_train_step(loss, optimizer)
valid_step = get_valid_step({"valid_loss": loss})

trainer = getTrainer(
    train_step,
    valid_step=valid_step,
    nsteps=nsteps,
    eval_interval=eval_interval,
    patience=patience,
    logger=None,
    train_kwargs=None,
    valid_kwargs=None,
)

In [8]:
# Train model sequentially
model_params, Theta_post = sequential(
    rng,
    X_true,
    model_params,
    log_pdf,
    log_prior,
    sample_prior,
    simulate,
    opt_state,
    trainer,
    data_loader_builder,
    num_rounds=num_rounds,
    num_initial_samples=num_initial_samples,
    num_samples_per_round=num_samples_per_round,
    num_samples_per_theta=1,
    num_chains=num_chains,
    logger=None,
)

STARTING ROUND 1


Valid loss: 1.5555: 100%|██████████| 1000/1000 [00:35<00:00, 27.79it/s]
sample: 100%|██████████| 20000/20000 [01:21<00:00, 244.72it/s, 15 steps of size 1.16e-01. acc. prob=0.94]
sample: 100%|██████████| 20000/20000 [00:53<00:00, 373.35it/s, 23 steps of size 1.69e-01. acc. prob=0.69]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.73      0.01      0.73      0.71      0.75   6693.40      1.00
Param:0[1]     -2.63      0.06     -2.62     -2.70     -2.53   3033.91      1.00
Param:0[2]      1.23      0.27      1.27      0.85      1.61   1672.12      1.00
Param:0[3]     -2.77      0.27     -2.84     -3.04     -2.51   2064.43      1.00
Param:0[4]      2.47      0.23      2.49      2.13      2.84   3975.95      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 2


Valid loss: -5.8855: 100%|██████████| 1000/1000 [00:16<00:00, 60.72it/s]
sample: 100%|██████████| 20000/20000 [01:12<00:00, 275.38it/s, 51 steps of size 6.47e-02. acc. prob=0.92]
sample: 100%|██████████| 20000/20000 [00:22<00:00, 886.03it/s, 7 steps of size 5.98e-01. acc. prob=0.85] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.70      0.01      0.70      0.68      0.73  13925.07      1.00
Param:0[1]     -2.91      0.02     -2.91     -2.94     -2.88  10898.42      1.00
Param:0[2]      0.92      0.09      0.92      0.77      1.07  12899.71      1.00
Param:0[3]      0.08      0.33      0.11     -0.46      0.61   9789.57      1.00
Param:0[4]      0.88      0.12      0.89      0.67      1.07  11923.47      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 3


Valid loss: -13.0592: 100%|██████████| 1000/1000 [00:14<00:00, 69.77it/s]
sample: 100%|██████████| 20000/20000 [00:31<00:00, 636.62it/s, 11 steps of size 2.98e-01. acc. prob=0.86]
sample: 100%|██████████| 20000/20000 [00:18<00:00, 1094.07it/s, 7 steps of size 5.04e-01. acc. prob=0.89]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.71      0.02      0.71      0.68      0.73      1.31      1.81
Param:0[1]     -2.86      0.06     -2.86     -2.94     -2.77      1.09      3.01
Param:0[2]     -0.06      1.09      0.04     -1.20      1.06      1.00     32.31
Param:0[3]      0.65      1.15      1.30     -0.80      1.95      1.04      4.38
Param:0[4]      0.76      0.11      0.76      0.57      0.93     17.51      1.03

Number of divergences: 0


  plt.show()


STARTING ROUND 4


Valid loss: -11.2087: 100%|██████████| 1000/1000 [00:15<00:00, 63.13it/s]
sample: 100%|██████████| 20000/20000 [00:20<00:00, 999.83it/s, 7 steps of size 5.46e-01. acc. prob=0.93] 
sample: 100%|██████████| 20000/20000 [00:17<00:00, 1135.59it/s, 7 steps of size 6.44e-01. acc. prob=0.89]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.72      0.01      0.72      0.70      0.73      1.16      2.28
Param:0[1]     -2.88      0.01     -2.88     -2.89     -2.86      2.68      1.19
Param:0[2]     -0.02      1.03     -0.01     -1.07      1.03      1.00     75.11
Param:0[3]      1.20      0.14      1.20      0.99      1.40      1.19      2.17
Param:0[4]      0.71      0.05      0.71      0.62      0.79      1.82      1.36

Number of divergences: 0


  plt.show()


STARTING ROUND 5


Valid loss: -13.6414: 100%|██████████| 1000/1000 [00:15<00:00, 63.12it/s]
sample: 100%|██████████| 20000/20000 [00:20<00:00, 976.25it/s, 7 steps of size 7.09e-01. acc. prob=0.89] 
sample: 100%|██████████| 20000/20000 [00:19<00:00, 1008.91it/s, 7 steps of size 5.21e-01. acc. prob=0.94]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.71      0.00      0.71      0.70      0.72  23866.00      1.00
Param:0[1]     -2.89      0.00     -2.89     -2.89     -2.88  23722.56      1.00
Param:0[2]      1.01      0.01      1.01      1.00      1.03  21948.38      1.00
Param:0[3]      0.93      0.03      0.93      0.89      0.97  24616.03      1.00
Param:0[4]      0.65      0.02      0.65      0.62      0.68  22487.81      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 6


Valid loss: -19.5586: 100%|██████████| 1000/1000 [00:16<00:00, 59.96it/s]
sample: 100%|██████████| 20000/20000 [00:18<00:00, 1058.48it/s, 7 steps of size 6.56e-01. acc. prob=0.90]
sample: 100%|██████████| 20000/20000 [00:16<00:00, 1231.72it/s, 7 steps of size 6.70e-01. acc. prob=0.90]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.69      0.00      0.69      0.69      0.70  22884.40      1.00
Param:0[1]     -2.88      0.00     -2.88     -2.89     -2.87  23644.81      1.00
Param:0[2]     -1.00      0.01     -1.00     -1.01     -0.99  23229.62      1.00
Param:0[3]      0.84      0.02      0.84      0.82      0.87  24097.15      1.00
Param:0[4]      0.63      0.01      0.63      0.61      0.65  23952.97      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 7


Valid loss: -19.7389: 100%|██████████| 1000/1000 [00:17<00:00, 57.19it/s]
sample: 100%|██████████| 20000/20000 [00:19<00:00, 1032.83it/s, 7 steps of size 6.29e-01. acc. prob=0.91]
sample: 100%|██████████| 20000/20000 [00:15<00:00, 1273.27it/s, 7 steps of size 7.36e-01. acc. prob=0.87]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.70      0.00      0.70      0.70      0.71  22104.39      1.00
Param:0[1]     -2.90      0.00     -2.90     -2.91     -2.89  24956.33      1.00
Param:0[2]      1.02      0.01      1.02      1.01      1.03  20739.00      1.00
Param:0[3]      0.90      0.01      0.90      0.89      0.91  25123.37      1.00
Param:0[4]      0.59      0.01      0.59      0.57      0.60  23964.88      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 8


Valid loss: -19.7875: 100%|██████████| 1000/1000 [00:18<00:00, 54.81it/s]
sample: 100%|██████████| 20000/20000 [00:18<00:00, 1078.65it/s, 7 steps of size 6.65e-01. acc. prob=0.89]
sample: 100%|██████████| 20000/20000 [00:15<00:00, 1261.06it/s, 7 steps of size 6.87e-01. acc. prob=0.90]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.69      0.00      0.69      0.69      0.70      2.21      1.26
Param:0[1]     -2.90      0.00     -2.90     -2.90     -2.89     11.98      1.04
Param:0[2]     -0.01      1.00      0.00     -1.01      1.00      1.00    152.68
Param:0[3]      0.91      0.01      0.91      0.90      0.93      4.54      1.10
Param:0[4]      0.63      0.01      0.63      0.61      0.65      2.61      1.20

Number of divergences: 0


  plt.show()


STARTING ROUND 9


Valid loss: -18.2406: 100%|██████████| 1000/1000 [00:18<00:00, 54.05it/s]
sample: 100%|██████████| 20000/20000 [00:17<00:00, 1159.45it/s, 7 steps of size 7.30e-01. acc. prob=0.88]
sample: 100%|██████████| 20000/20000 [00:16<00:00, 1236.00it/s, 7 steps of size 5.93e-01. acc. prob=0.93]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.71      0.00      0.71      0.71      0.72  24358.95      1.00
Param:0[1]     -2.89      0.00     -2.89     -2.90     -2.88  26507.03      1.00
Param:0[2]      1.01      0.01      1.01      1.00      1.02  22766.26      1.00
Param:0[3]      0.88      0.01      0.88      0.87      0.89  26304.88      1.00
Param:0[4]      0.61      0.01      0.61      0.60      0.63  24836.85      1.00

Number of divergences: 0


  plt.show()


STARTING ROUND 10


Valid loss: -21.2366: 100%|██████████| 1000/1000 [00:19<00:00, 51.41it/s]
sample: 100%|██████████| 20000/20000 [00:17<00:00, 1113.25it/s, 7 steps of size 6.85e-01. acc. prob=0.90]
sample: 100%|██████████| 20000/20000 [00:15<00:00, 1271.62it/s, 7 steps of size 6.75e-01. acc. prob=0.91]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.70      0.00      0.70      0.70      0.71  22972.41      1.00
Param:0[1]     -2.90      0.00     -2.90     -2.90     -2.89  26027.96      1.00
Param:0[2]      1.00      0.00      1.00      0.99      1.01  20544.33      1.00
Param:0[3]      0.90      0.01      0.90      0.89      0.91  24980.98      1.00
Param:0[4]      0.58      0.01      0.58      0.56      0.59  24734.77      1.00

Number of divergences: 0


  plt.show()


In [9]:
def potential_fn(theta):
    if len(theta.shape) == 1:
        theta = theta[None, :]
    log_post = (
        -log_pdf(
            model_params.fast if hasattr(model_params, "fast") else model_params,
            X_true,
            theta,
        )
        - log_prior(theta)
    )
    return log_post.sum()

In [10]:
num_chains = 20
init_theta = sample_prior(rng, num_samples=num_chains)

mcmc = hmc(
    rng,
    potential_fn,
    init_theta,
    adapt_step_size=True,
    adapt_mass_matrix=True,
    dense_mass=True,
    step_size=1e0,
    max_tree_depth=6,
    num_warmup=2000,
    num_samples=2000,
    num_chains=num_chains,
)
mcmc.print_summary()

theta_samples = mcmc.get_samples(group_by_chain=False).squeeze()

theta_dim = theta_samples.shape[-1]
true_theta = onp.array([0.7, -2.9, -1.0, -0.9, 0.6])

sample: 100%|██████████| 4000/4000 [00:05<00:00, 685.73it/s, 3 steps of size 6.19e-01. acc. prob=0.86] 
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1110.10it/s, 7 steps of size 6.12e-01. acc. prob=0.89]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1057.88it/s, 3 steps of size 5.09e-01. acc. prob=0.91]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1074.26it/s, 7 steps of size 6.30e-01. acc. prob=0.90]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1056.41it/s, 7 steps of size 5.77e-01. acc. prob=0.90]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1218.10it/s, 3 steps of size 6.89e-01. acc. prob=0.88]
sample: 100%|██████████| 4000/4000 [00:00<00:00, 4741.30it/s, 1 steps of size 1.18e-38. acc. prob=0.00]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1145.26it/s, 7 steps of size 5.98e-01. acc. prob=0.91]
sample: 100%|██████████| 4000/4000 [00:03<00:00, 1142.16it/s, 7 steps of size 6.56e-01. acc. prob=0.87]
sample: 100%|██████████| 4000/4000 [00:04<00:00, 925.15it/s, 7 s


                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.71      0.26      0.70      0.03      0.72     10.00     86.53
Param:0[1]     -2.09      1.95     -2.90     -2.91      2.59     10.00    437.74
Param:0[2]      0.52      1.23      0.99     -1.01      2.40     10.00    256.07
Param:0[3]     -0.07      1.14     -0.01     -1.02      1.55     10.00    130.86
Param:0[4]      0.47      0.92      0.59     -2.43      0.61     10.01     58.70

Number of divergences: 6000


In [11]:
corner.corner(
    onp.array(theta_samples),
    range=[(-3, 3) for i in range(theta_dim)],
    truths=true_theta,
    bins=75,
    smooth=(1.0),
    smooth1d=(1.0),
)
# plt.show()
plt.savefig("hmc_corner.png")