In [2]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv2D, BatchNormalization, MaxPool2D, SpatialDropout2D,
    Flatten, Dense, Dropout
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split

# Load and preprocess data
(x_train, labels_train), (x_test, labels_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
y_train = tf.keras.utils.to_categorical(labels_train, 10)

# Split into train/validation
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, test_size=0.1, random_state=42
)

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1
)
datagen.fit(x_train)

# Improved model
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    BatchNormalization(),
    Conv2D(32, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPool2D((2,2)),
    SpatialDropout2D(0.2),
    
    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPool2D((2,2)),
    SpatialDropout2D(0.2),
    
    Flatten(),
    Dense(256, activation='relu', kernel_regularizer='l2'),
    BatchNormalization(),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy'),
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6)
]

# Train with augmented data
history = model.fit(
    datagen.flow(x_train, y_train, batch_size=256),
    validation_data=(x_val, y_val),
    epochs=50,
    callbacks=callbacks
)

# Save the best model
model.save('improved_mnist_cnn.h5')

  self._warn_if_super_not_called()


Epoch 1/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 492ms/step - accuracy: 0.7295 - loss: 3.8845



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 507ms/step - accuracy: 0.7301 - loss: 3.8776 - val_accuracy: 0.1090 - val_loss: 5.2893 - learning_rate: 0.0010
Epoch 2/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 484ms/step - accuracy: 0.9502 - loss: 0.5440



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 496ms/step - accuracy: 0.9502 - loss: 0.5435 - val_accuracy: 0.9300 - val_loss: 0.4351 - learning_rate: 0.0010
Epoch 3/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 483ms/step - accuracy: 0.9615 - loss: 0.2812



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 496ms/step - accuracy: 0.9615 - loss: 0.2812 - val_accuracy: 0.9870 - val_loss: 0.1778 - learning_rate: 0.0010
Epoch 4/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 487ms/step - accuracy: 0.9662 - loss: 0.2450



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 500ms/step - accuracy: 0.9662 - loss: 0.2449 - val_accuracy: 0.9882 - val_loss: 0.1649 - learning_rate: 0.0010
Epoch 5/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 500ms/step - accuracy: 0.9735 - loss: 0.2165 - val_accuracy: 0.9877 - val_loss: 0.1674 - learning_rate: 0.0010
Epoch 6/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 488ms/step - accuracy: 0.9734 - loss: 0.2148



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 501ms/step - accuracy: 0.9734 - loss: 0.2148 - val_accuracy: 0.9887 - val_loss: 0.1629 - learning_rate: 0.0010
Epoch 7/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 488ms/step - accuracy: 0.9747 - loss: 0.2057



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 501ms/step - accuracy: 0.9747 - loss: 0.2056 - val_accuracy: 0.9892 - val_loss: 0.1526 - learning_rate: 0.0010
Epoch 8/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 490ms/step - accuracy: 0.9756 - loss: 0.2020



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 504ms/step - accuracy: 0.9756 - loss: 0.2020 - val_accuracy: 0.9907 - val_loss: 0.1522 - learning_rate: 0.0010
Epoch 9/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 501ms/step - accuracy: 0.9765 - loss: 0.1991 - val_accuracy: 0.9902 - val_loss: 0.1475 - learning_rate: 0.0010
Epoch 10/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 500ms/step - accuracy: 0.9788 - loss: 0.1879 - val_accuracy: 0.9897 - val_loss: 0.1483 - learning_rate: 0.0010
Epoch 11/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 489ms/step - accuracy: 0.9811 - loss: 0.1767



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 501ms/step - accuracy: 0.9811 - loss: 0.1768 - val_accuracy: 0.9917 - val_loss: 0.1595 - learning_rate: 0.0010
Epoch 12/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 500ms/step - accuracy: 0.9792 - loss: 0.1851 - val_accuracy: 0.9913 - val_loss: 0.1447 - learning_rate: 0.0010
Epoch 13/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 501ms/step - accuracy: 0.9800 - loss: 0.1805 - val_accuracy: 0.9905 - val_loss: 0.1613 - learning_rate: 0.0010
Epoch 14/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 506ms/step - accuracy: 0.9810 - loss: 0.1806 - val_accuracy: 0.9902 - val_loss: 0.1468 - learning_rate: 0.0010
Epoch 15/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 493ms/step - accuracy: 0.9806 - loss: 0.1776



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 506ms/step - accuracy: 0.9806 - loss: 0.1776 - val_accuracy: 0.9918 - val_loss: 0.1371 - learning_rate: 0.0010
Epoch 16/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 496ms/step - accuracy: 0.9819 - loss: 0.1695



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 508ms/step - accuracy: 0.9819 - loss: 0.1695 - val_accuracy: 0.9925 - val_loss: 0.1368 - learning_rate: 0.0010
Epoch 17/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 497ms/step - accuracy: 0.9828 - loss: 0.1682



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 510ms/step - accuracy: 0.9828 - loss: 0.1682 - val_accuracy: 0.9938 - val_loss: 0.1336 - learning_rate: 0.0010
Epoch 18/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m108s[0m 509ms/step - accuracy: 0.9819 - loss: 0.1738 - val_accuracy: 0.9910 - val_loss: 0.1376 - learning_rate: 0.0010
Epoch 19/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 540ms/step - accuracy: 0.9828 - loss: 0.1685 - val_accuracy: 0.9892 - val_loss: 0.1502 - learning_rate: 0.0010
Epoch 20/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 532ms/step - accuracy: 0.9829 - loss: 0.1704 - val_accuracy: 0.9922 - val_loss: 0.1368 - learning_rate: 0.0010
Epoch 21/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 518ms/step - accuracy: 0.9870 - loss: 0.1437



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 531ms/step - accuracy: 0.9870 - loss: 0.1436 - val_accuracy: 0.9942 - val_loss: 0.0777 - learning_rate: 2.0000e-04
Epoch 22/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 534ms/step - accuracy: 0.9898 - loss: 0.0888 - val_accuracy: 0.9930 - val_loss: 0.0623 - learning_rate: 2.0000e-04
Epoch 23/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 529ms/step - accuracy: 0.9906 - loss: 0.0695 - val_accuracy: 0.9942 - val_loss: 0.0569 - learning_rate: 2.0000e-04
Epoch 24/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 530ms/step - accuracy: 0.9902 - loss: 0.0685 - val_accuracy: 0.9935 - val_loss: 0.0589 - learning_rate: 2.0000e-04
Epoch 25/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 533ms/step - accuracy: 0.9916 - loss: 0.0633 - val_accuracy: 0.9938 - val_loss: 0.0564 - learning_rate: 2.0000e-04
Epoch 26/50
[1m211/211[0m [32m━━━━



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 528ms/step - accuracy: 0.9903 - loss: 0.0677 - val_accuracy: 0.9950 - val_loss: 0.0538 - learning_rate: 2.0000e-04
Epoch 27/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 529ms/step - accuracy: 0.9913 - loss: 0.0627 - val_accuracy: 0.9938 - val_loss: 0.0555 - learning_rate: 2.0000e-04
Epoch 28/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 530ms/step - accuracy: 0.9901 - loss: 0.0665 - val_accuracy: 0.9948 - val_loss: 0.0512 - learning_rate: 2.0000e-04
Epoch 29/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 528ms/step - accuracy: 0.9916 - loss: 0.0606 - val_accuracy: 0.9938 - val_loss: 0.0522 - learning_rate: 2.0000e-04
Epoch 30/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m110s[0m 519ms/step - accuracy: 0.9912 - loss: 0.0643 - val_accuracy: 0.9940 - val_loss: 0.0502 - learning_rate: 2.0000e-04
Epoch 31/50
[1m211/211[0m [32m━━━━



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 528ms/step - accuracy: 0.9909 - loss: 0.0611 - val_accuracy: 0.9952 - val_loss: 0.0490 - learning_rate: 2.0000e-04
Epoch 34/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 514ms/step - accuracy: 0.9914 - loss: 0.0642



[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 527ms/step - accuracy: 0.9914 - loss: 0.0642 - val_accuracy: 0.9953 - val_loss: 0.0490 - learning_rate: 2.0000e-04
Epoch 35/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 525ms/step - accuracy: 0.9910 - loss: 0.0640 - val_accuracy: 0.9948 - val_loss: 0.0489 - learning_rate: 2.0000e-04
Epoch 36/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 527ms/step - accuracy: 0.9916 - loss: 0.0600 - val_accuracy: 0.9938 - val_loss: 0.0511 - learning_rate: 2.0000e-04
Epoch 37/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 525ms/step - accuracy: 0.9913 - loss: 0.0626 - val_accuracy: 0.9933 - val_loss: 0.0517 - learning_rate: 2.0000e-04
Epoch 38/50
[1m211/211[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m111s[0m 523ms/step - accuracy: 0.9910 - loss: 0.0620 - val_accuracy: 0.9932 - val_loss: 0.0528 - learning_rate: 2.0000e-04
Epoch 39/50
[1m211/211[0m [32m━━━━

