In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Load MNIST dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize the data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

# Reshape the data
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

# Convert class vectors to binary class matrices for digit classification (0-9)
num_classes_digits = 10
y_train_digits = tf.keras.utils.to_categorical(y_train, num_classes_digits)
y_test_digits = tf.keras.utils.to_categorical(y_test, num_classes_digits)

# Create and train the original MNIST model (if not already done)
mnist_model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(num_classes_digits, activation='softmax')
])

mnist_model.compile(optimizer='adam',
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

# Train the model (comment out if you've already trained and saved the model)
mnist_model.fit(X_train, y_train_digits, epochs=5, validation_split=0.1)

# Save the model (comment out if you've already saved the model)
mnist_model.save('mnist_cnn_model.h5')

# Now, let's create a new task: classifying odd vs even numbers
y_train_odd_even = (y_train % 2).astype(np.int32)
y_test_odd_even = (y_test % 2).astype(np.int32)

# Convert to categorical
num_classes_odd_even = 2
y_train_odd_even = tf.keras.utils.to_categorical(y_train_odd_even, num_classes_odd_even)
y_test_odd_even = tf.keras.utils.to_categorical(y_test_odd_even, num_classes_odd_even)

# Load the pre-trained MNIST model
base_model = tf.keras.models.load_model('mnist_cnn_model.h5')

# Create a new model for odd/even classification
odd_even_model = models.Sequential()

# Add the convolutional layers from the MNIST model
for layer in base_model.layers[:-1]:  # Exclude the last dense layer
    odd_even_model.add(layer)

# Freeze the weights of the transferred layers
for layer in odd_even_model.layers:
    layer.trainable = False

# Add a new dense layer for odd/even classification
odd_even_model.add(layers.Dense(num_classes_odd_even, activation='softmax'))

# Compile the model
odd_even_model.compile(optimizer='adam',
                       loss='categorical_crossentropy',
                       metrics=['accuracy'])

# Print the model summary
odd_even_model.summary()

# Train the model
history = odd_even_model.fit(X_train, y_train_odd_even,
                             batch_size=128,
                             epochs=5,
                             validation_split=0.1,
                             verbose=1)

# Evaluate the model
test_loss, test_acc = odd_even_model.evaluate(X_test, y_test_odd_even, verbose=0)
print(f'Test accuracy: {test_acc:.4f}')

# Save the new model
odd_even_model.save('mnist_odd_even_transfer_learning_model.h5')

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 4ms/step - accuracy: 0.8895 - loss: 0.3558 - val_accuracy: 0.9863 - val_loss: 0.0507
Epoch 2/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9845 - loss: 0.0513 - val_accuracy: 0.9888 - val_loss: 0.0423
Epoch 3/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9895 - loss: 0.0344 - val_accuracy: 0.9885 - val_loss: 0.0357
Epoch 4/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 3ms/step - accuracy: 0.9918 - loss: 0.0267 - val_accuracy: 0.9907 - val_loss: 0.0359
Epoch 5/5
[1m1688/1688[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 3ms/step - accuracy: 0.9940 - loss: 0.0193 - val_accuracy: 0.9915 - val_loss: 0.0332




Epoch 1/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 6ms/step - accuracy: 0.7730 - loss: 0.7317 - val_accuracy: 0.9728 - val_loss: 0.0798
Epoch 2/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9752 - loss: 0.0702 - val_accuracy: 0.9835 - val_loss: 0.0535
Epoch 3/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9826 - loss: 0.0500 - val_accuracy: 0.9860 - val_loss: 0.0446
Epoch 4/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.9873 - loss: 0.0368 - val_accuracy: 0.9875 - val_loss: 0.0399
Epoch 5/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - accuracy: 0.9877 - loss: 0.0347 - val_accuracy: 0.9897 - val_loss: 0.0375




Test accuracy: 0.9885
