In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds


In [None]:
(train_ds, test_ds), info = tfds.load('mnist',
                                      split=['train', 'test'],
                                      shuffle_files=True,
                                      as_supervised=True,
                                      with_info=True)


In [None]:
BATCH_SIZE = 128


def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.divide(tf.cast(image, tf.float32), 255.), label


train_ds = train_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.cache()
train_ds = train_ds.shuffle(info.splits['train'].num_examples)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

test_ds = test_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE)
test_ds = test_ds.cache()
test_ds = test_ds.prefetch(tf.data.AUTOTUNE)


In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=info.features['image'].shape),  # type: ignore
    tf.keras.layers.Conv2D(filters=8, kernel_size=3, activation='relu'),
    tf.keras.layers.Conv2D(filters=4, kernel_size=3, activation='relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10),
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
import os

tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=os.path.join('logs', 'mnist'),
    histogram_freq=1,
    write_graph=True,
)

model.fit(
    train_ds,
    epochs=10,
    validation_data=test_ds,
    callbacks=[tensorboard_callback],
)
