# Confusion matrix
### Exercise
Implement a custom callback that calculates and sends a confusion matrix to Tensorboard.


In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import io

from sklearn.metrics import confusion_matrix
from toy_model import create_model

%load_ext autoreload
%autoreload 2

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
n_classes = tf.unique(y_test).y.shape[0]

In [None]:
x_train_net = tf.expand_dims(x_train, axis=3)
y_train_net = tf.one_hot(y_train, n_classes)
x_train_net.shape, y_train_net.shape

In [None]:
model = create_model(n_classes)

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(loss=loss, optimizer=optimizer, metrics=["accuracy"])

In [None]:
def matplotlib_to_png(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

In [None]:
class ConfusionMatrixVis(tf.keras.callbacks.Callback):

    def __init__(self, writer, x, y, class_names):
        super(ConfusionMatrixVis, self).__init__()
        self.writer = writer
        self.inputs = x
        self.outputs = y
        self.class_names = class_names
        
    def on_epoch_end(self, epoch, logs=None):
        y_pred = self.model.predict(self.inputs)
#         confusion_matrix(y_pred)
        pred_labels = tf.argmax(y_pred, 1)
        true_labels = tf.argmax(self.outputs, 1)
        cm = confusion_matrix(true_labels, pred_labels, normalize='true')
        
        figure = plt.figure(figsize=(8, 8))
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title("Confusion matrix")
        plt.colorbar()
        tick_marks = np.arange(len(self.class_names))
        plt.xticks(tick_marks, self.class_names, rotation=45)
        plt.yticks(tick_marks, self.class_names)
        
        image = matplotlib_to_png(figure)
        
        with self.writer.as_default():
            tf.summary.image("Confusion Matrix", image, step=epoch)

In [None]:
image_writer = tf.summary.create_file_writer('logs/validation')

custom_cb = ConfusionMatrixVis(image_writer, 
                               x=x_train_net[100:], 
                               y=y_train_net[100:],
                               class_names=range(10))

In [None]:
model.fit(x_train_net[:100], y_train_net[:100], 
          epochs=5, 
          callbacks=[custom_cb], 
          validation_split=.2)