In [1]:
from __future__ import absolute_import, division, print_function

In [2]:
from datetime import datetime
import io
import itertools
from packaging import version
from six.moves import range

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics

In [3]:
%load_ext tensorboard

### Download Fashion MNIST

In [4]:
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = \
    fashion_mnist.load_data()

# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

### Visualizing a Single Image

In [5]:
print(f"Shape: {train_images[0].shape}")
print(f"Label: {train_labels[0]} -> {class_names[train_labels[0]]}")

Shape: (28, 28)
Label: 9 -> Ankle boot


In [6]:
# Reshape the image for the Summary API
img = np.reshape(train_images[0], (-1, 28, 28, 1))

In [7]:
# Clear out any prior log data
!rm -rf logs

# Sets up a timestamped log directory
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")

# Creates a file writer for the log directory
file_writer = tf.summary.create_file_writer(logdir)

# Using the file writer, log the reshaped image
with file_writer.as_default():
    tf.summary.image("Training data", img, step=0)

In [8]:
%tensorboard --logdir logs/train_data

### Visualizing Multiple Images

In [9]:
with file_writer.as_default():
    images = np.reshape(train_images[0:25], (-1, 28, 28, 1))
    tf.summary.image("24 training data examples", images, max_outputs=25, step=0)
    
%tensorboard --logdir logs/train_data

Reusing TensorBoard on port 6006 (pid 10564), started 0:00:11 ago. (Use '!kill 10564' to kill it.)

### Logging Arbitrary Image Data

In [13]:
!rm -rf logs/plots

logdir = "logs/plots/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot_to_image(figure):
    """
    Converts the matplotlib plot specified by figure to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call
    """
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    
    plt.close(figure)
    buf.seek(0)
    
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image

def image_grid():
    figure = plt.figure(figsize=(10, 10))
    for i in range(25):
        plt.subplot(5, 5, i+1, title=class_names[train_labels[i]])
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i], cmap=plt.cm.binary)
        
    return figure

figure = image_grid()

with file_writer.as_default():
    tf.summary.image("Training Data", plot_to_image(figure), step=0)
    
%tensorboard --logdir logs/plots

### Building an Image Classifier

In [15]:
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [16]:
def plot_confusion_matrix(cm, class_names):
    figure = plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
    
    threshold = cm.max() / 2.0
    
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)
        
    plt.tight_layout()
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    
    return figure

In [17]:
!rm -rf logs/image

logdir = "logs/image/" + datetime.now().strftime("%Y%m%d-%H%M%S")

tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')

In [18]:
def log_confusion_matrix(epoch, logs):
    # Use the model to predict the values from the validation dataset
    test_pred_raw = model.predict(test_images)
    test_pred = np.argmax(test_pred_raw, axis=1)
    
    # Calculate the confusion matrix
    cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
    
    # Log the confusion matrix as an image summary
    figure = plot_confusion_matrix(cm, class_names=class_names)
    cm_image = plot_to_image(figure)
    
    # Log the confusion matrix as an image summary
    with file_writer_cm.as_default():
        tf.summary.image("Confusion Matrix", cm_image, step=epoch)
        
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)

In [19]:
%tensorboard --logdir logs/image

model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0,
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels),
)

<tensorflow.python.keras.callbacks.History at 0x14c499550>