In [20]:
import os
import warnings
import matplotlib
matplotlib.use('Agg')
from matplotlib import figure
from matplotlib.backends import backend_agg
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
warnings.simplefilter(action='ignore')
import seaborn as sns
tfd = tfp.distributions
from utils import *

In [21]:
NUM_CLASSES = 10
LEARNING_RATE = 0.001
NUM_EPOCHS = 300
BATCH_SIZE =128

In [24]:
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
labels = pd.Series(train_set[1])
print("Training set class count")
labels.value_counts()

Training set class count


1    6742
7    6265
3    6131
2    5958
9    5949
0    5923
6    5918
8    5851
4    5842
5    5421
dtype: int64

In [25]:
where_in = np.where(train_set[1] != 3)
where_in
train_in = train_set[0][where_in]
train_set_in = (train_set[0][where_in], train_set[1][where_in])
train_in.shape

(array([    0,     1,     2, ..., 59997, 59998, 59999], dtype=int64),)

In [31]:
def create_model():
  kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(NUM_TRAIN_EXAMPLES, dtype=tf.float32))

  model = tf.keras.models.Sequential([
      tfp.layers.Convolution2DFlipout(
          6, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          16, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          120, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.Flatten(),
      tfp.layers.DenseFlipout(
          84, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tfp.layers.DenseFlipout(
          NUM_CLASSES, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.softmax)
  ])

  model.compile(tf.keras.optimizers.Adam(lr= LEARNING_RATE),
                loss='categorical_crossentropy',
                metrics=['accuracy'],
                experimental_run_tf_function=False)
  return model


train_seq = DatasetSequence(data=train_set_in, batch_size=BATCH_SIZE)
heldout_seq = DatasetSequence(data=heldout_set, batch_size=BATCH_SIZE)
model = create_model()
model.build(input_shape=[None, 28, 28, 1])

In [34]:
print(' ... Training convolutional neural network')
for epoch in range(NUM_EPOCHS):
    print(epoch)
    epoch_accuracy, epoch_loss = [], []
    for step, (batch_x, batch_y) in enumerate(train_seq):
      batch_loss, batch_accuracy = model.train_on_batch(batch_x, batch_y)
model.save('mnist_bayes.tf')

 ... Training convolutional neural network
0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
INFO:tensorflow:Assets written to: mnist_bayes.tf\assets
