In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
(train_imgs, train_labels), (test_imgs, test_labels) = keras.datasets.fashion_mnist.load_data()
train_imgs, test_imgs = train_imgs / 255, test_imgs / 255  # normalisation

Normalisation improves accuracy dramatically.

In [None]:
model = keras.Sequential([
    keras.Input((28, 28, 1)),
    layers.Conv2D(32, (3, 3), activation=tf.nn.relu),
    layers.MaxPooling2D(2, 2),
    layers.Conv2D(32, (3, 3), activation=tf.nn.relu),
    layers.MaxPooling2D(2, 2),
    layers.Flatten(),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dropout(.2),
    layers.Dense(10, activation=tf.nn.softmax),
])
model.compile(
    keras.optimizers.Adam(),
    keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']  # smart way
)
model.summary()

In [4]:
class AccCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy') > 0.95:
            print('\nReached 95% accuracy so cancelling training!')
            self.model.stop_training = True

In [None]:
model.fit(train_imgs, train_labels, epochs=5, callbacks=[AccCallback()])
test_loss, test_acc = model.evaluate(test_imgs, test_labels)
print(f'Test accuracy: {test_acc * 100:.4}%, Test loss: {test_loss:.4}')

Inspect convolutional and pooling layers

In [6]:
num_l = 4
features_model = keras.Model(inputs=model.inputs,
    outputs=[layer.output for layer in model.layers[:num_l]])
layer_names = [layer.name for layer in model.layers[:num_l]]

In [None]:
_, axs = plt.subplots(3, num_l, layout='constrained')
convolution_number = 6
img_indexes = 0, 23, 28
for i, img in enumerate(img_indexes):
    features = features_model.predict(test_imgs[img].reshape(1, 28, 28, 1))
    for layer in range(num_l):
        f = features[layer]
        axs[i, layer].imshow(f[0, :, :, convolution_number], cmap='inferno')
        axs[i, layer].grid(False)
        axs[i, layer].set_title(layer_names[layer])