In [15]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from sklearn.datasets import make_moons
from jax.flatten_util import ravel_pytree

try:
    import jaxopt
except ModuleNotFoundError:
    %pip install -qq jaxopt

try:
    import distrax
except ModuleNotFoundError:
    %pip install -qq distrax
    import distrax


try:
    import flax.linen as nn
except ModuleNotFoundError:
    %pip install -qq flax
    import flax.linen as nn

In [19]:
class MLP1D(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.relu(nn.Dense(10)(x))
        x = nn.relu(nn.Dense(10)(x))
        x = nn.relu(nn.Dense(10)(x))
        x = nn.Dense(1)(x)
        return x


def bnn_log_joint(params, X, y, model):
    logits = model.apply(params, X).ravel()
    
    flatten_params, _ = ravel_pytree(params)
    print(flatten_params.shape)
    log_prior = distrax.Normal(0.0, 1.0).log_prob(flatten_params).sum()
    log_likelihood = distrax.Bernoulli(logits=logits).log_prob(y).sum()

    log_joint = log_prior + log_likelihood
    return log_joint


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [20]:
key = jax.random.PRNGKey(314)
key_samples, key_init, key_warmup, key = jax.random.split(key, 4)

In [21]:
noise = 0.2
num_samples = 50
num_warmup = 1000
num_steps = 500

model = MLP1D()
batch = jnp.ones((num_samples, 2))
params = model.init(key_init, batch)
print(params)
flatten_params, _ = ravel_pytree(params)
print(flatten_params)
print(flatten_params.shape)
# X, y = make_moons(n_samples=num_samples, noise=noise, random_state=314)
# potential = partial(bnn_log_joint, X=X, y=y, model=model)
# potential(params)

{'params': {'Dense_0': {'kernel': Array([[ 0.31588256, -0.7311128 ,  0.14564472, -0.04507473, -0.3482501 ,
        -0.04134907, -0.53790003,  0.13042098, -0.8531028 ,  0.3607035 ],
       [ 1.192346  , -0.03632545, -0.91390526,  0.85345274, -0.04192005,
         0.90277904, -0.52848047, -0.78681064,  0.32862702, -0.67841786]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_1': {'kernel': Array([[ 0.4883353 , -0.44211686,  0.18643142,  0.18440264, -0.36997104,
         0.10482579,  0.1068711 , -0.29041114,  0.21867244,  0.23569597],
       [-0.00518206,  0.3340689 ,  0.32882866,  0.37924156, -0.34622124,
        -0.06718794,  0.20170245, -0.3759444 ,  0.19442414,  0.52131015],
       [-0.36445212,  0.32422343,  0.05612694, -0.00069386, -0.4933609 ,
        -0.02358037, -0.30396432, -0.02799126, -0.4133881 ,  0.2287416 ],
       [-0.26134524, -0.53119165, -0.08851612,  0.35735786,  0.10675976,
        -0.20044345, -0.22024532, -0.5258

In [None]:
# adapt = blackjax.window_adaptation(blackjax.nuts, potential, num_warmup)
# final_state, kernel, _ = adapt.run(key_warmup, params)
# states = inference_loop(key_samples, kernel, final_state, num_samples)

# sampled_params = states.position

In [None]:
step = 0.2
vmin, vmax = X.min() - step, X.max() + step
X_grid = jnp.mgrid[vmin:vmax:100j, vmin:vmax:100j]

In [None]:
vapply = jax.vmap(model.apply, in_axes=(0, None), out_axes=0)
vapply = jax.vmap(vapply, in_axes=(None, 1), out_axes=1)
vapply = jax.vmap(vapply, in_axes=(None, 2), out_axes=2)

logits_grid = vapply(sampled_params, X_grid)[..., -1]
p_grid = jax.nn.sigmoid(logits_grid)

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
colors = ["tab:red" if yn == 1 else "tab:blue" for yn in y]
plt.scatter(*X.T, c=colors, zorder=1)
plt.contourf(*X_grid, p_grid.mean(axis=0), zorder=0, cmap="twilight")
plt.axis("off")
plt.title("Posterior mean")

In [None]:
fig, ax = plt.subplots(figsize=(5, 4))
colors = ["tab:red" if yn == 1 else "tab:blue" for yn in y]
plt.scatter(*X.T, c=colors, zorder=1)
plt.contourf(*X_grid, p_grid.std(axis=0), zorder=0, cmap="viridis")
plt.axis("off")
plt.title("Posterior std")
plt.colorbar()