# Bayes-by-backprop

In [1]:
import jax
import distrax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
from typing import Callable
from flax.training.train_state import TrainState

In [2]:
%config InlineBackend.figure_format = "retina"

In [5]:
import flax.linen as nn
from functools import partial

In [10]:
import numpy as np
from dynamax.utils import datasets
from jax.flatten_util import ravel_pytree

In [4]:
np.random.seed(314)
train, test = datasets.load_rotated_mnist()

X_train, y_train = train

In [6]:
@partial(jax.jit, static_argnums=(1,2))
def get_batch_train_ixs(key, num_samples, batch_size):
    """
    Obtain the training indices to be used in an epoch of
    mini-batch optimisation.
    """
    steps_per_epoch = num_samples // batch_size
    
    batch_ixs = jax.random.permutation(key, num_samples)
    batch_ixs = batch_ixs[:steps_per_epoch * batch_size]
    batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size)
    
    return batch_ixs

In [7]:
class VWeights(nn.Module):
    dim_out: int = 1
    normal_init: Callable = nn.initializers.normal()
    
    def setup(self):
        self.mean = self.param("mean", normal_init, (self.dim_out,))
        self.rho = self.param("rho", normal_init, (self.dim_out,))
    
    def __call__(self, key):
        eps = jax.random.normal(key)
        w = self.mean + self.rho * eps
        
        return w
        
    
    def log_prob(self, w):
        """
        Variational log-probability
        
        TODO: Try logvar trick and compare
              to rho-estimate
        """
        mean, rho = self.mean, self.rho
        sigma = jnp.log(1 + jnp.exp(rho))
        
        lprob = distrax.Normal(loc=mean, scale=sigma).log_prob(w)
        return lprob
    
    def sample_and_eval(self, key):
        """
        Sample weights and evaluate its
        log-probability
        """
        w = self.__call__(key)
        lprob = self.log_prob(w)
        
        return w, lprob


def bbb_lossfn(key, params, apply_fn, X_batch):
    w, variational_logprob = ...

In [8]:
class MLP(nn.Module):
    dim_out: int
    dim_hidden: int = 100
    activation: Callable = nn.relu
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.dim_hidden)(x)
        x = self.activation(x)
        x = nn.Dense(self.dim_hidden)(x)
        x = self.activation(x)
        x = nn.Dense(self.dim_out)(x)
        
        return x

key = jax.random.PRNGKey(314)
key_init, key_train = jax.random.split(key)

batch = (100, 28 ** 2)
model = MLP(1)
batch = jnp.ones(batch)

### Initialise params

In [11]:
key_mean, key_logvar = jax.random.split(key_init)
params_mean = model.init(key_mean, batch)

num_params, reconstruct_fn = ravel_pytree(params_mean)
num_params = len(num_params)

params_logvar = jax.random.normal(key_logvar, (num_params,))
params_logvar = reconstruct_fn(params_logvar)

### Define sampling functions

In [85]:
from jaxtyping import Float, PyTree, Array
from chex import dataclass

@dataclass
class BBBParams:
    mean: PyTree[Float]
    logvar: PyTree[Float]

# Bayes-by-backprop params
b3p = BBBParams(
    mean=params_mean,
    logvar=params_logvar,
)

In [68]:
def transform(eps, mean, logvar):
    std = jnp.exp(logvar / 2)
    weight = mean + std * eps
    return weight


@partial(jax.jit, static_argnames=("num_params", "reconstruct_fn"))
def sample_params(key, state:BBBParams, num_params, reconstruct_fn:Callable):
    eps = jax.random.normal(key, (num_params,))
    eps = reconstruct_fn(eps)
    params = jax.tree_map(transform, eps, state.mean, state.logvar)
    return params


@partial(jax.jit, static_argnames=("num_params", "reconstruct_fn"))
def sample_and_eval(key, state, X, num_params, reconstruct_fn):
    params_sample = sample_params(key, state, num_params, reconstruct_fn)
    return model.apply(params_sample, X).sum()

@jax.jit
def get_leaves(params):
    flat_params, _ = ravel_pytree(params)
    return flat_params

In [128]:
def cost_fn(
    key: jax.random.PRNGKey,
    state: BBBParams,
    X: Float[Array, "num_obs dim_obs"],
    y: Float[Array, "num_obs"],
    reconstruct_fn: Callable,
    scale_obs=1.0,
    scale_prior=1.0,
):
    """
    TODO:
    Add more general way to compute observation-model log-probability
    """
    
    # Sampled params
    params = sample_params(key, state, num_params, reconstruct_fn)
    params_flat = get_leaves(params)
    
    # Prior log probability (use initialised vals for mean?)
    logp_prior = distrax.Normal(loc=0.0, scale=scale_prior).log_prob(params_flat).sum()
    # Observation log-probability
    mu_obs = model.apply(params, X).ravel()
    logp_obs = distrax.Normal(loc=mu_obs, scale=scale_obs).log_prob(y).sum()
    # Variational log-probability
    logp_variational = jax.tree_map(
        lambda mean, logvar, x: distrax.Normal(loc=mean, scale=jnp.exp(logvar / 2)).log_prob(x),
        b3p.mean, b3p.logvar, params_sample
    )
    logp_variational = get_leaves(logp_variational).sum()
    
    return logp_variational - logp_prior - logp_obs


def lossfn(key, state, X, y, num_samples=10):
    # TODO: add costfn as input
    keys = jax.random.split(key, num_samples)
    cost_vmap = jax.vmap(cost_fn, in_axes=(0, None, None, None, None))
    loss = cost_vmap(keys, state, X, y, reconstruct_fn).mean()
    return loss

In [138]:
lossfn(key, b3p, X_train[:100], y_train[:100])

DeviceArray(60080128., dtype=float32)

In [139]:
tx = optax.adam(1e-3)
opt_state = tx.init(b3p)

In [140]:
opt_state = TrainState.create(
    apply_fn=model.apply,
    params=b3p,
    tx=tx
)

grads = jax.grad(lossfn, 1)(key, b3p, X_train[:100], y_train[:100])
opt_state = opt_state.apply_gradients(grads=grads)

In [141]:
lossfn(key, opt_state.params, X_train[:100], y_train[:100])

DeviceArray(56161204., dtype=float32)

## References

* [1] Laurent Valentin Jospin, Wray Buntine, Farid Boussaid, Hamid Laga, Mohammed Bennamoun: “Hands-on Bayesian Neural Networks -- a Tutorial for Deep Learning Users”, 2020, IEEE Computational Intelligence Magazine ( Volume: 17, Issue: 2, May 2022); [arXiv:2007.06823](http://arxiv.org/abs/2007.06823).