# Bayesian CNN

__Objective:__ experiment with a Bayesian CNN to classify images from the CIFAR-10 dataset.

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import seaborn as sns

tfd = tfp.distributions

sns.set_theme()

## Load data

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [None]:
x_train.shape, y_train.shape, x_test.shape, y_test.shape

In [None]:
class_labels = {
    0: 'airplane',
    1: 'automobile',
    2: 'bird',
    3: 'cat',
    4: 'deer',
    5: 'dog',
    6: 'frog',
    7: 'horse',
    8: 'ship',
    9: 'truck'
}

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(14, 6))

random_indices = tf.random.shuffle(tf.range(0, x_train.shape[0]))[:6]

for i, ax in enumerate(axs):
    ax.imshow(x_train[random_indices[i], ...])
    
    plt.sca(ax)
    plt.title(f'Class: {class_labels[y_train[random_indices[i]][0]]}')

Eliminate class 7 (horses) from the training dataset to then test the model on previously unseen classes.

In [None]:
x_train = tf.constant(x_train, dtype=tf.float32)
x_test = tf.constant(x_test, dtype=tf.float32)

In [None]:
x_train_missing_class = x_train[y_train[:, 0] != 7]
y_train_missing_class = y_train[y_train[:, 0] != 7]

In [None]:
y_train_missing_class_one_hot = tf.one_hot(
    y_train_missing_class[:, 0],
    depth=9
)

y_train_missing_class_one_hot

## Model building

Build a Bayesian CNN with `Convolution2DFlipout` layers. By default, only the kernel weights are treated in a Bayesian way: we could force the same for the bias terms but that would further increase the number of parameters (which is already doubled w.r.t. the non-Bayesian counterpart, as now for each of the original weights there are the $\mu$ and $\sigma$ parameters of its approximate posterior (variational distribution).

In [None]:
def kernel_divergence_fn(q, p, _):
    """
    """
    return tfd.kl_divergence(p, q) / (x_train.shape[0] * 1.)


class BayesianCNN(tf.keras.layers.Layer):
    """
    Keras `Layer` object implementing a Bayesian CNN. Structure:
      * Convolutional block (2 `Convolution2DFlipout` layers).
      * Maxpooling.
      * Convolutional block (2 `Convolution2DFlipout` layers).
      * Maxpooling.
      * Flattening.
      * Fully connected block (2 `DenseFlipout` layers and a
          final output one).
    """
    def __init__(self):
        """
        """
        super().__init__()
        
        # Initialize the internal layers.
        self.conv_block_1 = [
            tfp.layers.Convolution2DFlipout(
                8,
                kernel_size=(3, 3),
                padding='same',
                activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tfp.layers.Convolution2DFlipout(
                8,
                kernel_size=(3, 3),
                padding='same',
                activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.MaxPooling2D((2, 2))
        ]
        
        self.conv_block_2 = [
            tfp.layers.Convolution2DFlipout(
                16,
                kernel_size=(3, 3),
                padding='same',
                activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tfp.layers.Convolution2DFlipout(
                16,
                kernel_size=(3, 3),
                padding='same',
                activation='relu',
                kernel_divergence_fn=kernel_divergence_fn
            ),
            tf.keras.layers.MaxPooling2D((2, 2))
        ]
        
        self.flatten = tf.keras.layers.Flatten()
        
        self.dense_block = [
            tfp.layers.DenseFlipout(units=100, activation='relu', kernel_divergence_fn=kernel_divergence_fn),
            tfp.layers.DenseFlipout(units=100, activation='relu', kernel_divergence_fn=kernel_divergence_fn),
        ]
        
        self.output_layer = tfp.layers.DenseFlipout(units=9, activation='softmax', kernel_divergence_fn=kernel_divergence_fn)

    def call(self, x):
        """
        Forward pass.
        """
        for conv_layer in self.conv_block_1:
            x = conv_layer(x)
        
        for conv_layer in self.conv_block_2:
            x = conv_layer(x)
        
        x = self.flatten(x)
        
        for dense_layer in self.dense_block:
            x = dense_layer(x)
            
        x = self.output_layer(x)
            
        return x

In [None]:
bayesian_cnn_layer = BayesianCNN()

# Test.
bayesian_cnn_layer(x_train[:15]).shape

In [None]:
# Test.
tf.reduce_sum(bayesian_cnn_layer(x_train[:15]), axis=-1)

Build a Keras `Model` object.

In [None]:
inputs = tf.keras.Input(shape=(32, 32, 3,))

outputs = bayesian_cnn_layer(inputs)

bayesian_cnn_model = tf.keras.Model(
    inputs=inputs,
    outputs=outputs
)

bayesian_cnn_model.summary()

Multiple predictions return different outputs.

In [None]:
for _ in range(5):
    print(bayesian_cnn_model(x_test[:1, ...]))

## Model training

In [None]:
learning_rate = 1e-3

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

bayesian_cnn_model.compile(
    loss='categorical_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy']
)

In [None]:
epochs = 1

history = bayesian_cnn_model.fit(
    x_train_missing_class,
    y_train_missing_class_one_hot,
    batch_size=128,
    epochs=epochs
)