In [None]:
%matplotlib widget
import time
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
class GraphLoss(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []
    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs["loss"])
    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch}")
        plt.ylabel("Loss")
        plt.legend()
        #plt.savefig(f"plt_at_epoch_{epoch}")
        self.per_batch_losses = []

In [None]:
def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_img, val_img = images[10000:], images[:10000]
train_lbl, val_lbl = labels[10000:], labels[:10000]

mdl = get_mnist_model()
mdl.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
_ = mdl.fit(train_img, train_lbl, epochs=10, validation_data=(val_img, val_lbl), callbacks=[GraphLoss()])