In [None]:
import marimo as mo
import tensorflow as tf
from tensorflow.keras.models import Sequential # type: ignore
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout # type: ignore
from tensorflow.keras.callbacks import EarlyStopping # type: ignore
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

import numpy as np
import matplotlib.pyplot as plt
print(f"{tf.__version__ = }")
print(f"{np.__version__ = }")
print(f"{mo.__version__ = }")

# MNIST Digit Recognizer

In [None]:
# Load the MNIST dataset from Keras
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Split the full training set
x_train, x_val, y_train, y_val = train_test_split(
    x_train_full, y_train_full, 
    test_size=12000, 
    random_state=42 # for reproducibility
)

# print(f"Shape of x_train: {x_train.shape}")
# print(f"Shape of x_val: {x_val.shape}")
# print(f"Shape of x_test: {x_test.shape}")

# Normalize the pixel values to be between 0 and 1.
# This helps the neural network learn more efficiently.
def process_data(x):
    x = x.astype('float32') / 255.0
    # Reshape the data to fit the CNN input format (batch, height, width, channels).
    # MNIST images are grayscale, so they have 1 channel.
    return np.expand_dims(x, axis=-1)

x_train = process_data(x_train)
x_val = process_data(x_val)
x_test = process_data(x_test)

print(f"Shape of x_train: {x_train.shape}")
print(f"Shape of x_val: {x_val.shape}")
print(f"Shape of x_test: {x_test.shape}")

## Show random images

In [None]:
num_images_slider = mo.ui.slider(start=1, stop=20, value=10, label="#### Number of images to display", step=1)

In [None]:
# --- Plotting function with dynamic grid ---
def plot_images(num_images, x_train, y_train):
    """
    Plots a dynamic grid of random images from the dataset.
    """

    # Select random indices to display
    random_indices = np.random.choice(x_train.shape[0], size=num_images, replace=False)

    # Determine the grid size dynamically
    cols = min(num_images, 5)
    rows = int(np.ceil(num_images / cols))

    # Create the figure and display the images
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))

    # Handle the case of a single subplot
    if num_images == 1:
        axes = np.array([axes])

    axes = axes.ravel()

    for i, idx in enumerate(random_indices):
        axes[i].imshow(x_train[idx], cmap='gray')
        axes[i].set_title(f"Label: {y_train[idx]}")
        axes[i].axis('off')

    # Hide unused axes
    for j in range(num_images, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()

    # Display the plot in Marimo
    return fig

# The function will re-execute whenever `num_images_slider.value` changes

mo.md(
    f"""
    {num_images_slider}

    **Here's a preview of the images you've selected**:
    {mo.as_html(plot_images(num_images_slider.value, x_train, y_train))}

    """
)

Each element in the rectangular tensor corresponds to a number which represents a pixel intensity, as demonstrated by the following image:

<img src="https://s3-api.us-geo.objectstorage.softlayer.net/cf-courses-data/CognitiveClass/DL0110EN/notebook_images%20/chapter3/3.32_image_values.png" width="550" alt="MNIST elements" />

In [None]:
mo.md("---")
mo.md("## Model Definition & Training")

In [None]:
def create_cnn_model():
    """Defines and compiles the CNN model."""
    model = Sequential([
        Input(shape=(28, 28, 1)),
        Conv2D(32, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D((2, 2)),
        Dropout(0.25),
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(10, activation='softmax')
    ]) # The final layer: 10 neurons for the 10 digits.

    #  Compile the model with an optimizer, a loss function, and metrics.
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# Display the model's architecture summary.
model = create_cnn_model()
model.summary()

In [None]:
# epochs_slider = mo.ui.slider(
#     start=1, 
#     stop=20, 
#     value=5, 
#     label="Number of Epochs"
# )
# batch_size_slider = mo.ui.slider(
#     start=16, 
#     stop=164, 
#     value=64, 
#     step=16, 
#     label="Batch Size"
# )
# mo.md(
#     f"""
#     {epochs_slider}  
#     {batch_size_slider}
#     """
# )

In [None]:
# Train the model using the preprocessed data.
epochs = 20
batch_size = 128

# Define the EarlyStopping callback
# It will monitor 'val_loss' and stop if it doesn't improve for 3 consecutive epochs
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=3,
    restore_best_weights=True  # Restores the model weights from the epoch with the best value of the monitored quantity.
)

print("--- Starting Model Training ---")
print(f"--- epochs = {epochs}")
print(f"--- batchs = {batch_size} ==> iter : {len(x_train) / batch_size = }")
history = model.fit(x_train, y_train,
                    epochs=epochs, batch_size=batch_size,
                    validation_data=(x_val, y_val),
                    callbacks=[early_stopping],
                    verbose=2)

## Model Evaluation & Saving

In [None]:
# Evaluate the model on the unseen test data.
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f"Final Test Accuracy: {accuracy:.4f}")

plt.figure(figsize=(15,10))


# Plot the training and validation accuracy over the epochs to check for overfitting.
plt.subplot(2,1,1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot loss evolution
plt.subplot(2,1,2)
plt.plot(history.history['loss'], label='Training loss')
plt.plot(history.history['val_loss'], label='Validation loss')
plt.title('Model Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.legend()
plt.show()

In [None]:
# Make predictions on the test set
y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)

In [None]:
# Generate and print the classification report
print("\n--- Classification Report ---")
print(classification_report(y_test, y_pred))

In [None]:
cm = confusion_matrix(y_test, y_pred)

# Visualize the confusion matrix using a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Save the trained model to a file.
model.save('model/mnist_cnn_model.keras')
print("Model successfully saved as 'mnist_cnn_model.h5'.")