In [1]:
from functools import partial
import os; os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

import matplotlib.pyplot as plt
import numpy as onp

from jax import jit, lax, random, vmap
from jax.config import config; config.update('jax_platform_name', 'cpu')
from jax.experimental import optimizers, stax
import jax.numpy as np

from numpyro.contrib.autoguide import AutoIAFNormal
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.handlers import sample, scale, seed, substitute, trace
from numpyro.hmc_util import consensus, initialize_model, parametric_draws
from numpyro.mcmc import mcmc
from numpyro.svi import elbo, svi

### load data

In [2]:
def _load_dataset():
    _, fetch = load_dataset(COVTYPE, shuffle=False)
    features, labels = fetch()

    # normalize features and add intercept
    features = (features - features.mean(0)) / features.std(0)
    features = np.hstack([features, np.ones((features.shape[0], 1))])

    # make binary feature
    _, counts = onp.unique(labels, return_counts=True)
    specific_category = np.argmax(counts)
    labels = (labels == specific_category)

    N, dim = features.shape
    print("Data shape:", features.shape)
    print("Label distribution: {} has label 1, {} has label 0"
          .format(labels.sum(), N - labels.sum()))
    return features, labels

X_full, y_full = _load_dataset()

Data shape: (581012, 55)
Label distribution: 211840 has label 1, 369172 has label 0


### prepare train shards and test set

In [3]:
def get_train_shards_and_test_data(X, y, K, N, rng=None):
    if rng is not None:
        idxs = random.shuffle(rng, np.arange(X.shape[0]))
        X = X[idxs]
        y = y[idxs]
    shards = []
    for i in range(K):
        shards.append((X[i * N: (i + 1) * N], y[i * N: (i + 1) * N]))
    train_data = (X[:K * N], y[:K * N])
    test_data = (X[K * N:], y[K * N:])
    return shards, train_data, test_data

K, N = 40, 10000
shards, (X_train, y_train), (X_test, y_test) = get_train_shards_and_test_data(
    X_full, y_full, K, N, random.PRNGKey(0))
print("Train set contains {} ({}%) data points.".format(
    K * N, round(K * N / X_full.shape[0] * 100, 2)))
print("Test set contains {} ({}%) data points.".format(
    X_full.shape[0] - K * N, round(100 - K * N / X_full.shape[0] * 100, 2)))

Train set contains 400000 (68.85%) data points.
Test set contains 181012 (31.15%) data points.


### model

In [4]:
def model(X, y, prior_scale=1, likelihood_scale=1):
    with scale(prior_scale):
        coefs = sample('coefs', dist.Normal(np.zeros(X.shape[-1]), np.ones(X.shape[-1])))
    with scale(likelihood_scale):
        return sample('y', dist.Bernoulli(logits=np.dot(X, coefs)), obs=y)

### sampling

In [5]:
def get_subposterior(rng, shard, K):
    rngs = random.split(rng, 4)
    X, y = shard
    init_params, potential_fn, _ = initialize_model(rngs, model, X, y, prior_scale=1 / K)
    samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,
                   num_chains=4, potential_fn=potential_fn)
    return samples

In [6]:
rngs = random.split(random.PRNGKey(1), K)
subposteriors = []
for i, (rng, shard) in enumerate(zip(rngs, shards)):
    sep = '=' * 31
    print('\n ' + sep + ' SUBPOSTERIOR {:02d} '.format(i) + sep, end='')
    samples = get_subposterior(rng, shard, K)
    if not os.path.exists('.results'):
        os.makedirs('.results')
    np.save('.results/subposterior_{:02d}.npy'.format(i), samples)
    subposteriors.append(samples)



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.94      0.06      1.94      1.84      2.04  10596.19      1.00
  coefs[1]      0.00      0.03      0.00     -0.06      0.06  13018.15      1.00
  coefs[2]     -0.05      0.07     -0.05     -0.15      0.06   3962.76      1.00
  coefs[3]     -0.26      0.03     -0.26     -0.32     -0.21   8349.60      1.00
  coefs[4]     -0.13      0.04     -0.13     -0.19     -0.07   9673.05      1.00
  coefs[5]     -0.09      0.03     -0.09     -0.15     -0.04  11774.95      1.00
  coefs[6]      0.33      0.25      0.32     -0.08      0.73   3044.25      1.00
  coefs[7]     -0.70      0.15     -0.69     -0.94     -0.45   3230.34      1.00
  coefs[8]      0.61      0.29      0.60      0.12      1.08   3048.04      1.00
  coefs[9]     -0.04      0.03     -0.04     -0.08      0.01  11912.89      1.00
 coefs[10]      0.38      0.65      0.37     -0.66      1.50   3045.41      1.00
 coefs[11]     -0.05      

### merge subposteriors

In [8]:
consensus_samples = consensus(subposteriors, 10000)
parametric_samples = parametric_draws(subposteriors, 10000)
np.save('.results/consensus_samples.npy', consensus_samples)
np.save('.results/parametric_samples.npy', parametric_samples)

### predict

In [10]:
def predict(model, X, rng, sample):
    model = substitute(seed(model, rng), sample)
    model_trace = trace(model).get_trace(X, None)
    return model_trace['y']['value']

In [11]:
rngs = random.split(random.PRNGKey(2), 10000)

y_consensus = vmap(partial(predict, model, X_test))(rngs, consensus_samples)
y_consensus = (y_consensus.sum(axis=0) / y_consensus.shape[0]) >= 0.5
acc = (y_consensus == y_test).sum() / y_test.shape[0]
print('Consensus accuaracy: {}'.format(round(acc.item(), 4)))

y_parametric = vmap(partial(predict, model, X_test))(rngs, parametric_samples)
y_parametric = (y_parametric.sum(axis=0) / y_parametric.shape[0]) >= 0.5
acc = (y_parametric == y_test).sum() / y_test.shape[0]
print('Parametric accuaracy: {}'.format(round(acc.item(), 4)))

Consensus accuaracy: 0.771
Parametric accuaracy: 0.771


### train iaf

In [30]:
rng_guide, rng_init, rng_train = random.split(random.PRNGKey(3), 3)
opt_init, opt_update, get_params = optimizers.adam(0.001)
guide = AutoIAFNormal(rng_guide, model, get_params, nonlinearity=stax.Elu,
                      skip_connections=False)
batch_size = 1000
ll_scale = X_train.shape[0] // batch_size
svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params,
                              likelihood_scale=ll_scale)
num_iters = ll_scale

def epoch_train(epoch, opt_state, rng):
    # shuffle data
    rng, rng_shuffle = random.split(rng)
    idx = random.shuffle(rng_shuffle, np.arange(X_train.shape[0]))
    X = X_train[idx]
    y = y_train[idx]
    offset = epoch * num_iters

    def body_fn(val, i):
        batch_X = lax.dynamic_slice_in_dim(X, i * batch_size, batch_size)
        batch_y = lax.dynamic_slice_in_dim(y, i * batch_size, batch_size)
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(offset + i, rng_, opt_state_,
                                            (batch_X, batch_y), (batch_X, batch_y))
        return (opt_state_, rng_), loss

    (opt_state, _), losses = lax.scan(body_fn, (opt_state, rng), np.arange(num_iters))
    return opt_state, losses

In [31]:
opt_state, _ = svi_init(rng_init,
                        (X_train[:batch_size], y_train[:batch_size]),
                        (X_train[:batch_size], y_train[:batch_size]))

In [32]:
losses = np.array([])
num_epochs = 10
rngs = random.split(random.PRNGKey(2), 10000)
for epoch in range(num_epochs):
    rng = random.fold_in(rng_train, epoch)
    opt_state, epoch_loss = epoch_train(epoch, opt_state, rng)
    iaf_samples = guide.sample_posterior(random.PRNGKey(4), opt_state, sample_shape=(10000,))
    y_iaf = vmap(partial(predict, model, X_test))(rngs, iaf_samples)
    y_iaf = (y_iaf.sum(axis=0) / y_iaf.shape[0]) >= 0.5
    acc = (y_iaf == y_test).sum() / y_test.shape[0]
    print("Epoch {:02d} - loss {:.4f} - acc {}".format(
        epoch, np.mean(epoch_loss), round(acc.item(), 4)))
    losses = np.concatenate([losses, epoch_loss])

Epoch 00 - loss 1883.5496 - acc 0.7612
Epoch 01 - loss 688.8702 - acc 0.7706
Epoch 02 - loss 617.3635 - acc 0.7685
Epoch 03 - loss 601.3757 - acc 0.7711
Epoch 04 - loss 598.3813 - acc 0.7687
Epoch 05 - loss 593.7759 - acc 0.7708
Epoch 06 - loss 593.0955 - acc 0.7692
Epoch 07 - loss 590.5724 - acc 0.7714
Epoch 08 - loss 589.2113 - acc 0.7693
Epoch 09 - loss 588.8226 - acc 0.7721


### Flow HMC

In [77]:
transform = guide.get_transform(opt_state)
unpack_fn = guide.unpack_latent
latent_size = guide.latent_size

In [88]:
def make_transformed_pe(potential_fn, transform, unpack_fn, prior_scale):
    def transformed_potential_fn(z):
        u, intermediates = transform.call_with_intermediates(z)
        logdet = transform.log_abs_det_jacobian(z, u, intermediates=intermediates) * prior_scale
        return potential_fn(unpack_fn(u)) + logdet

    return transformed_potential_fn

In [79]:
def get_flow_subposterior(rng, shard, K):
    X, y = shard
    init_params = random.normal(rng, (4, latent_size))
    _, potential_fn, _ = initialize_model(rng, model, X, y, prior_scale=1 / K)
    transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn, 1 / K)
    samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,
                   num_chains=4, potential_fn=transformed_potential_fn)
    return samples

In [89]:
init_params = random.normal(random.PRNGKey(1), (latent_size,))
_, potential_fn, _ = initialize_model(rng, model, X, y, prior_scale=1 / K)
transformed_potential_fn = make_transformed_pe(potential_fn, transform, unpack_fn, 1 / K)
samples = mcmc(num_warmup=1000, num_samples=2500, init_params=init_params,
               num_chains=1, potential_fn=transformed_potential_fn)

warmup: 100%|██████████| 1000/1000 [03:07<00:00,  5.63it/s, 1023 steps of size 6.92e-05. acc. prob=0.78]
sample: 100%|██████████| 2500/2500 [07:29<00:00,  5.58it/s, 1023 steps of size 6.92e-05. acc. prob=0.88]




                 mean       std    median      5.0%     95.0%     n_eff     r_hat
 Param:0[0]    -13.94      0.25    -14.01    -14.26    -13.50      6.02      1.01
 Param:0[1]      3.60      0.10      3.58      3.46      3.77      4.80      1.81
 Param:0[2]     -1.00      0.36     -1.06     -1.56     -0.42      4.99      1.35
 Param:0[3]     -5.20      0.33     -5.32     -5.57     -4.54      3.37      1.68
 Param:0[4]     -2.45      0.67     -2.45     -3.60     -1.51      2.59      2.69
 Param:0[5]      6.63      0.44      6.74      5.81      7.18      3.79      1.65
 Param:0[6]      0.34      0.57      0.30     -0.56      1.29      7.26      1.01
 Param:0[7]     -1.72      0.30     -1.72     -2.22     -1.24      4.35      1.86
 Param:0[8]     10.02      0.59      9.89      9.13     11.04      4.90      1.15
 Param:0[9]     -6.03      0.49     -5.89     -6.84     -5.33      3.51      1.69
Param:0[10]     11.40      0.53     11.44     10.59     12.21      3.57      2.19
Param:0[11]   

In [94]:
flow_samples = vmap(lambda x: unpack_fn(transform(x)))(samples)

In [82]:
from numpyro.diagnostics import summary

In [97]:
rngs = random.split(random.PRNGKey(2), 2500)
y_flow = vmap(partial(predict, model, X_test))(rngs, flow_samples)
y_flow = (y_flow.sum(axis=0) / y_flow.shape[0]) >= 0.5
acc = (y_flow == y_test).sum() / y_test.shape[0]
acc

DeviceArray(0.7682474, dtype=float32)

In [92]:
summary({'coefs': real_samples['coefs'][None, ...]})



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  coefs[0]      1.97      0.07      1.97      1.86      2.09     17.99      1.01
  coefs[1]     -0.02      0.02     -0.02     -0.05      0.02     21.80      1.02
  coefs[2]     -0.12      0.02     -0.12     -0.14     -0.08      8.34      1.05
  coefs[3]     -0.31      0.02     -0.31     -0.34     -0.28      7.72      1.17
  coefs[4]     -0.11      0.02     -0.11     -0.14     -0.08     10.66      1.01
  coefs[5]     -0.10      0.02     -0.09     -0.12     -0.07     65.96      1.00
  coefs[6]      0.01      0.02      0.01     -0.03      0.05      7.04      1.15
  coefs[7]     -0.49      0.03     -0.49     -0.53     -0.44     14.74      1.00
  coefs[8]      0.24      0.02      0.23      0.20      0.27      5.82      1.47
  coefs[9]     -0.01      0.01     -0.01     -0.03      0.02      6.66      1.48
 coefs[10]      1.67      0.04      1.68      1.61      1.74      6.33      1.03
 coefs[11]      0.53      

In [52]:
X, y = shards[0]

In [51]:
rngs = random.split(random.PRNGKey(1), K)
subposteriors = []
for i, (rng, shard) in enumerate(zip(rngs, shards)):
    if i > 0:
        break
    sep = '=' * 31
    print('\n ' + sep + ' SUBPOSTERIOR {:02d} '.format(i) + sep, end='')
    samples = get_flow_subposterior(rng, shard, K)
    if not os.path.exists('.results/flow'):
        os.makedirs('.results/flow')
    np.save('.results/flow/subposterior_{:02d}.npy'.format(i), samples)
    subposteriors.append(samples)



                 mean       std    median      5.0%     95.0%     n_eff     r_hat


KeyboardInterrupt: 