In [10]:
import tensorflow as tf

# Load data.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [None]:
import numpy as np

# Explore data range

print(x_train.shape)

first_sample = x_train[0]

print(f"Max value: {np.max(first_sample)}")
print(f"Min value: {np.min(first_sample)}")

print(f"Data type: {first_sample.dtype}")


In [None]:
# Explore data.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 1))

for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(x_train[i], cmap='gray')
    plt.title(f"{y_train[i]}")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# Reshape data.

x_train_scaled = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test_scaled = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

print(f"Min input value after scaling: {np.min(x_train_scaled[0])}")
print(f"Max input value after scaling: {np.max(x_train_scaled[0])}")

y_train_categorical = tf.keras.utils.to_categorical(y_train)
y_test_categorical = tf.keras.utils.to_categorical(y_test)

print(f"Shape of y after one-hot encoding: {y_train_categorical.shape}")
print(f"Each sample look like this: {y_train_categorical[0]}")

In [None]:
from tensorflow.keras import models, layers

# Build the model.

model = models.Sequential([
    layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model.

model.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

# Train the model.

history = model.fit(x_train_scaled, y_train_categorical, epochs=5, batch_size=64, validation_split=0.1)

In [None]:

# Evaluate the model.

loss, accuracy = model.evaluate(x_test_scaled, y_test_categorical)


In [None]:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()