In [None]:
import warnings

import matplotlib.pyplot as plt
import seaborn as sns
from keras.datasets import mnist
from keras.utils import to_categorical, disable_interactive_logging
from keras.layers import Flatten, Dense
from keras.models import Sequential
from keras.optimizers import Adam

# Suppress specific warning message
warnings.filterwarnings(
    action='ignore',
    message='The name*',
)
warnings.filterwarnings("ignore")
disable_interactive_logging();

In [None]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Resize images to match the required input
train_images_reshaped = train_images.reshape(train_images.shape[0], 28, 28)
test_images_reshaped = test_images.reshape(test_images.shape[0], 28, 28)

# Normalize pixel values to be between 0 and 1
train_images_resized = train_images_reshaped.astype('float32') / 255
test_images_resized = test_images_reshaped.astype('float32') / 255

# One-hot encode labels
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# Get model
model = Sequential([
  Flatten(input_shape=(28, 28)),
  Dense(128, activation='relu'),
  Dense(10)
])

# Compile model
model.compile(
    optimizer=Adam(0.001),
    loss='categorical_crossentropy',  # Use categorical_crossentropy for one-hot encoded labels
    metrics=['accuracy']  # Use 'accuracy' instead of SparseCategoricalAccuracy
)

# Fit model
history = model.fit(
    x=train_images_resized,
    y=train_labels,  # Use one-hot encoded labels
    epochs=2,
    validation_data=(test_images_resized, test_labels)  # Use validation_data as a tuple
)


In [None]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

# Launch TensorBoard and navigate to the Profile tab to view performance profile
%tensorboard --logdir=logs --port=6009

# !!! Got message like 'Reusing TensorBoard on port XXXX' - change port !!!

In [None]:
# Plot train loss
sns.lineplot(
    x=range(1, len(history.history['loss']) + 1),
    y=history.history['loss'],
    label='Train'
)

# Plot validation loss
sns.lineplot(
    x=range(1, len(history.history['val_loss']) + 1),
    y=history.history['val_loss'],
    label='Test'
)

plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')
# plt.show();

# Evaluate the model
test_loss, test_acc = model.evaluate(test_images_resized, test_labels)
print('Test accuracy:', test_acc)