In [58]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

import numpy as np
import matplotlib.pyplot as plt

In [59]:
train_data, test_data = tfds.load(name='fashion_mnist',
                                  split=['train', 'test'],
                                  as_supervised=True)

def ohe_normalize(images, labels):
    images = tf.cast(images, tf.float32)
    images = tf.divide(images, 255.0)
    labels = tf.one_hot(labels, 10)
    return images, labels

train_data = train_data.batch(128).map(ohe_normalize).shuffle(128).prefetch(tf.data.AUTOTUNE)
test_data = test_data.batch(128).map(ohe_normalize).prefetch(tf.data.AUTOTUNE)

In [60]:
train_size = 60000
def divergence_fn(q, p, q_tensor):
  return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor)) / train_size

def createConvReparamLayer(filters, kernel_size, activation):
  tfpl = tfp.layers
  return tfpl.Convolution2DReparameterization(
      filters=filters,
      kernel_size=kernel_size,
      activation=activation,
      padding='same',
      kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
      bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
      kernel_prior_fn=tfpl.default_multivariate_normal_fn,
      bias_prior_fn=tfpl.default_multivariate_normal_fn,
      kernel_divergence_fn=divergence_fn,
      bias_divergence_fn=divergence_fn)

def createDenseReparamLayer(unit_num):
  tfpl = tfp.layers
  return tfpl.DenseReparameterization(
      units=tfp.layers.OneHotCategorical.params_size(unit_num),
      activation=None,
      kernel_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
      bias_posterior_fn=tfpl.default_mean_field_normal_fn(is_singular=False),
      kernel_prior_fn=tfpl.default_multivariate_normal_fn,
      bias_prior_fn=tfpl.default_multivariate_normal_fn,
      kernel_divergence_fn=divergence_fn,
      bias_divergence_fn=divergence_fn)

def createModel():
  tfpl = tfp.layers
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer((28, 28, 1)),
      createConvReparamLayer(16, 3, 'swish'),
      tf.keras.layers.MaxPooling2D(2),
      createConvReparamLayer(32, 3, 'swish'),
      tf.keras.layers.MaxPooling2D(2),
      createConvReparamLayer(64, 3, 'swish'),
      tf.keras.layers.GlobalMaxPooling2D(),
      createDenseReparamLayer(10),
      tfpl.OneHotCategorical(10)
  ])
  return model

In [61]:
bcnn = createModel()
bcnn.compile(loss=lambda y, y_hat: -y_hat.log_prob(y),
             optimizer=tf.keras.optimizers.Adam(0.001),
             metrics=['accuracy'])
bcnn.summary()

Model: "sequential_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_reparameterization_  (None, 28, 28, 16)        320       
 27 (Conv2DReparameterizati                                      
 on)                                                             
                                                                 
 max_pooling2d_17 (MaxPooli  (None, 14, 14, 16)        0         
 ng2D)                                                           
                                                                 
 conv2d_reparameterization_  (None, 14, 14, 32)        9280      
 28 (Conv2DReparameterizati                                      
 on)                                                             
                                                                 
 max_pooling2d_18 (MaxPooli  (None, 7, 7, 32)          0         
 ng2D)                                               

In [62]:
bcnn.fit(train_data, epochs=12, validation_data=test_data)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.src.callbacks.History at 0x793ed76a3490>