In [20]:
import os
import warnings
import matplotlib
matplotlib.use('Agg')
from matplotlib import figure
from matplotlib.backends import backend_agg
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
warnings.simplefilter(action='ignore')
import seaborn as sns
tfd = tfp.distributions

In [21]:
IMAGE_SHAPE = [28, 28, 1]
NUM_TRAIN_EXAMPLES = 60000
NUM_HELDOUT_EXAMPLES = 10000
NUM_CLASSES = 10
LEARNING_RATE = 0.001
NUM_EPOCHS = 300
BATCH_SIZE =128

In [22]:
def create_model():
  kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(NUM_TRAIN_EXAMPLES, dtype=tf.float32))

  model = tf.keras.models.Sequential([
      tfp.layers.Convolution2DFlipout(
          6, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          16, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.MaxPooling2D(
          pool_size=[2, 2], strides=[2, 2],
          padding='SAME'),
      tfp.layers.Convolution2DFlipout(
          120, kernel_size=5, padding='SAME',
          kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tf.keras.layers.Flatten(),
      tfp.layers.DenseFlipout(
          84, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.relu),
      tfp.layers.DenseFlipout(
          NUM_CLASSES, kernel_divergence_fn=kl_divergence_function,
          activation=tf.nn.softmax)
  ])

  model.compile(tf.keras.optimizers.Adam(lr= LEARNING_RATE),
                loss='categorical_crossentropy',
                metrics=['accuracy'],
                experimental_run_tf_function=False)
  return model


In [23]:
class DatasetSequence(tf.keras.utils.Sequence):

  def __init__(self, data, batch_size=128):
    images, labels = data
    self.images, self.labels = DatasetSequence.__preprocessing(images, labels)
    self.batch_size = batch_size

  @staticmethod
  def __preprocessing(images, labels):
    images = 2 * (images / 255.) - 1.
    images = images[..., tf.newaxis]
    labels = tf.keras.utils.to_categorical(labels)
    return images, labels

  def __len__(self):
    return int(tf.math.ceil(len(self.images) / self.batch_size))

  def __getitem__(self, idx):
    batch_x = self.images[idx * self.batch_size: (idx + 1) * self.batch_size]
    batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
    return batch_x, batch_y

In [24]:
train_set, heldout_set = tf.keras.datasets.mnist.load_data()
labels = pd.Series(train_set[1])
print("Training set class count")
labels.value_counts()

Training set class count


1    6742
7    6265
3    6131
2    5958
9    5949
0    5923
6    5918
8    5851
4    5842
5    5421
dtype: int64

In [25]:
where_out = np.where(train_set[1] == 3)
where_out
where_in = np.where(train_set[1] != 3)
where_in

(array([    0,     1,     2, ..., 59997, 59998, 59999], dtype=int64),)

In [26]:
train_out = train_set[0][where_out]
train_set_out = (train_set[0][where_out], train_set[1][where_out])
train_out.shape

(6131, 28, 28)

In [27]:
train_in = train_set[0][where_in]
train_set_in = (train_set[0][where_in], train_set[1][where_in])
train_in.shape

(53869, 28, 28)

In [28]:
train_set[0].shape

(60000, 28, 28)

In [29]:
heldout_y_srs = pd.Series(heldout_set[1])
heldout_y_srs.value_counts()

1    1135
2    1032
7    1028
3    1010
9    1009
4     982
0     980
8     974
6     958
5     892
dtype: int64

In [30]:
heldout_x = heldout_set[0]
heldout_y = heldout_set[1]
reorder_heldout = list()
for digit in range(10):
    where_digit = np.where(heldout_y == digit)
    reorder_heldout.append(where_digit[0][0])
reorder_heldout.extend(list(range(len(heldout_x))))
heldout_x = heldout_x[reorder_heldout]
heldout_y = heldout_y[reorder_heldout]
heldout_set = (heldout_x, heldout_y)

In [31]:
train_seq = DatasetSequence(data=train_set_in, batch_size=BATCH_SIZE)
heldout_seq = DatasetSequence(data=heldout_set, batch_size=BATCH_SIZE)

In [32]:
model = create_model()
model.build(input_shape=[None, 28, 28, 1])

In [34]:
print(' ... Training convolutional neural network')
for epoch in range(NUM_EPOCHS):
    print(epoch)
    epoch_accuracy, epoch_loss = [], []
    for step, (batch_x, batch_y) in enumerate(train_seq):
      batch_loss, batch_accuracy = model.train_on_batch(batch_x, batch_y)
      print(step)
      # epoch_accuracy.append(batch_accuracy)
      # epoch_loss.append(batch_loss)
      #
      # if step % 100 == 0:
      #   print('Epoch: {}, Batch index: {}, '
      #         'Loss: {:.3f}, Accuracy: {:.3f}'.format(
      #             epoch, step,
      #             tf.reduce_mean(epoch_loss),
      #             tf.reduce_mean(epoch_accuracy)))
      #
      # if (step+1) % viz_steps == 0:
      #   # Compute log prob of heldout set by averaging draws from the model:
      #   # p(heldout | train) = int_model p(heldout|model) p(model|train)
      #   #                   ~= 1/n * sum_{i=1}^n p(heldout | model_i)
      #   # where model_i is a draw from the posterior p(model|train).
      #   print(' ... Running monte carlo inference')
      #   probs = tf.stack([model.predict(heldout_seq, verbose=1)
      #                     for _ in range(num_monte_carlo)], axis=0)
      #   mean_probs = tf.reduce_mean(probs, axis=0)
      #   heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))
      #   print(' ... Held-out nats: {:.3f}'.format(heldout_log_prob))
      #
      #   if HAS_SEABORN:
      #     names = [layer.name for layer in model.layers
      #              if 'flipout' in layer.name]
      #     qm_vals = [layer.kernel_posterior.mean().numpy()
      #                for layer in model.layers
      #                if 'flipout' in layer.name]
      #     qs_vals = [layer.kernel_posterior.stddev().numpy()
      #                for layer in model.layers
      #                if 'flipout' in layer.name]
      #     plot_weight_posteriors(names, qm_vals, qs_vals,
      #                            fname=os.path.join(
      #                                model_dir,
      #                                'epoch{}_step{:05d}_weights.png'.format(
      #                                    epoch, step)))
      #     plot_heldout_prediction(heldout_seq.images, probs.numpy(),
      #                             fname=os.path.join(
      #                                 model_dir,
      #                                 'epoch{}_step{}_pred.png'.format(
      #                                     epoch, step)),
      #                             title='mean heldout logprob {:.2f}'
      #                             .format(heldout_log_prob))
model.save('mnist_bayes.tf')

 ... Training convolutional neural network
0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
INFO:tensorflow:Assets written to: mnist_bayes.tf\assets


In [35]:
from tensorflow import keras
model = keras.models.load_model('mnist_bayes.tf')
num_monte_carlo = 5

probs = tf.stack([model.predict(heldout_seq, verbose=1)
                  for _ in range(num_monte_carlo)], axis=0)
mean_probs = tf.reduce_mean(probs, axis=0)
heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))

model_dir = 'final_bayes'
# plot_heldout_prediction(heldout_seq.images, probs.numpy(),
#                                   fname=os.path.join(
#                                       model_dir,
#                                       'pred.png'.format(
#                                           epoch, step)),
#                                   title='mean heldout logprob {:.2f}'
#                                   .format(heldout_log_prob))



In [36]:
heldout_x = heldout_set[0]
heldout_y = heldout_set[1]
reorder_heldout = list()

where_digit = np.where(heldout_y == 3)[0].tolist()

heldout_x = heldout_x[where_digit]
heldout_y = heldout_y[where_digit]
heldout_set = (heldout_x, heldout_y)
heldout_seq = DatasetSequence(data=heldout_set, batch_size=BATCH_SIZE)

In [None]:
where_digit = np.where(heldout_y == 3)[0].tolist()

In [None]:
num_monte_carlo = 100

probs = tf.stack([model.predict(heldout_seq, verbose=1)
                  for _ in range(num_monte_carlo)], axis=0)
mean_probs = tf.reduce_mean(probs, axis=0)
heldout_log_prob = tf.reduce_mean(tf.math.log(mean_probs))
for i in range(5):
    plot_heldout_prediction(heldout_seq.images[10*i:10*i+10], probs.numpy()[:,10*i:10*i+10,:],
                                      fname=os.path.join(
                                          model_dir,
                                          f'pred_{i}.png'.format(
                                              epoch, step)),
                                      title='mean heldout logprob {:.2f}'
                                      .format(heldout_log_prob))

In [None]:
probs_np = probs.numpy()
probs_np = probs_np[:,10*i:10*i+10,:]

In [None]:
num_monte_carlo = 10
probs = tf.stack([model.predict(heldout_seq, verbose=1)
                  for _ in range(num_monte_carlo)], axis=0).numpy()

In [None]:
images = heldout_seq.images
labels = heldout_seq.labels

In [None]:
true_labels = list()
pred_labels = list()
for i in range(len(images)):
    true_label = np.argmax(labels[i])
    pred_means = np.mean(probs[:, i, :], axis=0)
    pred_label = np.argmax(pred_means)
    if pred_means[pred_label] > 0.95:
        true_labels.append(true_label)
        pred_labels.append(pred_label)

In [None]:
len(pred_labels)/len(labels)