# Bayes-by-backprop

In [14]:
import jax
import distrax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
from typing import Callable
from flax.training.train_state import TrainState

In [2]:
%config InlineBackend.figure_format = "retina"

In [3]:
import numpy as np
from dynamax.utils import datasets

In [5]:
np.random.seed(314)
train, test = datasets.load_rotated_mnist()

X_train, y_train = train

In [90]:
import flax.linen as nn
from functools import partial

In [91]:
@partial(jax.jit, static_argnums=(1,2))
def get_batch_train_ixs(key, num_samples, batch_size):
    """
    Obtain the training indices to be used in an epoch of
    mini-batch optimisation.
    """
    steps_per_epoch = num_samples // batch_size
    
    batch_ixs = jax.random.permutation(key, num_samples)
    batch_ixs = batch_ixs[:steps_per_epoch * batch_size]
    batch_ixs = batch_ixs.reshape(steps_per_epoch, batch_size)
    
    return batch_ixs

In [92]:
class VWeights(nn.Module):
    dim_out: int = 1
    normal_init: Callable = nn.initializers.normal()
    
    def setup(self):
        self.mean = self.param("mean", normal_init, (self.dim_out,))
        self.rho = self.param("rho", normal_init, (self.dim_out,))
    
    def __call__(self, key):
        eps = jax.random.normal(key)
        w = self.mean + self.rho * eps
        
        return w
        
    
    def log_prob(self, w):
        """
        Variational log-probability
        
        TODO: Try logvar trick and compare
              to rho-estimate
        """
        mean, rho = self.mean, self.rho
        sigma = jnp.log(1 + jnp.exp(rho))
        
        lprob = distrax.Normal(loc=mean, scale=sigma).log_prob(w)
        return lprob
    
    def sample_and_eval(self, key):
        """
        Sample weights and evaluate its
        log-probability
        """
        w = self.__call__(key)
        lprob = self.log_prob(w)
        
        return w, lprob


def bbb_lossfn(key, params, apply_fn, X_batch):
    w, variational_logprob = ...

In [110]:
class MLP(nn.Module):
    dim_out: int
    dim_hidden: int = 100
    activation: Callable = nn.relu
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.dim_hidden)(x)
        x = self.activation(x)
        x = nn.Dense(self.dim_hidden)(x)
        x = self.activation(x)
        x = nn.Dense(self.dim_out)(x)
        
        return x

key = jax.random.PRNGKey(314)
key_init, key_train = jax.random.split(key)

batch = (100, 28 ** 2)
model = MLP(1)
batch = jnp.ones(batch)

### Initialise params

In [111]:
key_mean, key_logvar = jax.random.split(key_init)
params_mean = model.init(key_mean, batch)

num_params, reconstruct_fn = ravel_pytree(params_mean)
num_params = len(num_params)

params_logvar = jax.random.normal(key_logvar, (num_params,))
params_logvar = reconstruct_fn(params_logvar)

### Define sampling functions

In [142]:
from jaxtyping import Float, PyTree
from chex import dataclass

@dataclass
class BBBParams:
    mean: PyTree[Float]
    logvar: PyTree[Float]
    
bbbp = BBBParams(
    mean=params_mean,
    logvar=params_logvar,
)

In [165]:
from jax.flatten_util import ravel_pytree

def transform(eps, mean, logvar):
    std = jnp.exp(logvar / 2)
    weight = mean + std * eps
    return weight


@partial(jax.jit, static_argnames=("num_params", "reconstruct_fn"))
def sample_params(key, state:BBBParams, num_params, reconstruct_fn:Callable):
    eps = jax.random.normal(key, (num_params,))
    eps = reconstruct_fn(eps)
    params = jax.tree_map(transform, eps, state.mean, state.logvar)
    return params


def sample_and_eval(key, state, X, num_params, reconstruct_fn):
    params_sample = sample_params(key, state, num_params, reconstruct_fn)
    return model.apply(params_sample, X).mean()

In [167]:
yhat = sample_and_eval(key, bbbp, X_train[:100], num_params, reconstruct_fn).ravel()

In [168]:
grad_eval = jax.value_and_grad(sample_and_eval, 1)
yhat, grads = grad_eval(key, bbbp, X_train[:100], num_params, reconstruct_fn)

In [170]:
jax.tree_map(jnp.shape, grads)

BBBParams(mean=FrozenDict({
    params: {
        Dense_0: {
            bias: (100,),
            kernel: (784, 100),
        },
        Dense_1: {
            bias: (100,),
            kernel: (100, 100),
        },
        Dense_2: {
            bias: (1,),
            kernel: (100, 1),
        },
    },
}), logvar=FrozenDict({
    params: {
        Dense_0: {
            bias: (100,),
            kernel: (784, 100),
        },
        Dense_1: {
            bias: (100,),
            kernel: (100, 100),
        },
        Dense_2: {
            bias: (1,),
            kernel: (100, 1),
        },
    },
}))

In [155]:
def lossfn(key):
    ...

In [None]:
def elbo_bern(key, params, apply_fn, X_batch):
    """
    Importance-weighted marginal log-likelihood for
    a Bernoulli decoder and Gaussian encoder
    """
    batch_size = len(X_batch)

    encode_decode = apply_fn(params, X_batch, key)
    w, (mean_w, logvar_w), logit_mean_x = encode_decode
    _, num_is_samples, dim_latent = z.shape

    std_w = jnp.exp(logvar_w / 2)
    var_w = jnp.exp(logvar_w)
    
    dist_prior = distrax.MultivariateNormalDiag(jnp.zeros(dim_latent),
                                                jnp.ones(dim_latent))
    dist_decoder = distrax.Bernoulli(logits=logit_mean_x)
    dist_posterior = distrax.Normal(mean_w, std_w)

    log_prob_w_prior = dist_prior.log_prob(w)
    log_prob_x = dist_decoder.log_prob(X_batch).sum(axis=-1)
    # Posterior probability
    log_prob_w_post = dist_posterior.log_prob(w).sum(axis=-1)
    
    elbo = log_prob_w_post - log_prob_x + log_prob_w_prior
    return -elbo.mean()


## References

* [1] Laurent Valentin Jospin, Wray Buntine, Farid Boussaid, Hamid Laga, Mohammed Bennamoun: “Hands-on Bayesian Neural Networks -- a Tutorial for Deep Learning Users”, 2020, IEEE Computational Intelligence Magazine ( Volume: 17, Issue: 2, May 2022); [arXiv:2007.06823](http://arxiv.org/abs/2007.06823).