# Import dependencies

In [None]:
import warnings

# For graph
import matplotlib.pyplot as plt
import seaborn as sns

# For datascience
import numpy as np

from keras.datasets import mnist
from keras.utils import to_categorical
from keras.layers import Flatten, Dense, Dropout
from keras.models import Sequential
from keras.callbacks import EarlyStopping, TerminateOnNaN, ProgbarLogger, TensorBoard
from keras.optimizers import Adam
from datetime import datetime

warnings.filterwarnings("ignore")

Mute warnings

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Getting data, observations
## Get dataset

In [None]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

Show original images

In [None]:
# Sample 25 mnist digits from train dataset
indexes = np.random.randint(0, train_images.shape[0], size=25)
images = train_images[indexes]
labels = train_labels[indexes]

# Plot the 25 mnist digits
plt.figure(figsize=(5, 5))

for i in range(len(indexes)):
    plt.subplot(5, 5, i + 1)
    image = images[i]
    plt.imshow(image, cmap="gray")
    plt.axis("off")

plt.show()
plt.savefig("mnist-samples.png")
plt.close("all")

## Preprocessing
Reshape data

In [None]:
# Resize images to match the required input
train_images_reshaped = train_images.reshape(train_images.shape[0], 28, 28)
test_images_reshaped = test_images.reshape(test_images.shape[0], 28, 28)

Resize image

In [None]:
# Normalize pixel values to be between 0 and 1
train_images_resized = train_images_reshaped.astype("float32") / 255
test_images_resized = test_images_reshaped.astype("float32") / 255

Make labels

In [None]:
# One-hot encode labels
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# Create model
## Set model

In [None]:
# Set hidden units amount
hidden_units = 128

# Get model
model = Sequential(
    [
        Flatten(input_shape=(28, 28)),  # Flatten the input image
        Dense(
            units=128, activation="relu"
        ),  # Fully connected layer with 128 neurons and ReLU activation
        Dropout(0.2),  # Dropout layer to prevent over-fitting
        Dense(
            units=10, activation="softmax"
        ),  # Output layer with 10 neurons (one for each class) and softmax activation
    ]
)

## Compile model

In [None]:
# Compile model
model.compile(
    optimizer=Adam(0.001),
    loss="categorical_crossentropy",  # Use categorical_crossentropy for one-hot encoded labels
    metrics=["accuracy"],  # Use 'accuracy' instead of SparseCategoricalAccuracy
)

## Set callbacks

In [None]:
# Get time log
logs = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

# Get TensorBoard
tboard_callback = TensorBoard(
    log_dir=logs,
    histogram_freq=1,
    profile_batch="500,520",
)

# Define early stopping callback
early_stopping = EarlyStopping(
    monitor="val_loss",  # Monitor validation loss
    patience=3,  # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=False,  # Restore weights from the epoch with the best validation loss
)
# Define terminate if Nan result appeared
terminate_on_nan = TerminateOnNaN()

# Define progress bar with metrics
progbar_logger = ProgbarLogger(count_mode="samples", stateful_metrics=["acc"])

## Train model

In [None]:
# Fit model
history = model.fit(
    x=train_images_resized,
    y=train_labels,
    epochs=6,
    validation_data=(
        test_images_resized,
        test_labels,
    ),
    callbacks=[
        early_stopping,
        terminate_on_nan,
        progbar_logger,
    ],
)

# Result visualization
## Get tensorboard

In [None]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

# Launch TensorBoard and navigate to the Profile tab to view performance profile
%tensorboard --logdir=logs --port=6011

# !!! Got message like 'Reusing TensorBoard on port XXXX' - change port !!!

Scalars trends are quite representative.

# Get graph
## We can get loss as a trend.

In [None]:
# Plot train loss
sns.lineplot(
    x=range(1, len(history.history["loss"]) + 1),
    y=history.history["loss"],
    label="Train",
)

# Plot validation loss
sns.lineplot(
    x=range(1, len(history.history["val_loss"]) + 1),
    y=history.history["val_loss"],
    label="Test",
)

plt.title("Model loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(loc="upper right")

# Evaluate the model
test_loss, test_acc = model.evaluate(test_images_resized, test_labels)
print(f"Test accuracy: {test_acc: .2f}.")

It looks like 6th epoch gives overfitting.