In [9]:
# Start TensorBoard.
%reload_ext tensorboard

import tensorflow as tf
from tensorboard.plugins.hparams import api as hp
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from pathlib import Path
import itertools
import sklearn
import sklearn.metrics
import io
from tensorflow import keras

In [10]:
# Download the data. The data is already divided into train and test.
# The labels are integers representing classes.
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = \
    fashion_mnist.load_data()

# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


In [11]:
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)


In [12]:
def plot_confusion_matrix(cm, class_names):
  """
  Returns a matplotlib figure containing the plotted confusion matrix.

  Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
  """
  figure = plt.figure(figsize=(8, 8))
  plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
  plt.title("Confusion matrix")
  plt.colorbar()
  tick_marks = np.arange(len(class_names))
  plt.xticks(tick_marks, class_names, rotation=45)
  plt.yticks(tick_marks, class_names)

  # Compute the labels from the normalized confusion matrix.
  labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

  # Use white text if squares are dark; otherwise black.
  threshold = cm.max() / 2.
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    color = "white" if cm[i, j] > threshold else "black"
    plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  return figure


In [13]:
# Clear out prior logging data.
!rm -rf logs/

logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Define the basic TensorBoard callback.
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')


In [14]:
def plot_to_image(figure):
  """Converts the matplotlib plot specified by 'figure' to a PNG image and
  returns it. The supplied figure is closed and inaccessible after this call."""
  # Save the plot to a PNG in memory.
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  # Closing the figure prevents it from being displayed directly inside
  # the notebook.
  plt.close(figure)
  buf.seek(0)
  # Convert PNG buffer to TF image
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  # Add the batch dimension
  image = tf.expand_dims(image, 0)
  return image

In [15]:
def log_confusion_matrix(epoch, logs):
  print("-----")
  # Use the model to predict the values from the validation dataset.
  test_pred_raw = model.predict(test_images)
  print(len(test_pred_raw))

  test_pred = np.argmax(test_pred_raw, axis=1)
  print(len(test_pred))
  # Calculate the confusion matrix.
  cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
  # Log the confusion matrix as an image summary.
  figure = plot_confusion_matrix(cm, class_names=class_names)
  cm_image = plot_to_image(figure)

  # Log the confusion matrix as an image summary.
  with file_writer_cm.as_default():
    tf.summary.image("Confusion Matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)


In [16]:
# # Start TensorBoard.
%reload_ext tensorboard
%tensorboard --logdir logs --bind_all --purge_orphaned_data True

# Train the classifier.
model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, # Suppress chatty output
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels),
)


Reusing TensorBoard on port 6006 (pid 1670), started 0:01:57 ago. (Use '!kill 1670' to kill it.)

-----
[[5.5399297e-36 0.0000000e+00 0.0000000e+00 ... 2.9944065e-01
  5.9337261e-24 5.7420468e-01]
 [6.6975720e-02 3.4300022e-02 1.6531739e-01 ... 4.1683599e-02
  1.6462050e-01 5.8530502e-02]
 [2.3791726e-14 9.9999964e-01 0.0000000e+00 ... 0.0000000e+00
  8.9320259e-23 7.8401113e-38]
 ...
 [6.6975720e-02 3.4300022e-02 1.6531739e-01 ... 4.1683599e-02
  1.6462050e-01 5.8530502e-02]
 [3.2261640e-30 9.9978501e-01 0.0000000e+00 ... 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [2.3902217e-03 4.6109073e-03 9.9074142e-03 ... 4.5424235e-01
  6.0284059e-03 5.0826404e-02]]
[9 2 1 ... 2 1 7]
-----
[[0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 5.7795830e-02
  8.6946001e-25 8.5066849e-01]
 [5.1347181e-02 1.7271357e-02 2.0602329e-01 ... 2.4220850e-02
  1.9772097e-01 1.8474031e-02]
 [8.2250323e-10 1.0000000e+00 0.0000000e+00 ... 0.0000000e+00
  7.4648212e-28 0.0000000e+00]
 ...
 [5.1347181e-02 1.7271357e-02 2.0602329e-01 ... 2.4220850e-02
  1.9772097e-01 1.8474031e-02]
 [1.1239955e-08 1.0000000e+00

<keras.callbacks.History at 0x7fc724125ba8>