In [14]:
import jax.numpy as np
import jax.scipy as sp
import jax
from jax import jit, grad, vmap

# from tensorflow_probability.substrates import jax as tfp

import matplotlib.pyplot as plt
plt.style.use('dark_background')

### The Idea

In Bayesian statistics, we search for a posterior distribution of parameters:
$$p(\theta | D) = \frac{p(D | \theta)p(\theta)}{p(D)} \propto p(D | \theta)p(\theta)$$

But, if the parameters represent a function, i.e., they are parameters of the function, maybe what we are looking for is the posterior probability over **functions**:

$$p(f | D) = \frac{p(D | f)p(f)}{p(D)} \propto p(D | f)p(f)$$

$$p(D | f) = p(X, Y | f) = p(Y | X, f)p(X | f) = p(Y | X, f)p(X)$$

$$p(D) = p(X, Y) = p(Y | X)p(X) $$

$$p(f | D) = \frac{p(D | f)p(f)}{p(D)} = \frac{p(Y | X, f)p(X)p(f)}{p(Y | X)p(X)} = \frac{p(Y | X, f)p(f)}{p(Y | X)}$$

$$\mathbb{D}_{KL}[q(f | D) | p(f | D)] = -\mathbb{E}_{f \sim q}\text{log}\frac{p(f|D)}{q(f|D)} = -\mathbb{E}_{f \sim q}\text{log}\frac{p(Y | X, f)p(X | f)p(f)}{q(f|D)} = -\mathbb{E}_{f \sim q}\text{log}p(Y | X, f) - \mathbb{E}_{f \sim q}\text{log}\frac{p(f)}{q(f|D)} + C$$

$$\mathbb{D}_{KL}[q(f | D) | p(f | D)] = -\mathbb{E}_{f \sim q}\text{log}p(Y | X, f) + \mathbb{D}_{KL}[q(f|D) | p(f)] + C $$

$$\mathbb{D}_{KL}[q(f | D) | p(f | D)] = -\mathbb{E}_{f \sim q}\text{log}p(Y | X, f) + \mathbb{H}[q(f|D), p(f)] - \mathbb{H}[q(f|D)] + C $$

The relative enropy $\mathbb{H}[q(f|D), p(f)] = -\mathbb{E}_{f \sim q}\text{log}p(f)$ is simple to calculate, since we can sample from $q$ and we choose our prior. Having a sampler, we can also easily evaluate the first term, the cross-entropy "loss". However, to calculate the second term, the entropy of $q$, we need the surrogate posterior $q$. The trick is to use the derivative of our sampler instead:

$$\text{log}\frac{d}{dp}F^{-1}(p) = \text{log}\frac{1}{F'(F^{-1}(p))} = \text{log}\frac{1}{q(f|D)} = -\text{log}q(f|D)$$
since inverse of cdf $F^{-1}$ is a sampler (!!!) then $F^{-1}=f$ and $F'$ is density function, so $F'=q(\cdot|D)$. Therefore, we can calculate the entropy of $q$ as follows:

$$H(q) = -\mathbb{E}_{f \sim q}\text{log}q(f | D)= -\mathbb{E}_{p \sim U}\text{log}q(F^{-1}(p) | D) = \mathbb{E}_{f \sim q}\text{log}\frac{d}{dp}F^{-1}(p)$$

where the $F^{-1}: p \rightarrow (\mathcal{X} \rightarrow \mathcal{Y}) $ returns a neural network, $\mathcal{X}$ is input space, $\mathcal{Y}$ is the output space, $F^{-1}(p) = f$ and $p \sim U$ is the source of randomness.