# Image classification with transfer learning

**Data Pre-processing**

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

#Importing dataset from CIFAR
from tensorflow.keras.datasets import cifar10
from sklearn.model_selection import train_test_split


# Load CIFAR-10 dataset and print shapes
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step
x_train shape: (50000, 32, 32, 3)
y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3)
y_test shape: (10000, 1)


**Pre-processing classes (y)**

In [2]:
# convert classes into categories with one hot encoding and check shape.
from tensorflow.keras.utils import to_categorical

num_classes = 10

y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

print("y_train_cat shape:", y_train.shape)
print("y_test_cat shape:", y_test.shape)

y_train_cat shape: (50000, 10)
y_test_cat shape: (10000, 10)


**Pre-processing images**

to make them compatible with the shapes and scales in our base model (MobileNetV2)


In [4]:
# Reshaping image to be compatible with base model. In batches, to avoid crashing.

def resize_in_batches_cpu(images, new_size=(160, 160), batch_size=1000):
    resized_batches = []
    with tf.device('/CPU:0'):  # force CPU to avoid GPU OOM
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            batch_resized = tf.image.resize(batch, new_size).numpy()
            resized_batches.append(batch_resized)
    return np.concatenate(resized_batches, axis=0)

x_train_resized = resize_in_batches_cpu(x_train)
x_test_resized = resize_in_batches_cpu(x_test)

print("Resized shapes:", x_train_resized.shape, x_test_resized.shape)

Resized shapes: (50000, 160, 160, 3) (10000, 160, 160, 3)


In [7]:
# This Keras function is specifically designed to make images compatible with the MobileNetV2 model.
# It involves scaling and shifting pixel values to [-1, 1], as the model was trained on.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

# Preprocess your data
x_train = preprocess_input(x_train_resized)
x_test = preprocess_input(x_test_resized)


In [None]:
# Since we are using a Transfer learning technique, we need to merge features and labels into datasets, so that they can be processed correctly at a later stage.

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

# Create the base model from the pre-trained convnets

In [None]:
# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = (160,160,3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')



#Freeze the convolutional base to prevent the weights from being updated during training.
base_model.trainable = False



### Building and training transfer learning model

**Apply Global Average Pooling**
to convert feature maps (from base model) to vectors

In [None]:
# First we need to create batched datasets. We will make this as an isolated variable so that we can tweak it to fine-tune the model if needed.
BATCH_SIZE = 32

train_dataset = train_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)



**Add a classification head**

In [None]:
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)

# Apply Global Average Pooling to convert feature maps to vectors
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(f"Feature batch average shape: {feature_batch_average.shape}")


# Classification head: In our case, it will be 10 neurons dense since this is the number of categories.
# Using Softmax as the activation function since this is the ideal choice for classification problems.
prediction_layer = tf.keras.layers.Dense(10, activation='softmax')
prediction_batch = prediction_layer(feature_batch_average)
print(f"Prediction batch shape: {prediction_batch.shape}")



**Now to the model**

In [None]:
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])

In [None]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

model.summary()

len(model.trainable_variables)
tf.keras.utils.plot_model(model, show_shapes=True)

## Compile and fit the model

In [None]:
# Step 4: Compile the model
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',  # Using one-hot encoded labels
    metrics=['accuracy']
)


# Step 5: Train the model
history = model.fit(
    train_dataset,
    epochs=10,
    validation_data=test_dataset
)

# Step 6: Evaluate on test set
test_loss, test_accuracy = model.evaluate(test_dataset)
print(f"\nTest accuracy: {test_accuracy:.4f}")

# Step 7: Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
##block of code used to clear keras sessions during development

from keras.backend import clear_session
clear_session()