<a href="https://colab.research.google.com/github/maciejskorski/ml_examples/blob/master/BayesianVariationalInference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Variational Inference: Background

**Goal**: Approximate the posterior distribution $p(w|x)$ of weights $w$ given data $x$ by a *trackatable surogate* $q(w)$.

**Solution**:
Sampling from posterior (hard) is replaced by minimizing (easier) the divergence  
$$
\min_{q}\mathrm{KL}(q(w) \parallel p(w|x)),
$$
which with some basic algebra decomposes:
$$
\mathrm{KL}(q(w) \parallel p(w|x)) = -\mathbf{E}_{w\sim q}\log p(x|w) + \mathrm{KL}(q(w) \parallel p(w)) + \log p(x),
$$
so that ignoring terms independent of $w$ we are left with the task of optimizing 
$$
\min_q \left[-\mathbf{E}_{w\sim q}\log p(x|w) + \mathrm{KL}(q(w) \parallel p(w))\right],
$$
the sum of *expected negative loglikelihood* and *posterior/prior divergence*. 

**Interpretation**: 

The first term rewards for the fit while the second term penalizes for deviating from the prior.

**Implementation**: 

Theoretical surveys speak of ELBO, the undecomposed form [2] and don't discuss batch training. Practical implementations, like Tensorflow [1], take advantage of the decomposition and use batches:


*   the loglikelihood is approximated by stochastic forward pass: network weights are sampled and the loss is computed as usual
*   the KL terms, assuming gaussian surrogate, are handled analytically and passed trough internal losses of layers (as with regularizers)
*   for batch training, the first term is subsampled so the second requires correct scaling


In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_probability as tfp
import tensorflow_datasets as tfds

In [2]:
## dataset and general setup

from sklearn.datasets import load_wine
X,y = load_wine(return_X_y=True)
X = (X-X.mean(0))/X.std(0)

N_SAMPLES = X.shape[0]
N_BATCH = 32

## Logistic Regression via ML

In [4]:
### build model: logistic regression via Maximum Likelihood ###

tf.random.set_seed(1234)

ds = tf.data.Dataset.from_tensor_slices((X,y))
ds = ds.shuffle(N_SAMPLES).batch(N_BATCH).prefetch(1)

raw_inputs = tf.keras.Input(shape=(13,))
features = raw_inputs
logits = tf.keras.layers.Dense(3)(features)

model = tf.keras.Model(raw_inputs,logits)

def neg_loglike(y_true,y_pred):
  return tf.keras.losses.sparse_categorical_crossentropy(y_true,y_pred,from_logits=True)

model.compile(optimizer=tf.optimizers.Adam(0.01),loss=neg_loglike,metrics='accuracy')

model.fit(ds,epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f08c004d490>

## Logistic Regression via VI Sampling



In [9]:
### build model: logistic regression via VI ###

tf.random.set_seed(1234)

ds = tf.data.Dataset.from_tensor_slices((X,y))
ds = ds.shuffle(N_SAMPLES).repeat(10).batch(N_BATCH).prefetch(1)

raw_inputs = tf.keras.Input(shape=(13,))
features = raw_inputs
logits = tfp.layers.DenseReparameterization(units=3)(features)
model = tf.keras.Model(raw_inputs,logits)

def vi_loss(N_SAMPLES):
  
  def loss(y_true,y_pred):
    #  y_true is the class label tensor, shape = (N_BATCH,). 
    #  y_pred is the logits tensor, shape = (N_BATCH,N_CLASSES)
    neg_loglike = tf.keras.losses.sparse_categorical_crossentropy(y_true,y_pred,from_logits=True) # shape: (N_BATCH,)
    KL = sum(model.losses) # shape: (,)
    return tf.reduce_mean(neg_loglike) + 1/N_SAMPLES * KL
  
  return loss

optimizer = tf.optimizers.Adam(0.01)
loss_fn = vi_loss(N_SAMPLES)

@tf.function
def train_step(x,y):
  with tf.GradientTape() as tape: 
    loss_value = loss_fn(y,model(x, training=True))
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  return loss_value

for x_b,y_b in ds:
  train_step(x_b,y_b)

accs = [(model(X).numpy().argmax(-1)==y).mean() for _ in range(200)]
accs = np.array(accs)
print('Accuracy=%s \u00B1 %s'%(accs.mean(),accs.std()))

  trainable=trainable)
  trainable=trainable)


Accuracy=0.9778370786516852 Â± 0.007909136734394449


## Deeper Network and Elegant Loss Handling

Let's train a two-layer Bayesian network on MNIST, (with 128 and 10 neurons) which should achieve ~ 96% of accuracy on test data. 



1.   We use `model.compile` only with the negative loglike (top loss)
2.   Let the KL terms are collected in `model.losses` and handled as other regularizers in  [`model.train_step`](https://github.com/keras-team/keras/blob/2c48a3b38b6b6139be2da501982fd2f61d7d48fe/keras/engine/training.py#L780)
3.   The KL term corresponds to the full data epoch (as in the equation), [so we scale the average batch loss by the number of samples](https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DenseReparameterization).


In [55]:
### build model: 2-layer bayesian network via Variational Inference ###

(X,y),(X_test,y_test) = tf.keras.datasets.mnist.load_data()

X = (X-X.mean(0))/(X.std(0)+1e-7)
X_test = (X_test-X_test.mean(0))/(X_test.std(0)+1e-7)

N_SAMPLES = X.shape[0]
N_BATCH = 32
N_CLASS = 10

model = tf.keras.Sequential([
    tf.keras.Input(shape=X.shape[1:]),
    tf.keras.layers.Flatten(),
    tfp.layers.DenseReparameterization(128,activation='relu'),
    tfp.layers.DenseReparameterization(N_CLASS,activation=None)]
)

def vi_loss(N_SAMPLES):
  ''' note: this handles only the loglike term, the KL term is handled via model regularizers '''

  def loss(y_true,y_pred):
    loglike = tf.keras.losses.sparse_categorical_crossentropy(y_true,y_pred,from_logits=True)
    loglike = tf.reduce_mean(loglike)
    return N_SAMPLES*loglike 

  return loss


model.compile(optimizer=tf.keras.optimizers.Adam(0.005),
              loss=vi_loss(N_SAMPLES),
)
model.fit(X,y,batch_size=64,epochs=10,shuffle=True)
(model(X_test).numpy().argmax(-1)==y_test).mean()

  trainable=trainable)
  trainable=trainable)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f08a8cb1950>

## Literature

[1] Tensorflow Probability Layers, Tensorflow, https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DenseReparameterization

[2] Advances in Variational Inference, https://arxiv.org/pdf/1711.05597.pdf
