# Bayesian CNN

__Objective:__ experiment with a Bayesian CNN to classify images from the CIFAR-10 dataset.

In [None]:
# Execute on Colab.
# !pip install keras_cv

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_probability as tfp
import keras_cv
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## Load data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Create a validation set.
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.4)

print(
    'Training:', x_train.shape, y_train.shape,
    '\nValidation:', x_val.shape, y_val.shape,
    '\nTest:', x_test.shape, y_test.shape
)

In [None]:
class_labels = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog',
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(14, 6))

random_indices = tf.random.shuffle(tf.range(0, x_train.shape[0]))[:6]

for i, ax in enumerate(axs):
    ax.imshow(x_train[random_indices[i], ...], cmap='gray')

    plt.sca(ax)
    plt.title(f'Class: {class_labels[y_train[random_indices[i]][0]]}')

### Preprocessing

In [None]:
def preprocess_data(x, y, class_to_eliminate=None):
    """
    """
    # Turn images to grayscale.
    # Shape transformation: (batch_shape, 32, 32, 3) -> (batch_shape, 32, 32, 1)
    # (a single channel).
    x_preprocessed = keras_cv.layers.Grayscale(output_channels=1)(x)

    # Convert data to tensors and normalize pixel values.
    x_preprocessed = tf.constant(x_preprocessed, dtype=tf.float32) / 255.

    # Change target tensor shape: (batch_shape, 1) -> (batch_shape,).
    y_preprocessed = y[:, 0]

    # Eliminate class if required.
    if class_to_eliminate is not None:
        x_preprocessed = x_preprocessed[y_preprocessed != class_to_eliminate]
        y_preprocessed = y_preprocessed[y_preprocessed != class_to_eliminate]

    # One-hot encode the targets.
    y_preprocessed = tf.one_hot(
        y_preprocessed,
        depth=10 if class_to_eliminate is None else 9
    )

    return x_preprocessed, y_preprocessed

In [None]:
x_train, y_train = preprocess_data(x_train, y_train, class_to_eliminate=7)
x_val, y_val = preprocess_data(x_val, y_val, class_to_eliminate=7)
x_test, y_test = preprocess_data(x_test, y_test)

print(
    'Training:', x_train.shape, y_train.shape,
    '\nValidation:', x_val.shape, y_val.shape,
    '\nTest:', x_test.shape, y_test.shape
)

## Model building

Build a Bayesian CNN with `Convolution2DFlipout` layers. By default, only the kernel weights are treated in a Bayesian way: we could force the same for the bias terms but that would further increase the number of parameters (which is already doubled w.r.t. the non-Bayesian counterpart, as now for each of the original weights there are the $\mu$ and $\sigma$ parameters of its approximate posterior (variational distribution).

Observations:
- Adding batch normalization layers helps avoiding exploding gradients. These are usually put after the activation function following convolutional and dense layers.
- Once the minimum of the loss has been reached (given a value for the learning rate), increasing the batch size can help "squeeze" some information by computing more exact gradients.

### Test on a non-Bayesian CNN model

In [None]:
class ClassicCNN(tf.keras.layers.Layer):
    """
    CNN network implemented as a Keras `Layer` subclass. See
    the `BayesianCNN` for the architecture: this is its
    non-bayesian counterpart.
    """
    def __init__(self):
        """
        """
        super().__init__()

        # Initialize the internal layers.
        self.conv_block_1 = [
            tf.keras.layers.Convolution2D(
                filters=8,
                kernel_size=(3, 3),
                padding='same',
                activation='relu'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        ]

        self.conv_block_2 = [
            tf.keras.layers.Convolution2D(
                filters=16,
                kernel_size=(3, 3),
                padding='same',
                activation='relu'
            ),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.MaxPooling2D(pool_size=(2, 2))
        ]

        self.flatten = tf.keras.layers.Flatten()

        self.dense_block = [
            tf.keras.layers.Dense(units=100, activation='relu'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Dense(units=100, activation='relu'),
            tf.keras.layers.BatchNormalization()
        ]

        self.output_layer = tf.keras.layers.Dense(units=9, activation='softmax')

    def call(self, x):
        """
        Forward pass.
        """
        for conv_layer in self.conv_block_1:
            x = conv_layer(x)

        for conv_layer in self.conv_block_2:
            x = conv_layer(x)

        x = self.flatten(x)

        for dense_layer in self.dense_block:
            x = dense_layer(x)

        x = self.output_layer(x)

        return x


# Define callback to save/reload the model automatically every time
# training ends/starts.
backup_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir='./models/'
)

# Input has shape (32, 32, 3) for RGB images and (32, 32, 1)
# for grayscale ones.
inputs = tf.keras.layers.Input(shape=(32, 32, 1,))
outputs = ClassicCNN()(inputs)

cnn_model = tf.keras.Model(
    inputs=inputs,
    outputs=outputs
)

cnn_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

cnn_model.compile(
    loss='categorical_crossentropy',
    optimizer=cnn_optimizer,
    metrics=['accuracy']
)

cnn_model.summary()

In [None]:
K.get_value(cnn_model.optimizer.lr)

In [None]:
K.set_value(cnn_model.optimizer.lr, 1e-6)

print('New learning rate:', K.get_value(cnn_model.optimizer.lr))

In [None]:
epochs = 1000

cnn_history = cnn_model.fit(
    x_train,
    y_train,
    validation_data=(x_val, y_val),
    epochs=epochs,
    batch_size=27007,
    callbacks=[backup_callback]
)

In [None]:
# Loss history
fig = plt.figure(figsize=(14, 6))

sns.lineplot(
    x=range(len(cnn_history.history['loss'])),
    y=cnn_history.history['loss'],
    label='Total',
    color=sns.color_palette()[0]
)

plt.title('Training loss', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss value')
plt.legend()

### Test on a Bayesian CNN model

Observations:
- The loss values in the Bayesian case around an order of magnitude bigger than in the non-Bayesian one. I think this is expected as VI adds the KL terms to the loss (one for each variational distribution).
- Training is more difficult, in that the NN seems to plateau on worse performance w.r.t. to the non-Bayesian case with the equivalent architecture. My impression is that there's a lot of noise in the training process (see points below).
- In the Bayesian CNN case, the loss tends to "bounce back up" at some point. This could be given by vanishing/exploding gradients.
- Batch normalization - while probably a good idea by analogy with the non-Bayesian case - doesn't fully solve the above problem.
- Increasing the batch size seems to be have a bigger effect on the above problem. Possible explanation: the Monte Carlo estimate of the NLL part of the loss adds noise to the loss itself, which probably adds noise to the gradients as well. This adds up with the noise already introduced by minibatch gradient descent (are these noise source with the same size?): increasing the batch size reduces at least one source of noise.

In [None]:
def kernel_divergence_fn(q, p, _):
    """
    Note: KL divergence is NOT symmetric and it is assumed
          that the approximate posterior (variational
          distribution) is the FIRST entry and the prior
          is the SECOND one.
    """
    return tfd.kl_divergence(q, p) / (x_train.shape[0] * 1.)


class BayesianCNN(tf.keras.layers.Layer):
    """
    Keras `Layer` object implementing a Bayesian CNN. Structure:
      * Convolutional block (2 `Convolution2DFlipout` layers).
      * Maxpooling.
      * Convolutional block (2 `Convolution2DFlipout` layers).
      * Maxpooling.
      * Flattening.
      * Fully connected block (2 `DenseFlipout` layers and a
          final output one).
    """
    def __init__(self):
        """
        """
        super().__init__()

        # Initialize the internal layers.
        self.conv_block_1 = [
            tfp.layers.Convolution2DFlipout(
                8,
                kernel_size=(3, 3),
                padding='same',
                # activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.LeakyReLU(),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.MaxPooling2D((2, 2))
        ]

        self.conv_block_2 = [
            tfp.layers.Convolution2DFlipout(
                16,
                kernel_size=(3, 3),
                padding='same',
                # activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.LeakyReLU(),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.MaxPooling2D((2, 2))
        ]

        self.flatten = tf.keras.layers.Flatten()

        self.dense_block = [
            tfp.layers.DenseFlipout(
                units=100,
                # activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.LeakyReLU(),
            tf.keras.layers.BatchNormalization(),
            tfp.layers.DenseFlipout(
                units=100,
                # activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.LeakyReLU(),
            tf.keras.layers.BatchNormalization(),
        ]

        self.output_layer = tfp.layers.DenseFlipout(units=9, activation='softmax', kernel_divergence_fn=kernel_divergence_fn)

    def call(self, x):
        """
        Forward pass.
        """
        for conv_layer in self.conv_block_1:
            x = conv_layer(x)

        for conv_layer in self.conv_block_2:
            x = conv_layer(x)

        x = self.flatten(x)

        for dense_layer in self.dense_block:
            x = dense_layer(x)

        x = self.output_layer(x)

        return x

In [None]:
bayesian_cnn_layer = BayesianCNN()

# Test.
bayesian_cnn_layer(x_train[:15]).shape

In [None]:
# Test: the final softmax activation should normalize all output vectors
# to 1.
tf.reduce_sum(bayesian_cnn_layer(x_train[:15, ...]), axis=-1)

Build a Keras `Model` object.

In [None]:
# Input has shape (32, 32, 3) for RGB images and (32, 32, 1)
# for grayscale ones.
inputs = tf.keras.Input(shape=(32, 32, 1,))

outputs = BayesianCNN()(inputs)

bayesian_cnn_model = tf.keras.Model(
    inputs=inputs,
    outputs=outputs
)

bayesian_cnn_model.summary()

Multiple predictions return different outputs.

In [None]:
for _ in range(5):
    print(bayesian_cnn_model(x_test[:1, ...]))

Training.

In [None]:
# Define callback to save/reload the model automatically every time
# training ends/starts.
backup_callback = tf.keras.callbacks.BackupAndRestore(
    backup_dir='./models/'
)

# Recreate the model (for retraining purposes).
inputs = tf.keras.Input(shape=x_train.shape[1:])

outputs = BayesianCNN()(inputs)

bayesian_cnn_model = tf.keras.Model(
    inputs=inputs,
    outputs=outputs
)

bayesian_cnn_model.summary()

In [None]:
learning_rate = 1e-3

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # Previously tried SGD.

bayesian_cnn_model.compile(
    loss='categorical_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy']
)

In [None]:
K.set_value(bayesian_cnn_model.optimizer.lr, 1e-5)

print('New learning rate:', K.get_value(bayesian_cnn_model.optimizer.lr))

In [None]:
epochs = 1000

history = bayesian_cnn_model.fit(
    x_train,
    y_train,
    validation_data=(x_val, y_val),
    epochs=epochs,
    batch_size=x_train.shape[0],
    callbacks=[backup_callback]
)

In [None]:
# Loss history
fig = plt.figure(figsize=(14, 6))

sns.lineplot(
    x=range(len(history.history['loss'])),
    y=history.history['loss'],
    label='Total',
    color=sns.color_palette()[0]
)

plt.title('Training loss', fontsize=14)
plt.xlabel('Epoch')
plt.ylabel('Loss value')
plt.legend()