Finally, we will learn how to use pretrained models in keras.

The usage of pretrained models is important since it often allow us to get a model with a "better start" than training from scratch.

The stereotypical pretraining dataset is called ImageNet, which is a dataset of 14M images over 1k classes.

![](imgs/imagenet_.jpg)

We will use 2 models, ResNet50 and ResNet101.

We will train ResNet50 on MNIST and showcase ResNet101 for feature extraction.

In [41]:
import keras
from keras.models import Sequential
from keras.layers import Dense, GlobalAveragePooling2D

from keras.applications.resnet import ResNet50, ResNet101, preprocess_input

import tensorflow as tf
import numpy as np


ResNet50:

In [42]:
cnn_pretrained = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

MNIST shape is 28 x 28; however, the pretrained model requires exactly 3 channels and a minimum size of 32 x 32. We will need to fix the mnist dataset to abide to these characteristics to use the pretrained model. More on that later.

The model:

In [43]:
cnn = Sequential()
cnn.add(cnn_pretrained)
cnn.add(GlobalAveragePooling2D())
cnn.add(Dense(10, activation='softmax'))

**DIY**: fix the MNIST dataset and train this neural network for 1 or 2 epochs.

**Feature extraction**

The follwing cells showcase how to operate fetaure extraction. With our super small dataset, it will probably not work...

In [None]:
img_size = (224, 224)

cnn_pretrained = ResNet101(weights='imagenet', include_top=False, input_shape=img_size + (3,))

In [None]:
dataset = keras.utils.image_dataset_from_directory(
    "dataset",
    labels = "inferred",
    batch_size = 32,
    image_size = img_size,
    color_mode = "rgb",
    interpolation = "bilinear"
)

In [None]:
def preprocess_batch(batch_images, batch_labels):
    batch_images = preprocess_input(batch_images)  # This applies ImageNet normalization
    return batch_images, batch_labels

dataset = dataset.map(preprocess_batch)

In [None]:
all_images = []
all_labels = []

# Iterate through the dataset to collect all images and labels
for images, labels in dataset:
    all_images.append(images.numpy())
    all_labels.append(labels.numpy())

all_images = np.concatenate(all_images, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

In [None]:
predictions = cnn_pretrained.predict(dataset)

In [None]:
predictions.shape

In [None]:
predictions = predictions.reshape(predictions.shape[0], -1)
predictions.shape

In [None]:
from sklearn.svm import SVC

svc = SVC()
svc.fit(predictions, all_labels)

In [None]:
test_image = keras.utils.load_img(
    "extra_dataset/ginger-cat.jpg", target_size=img_size
)
test_image = keras.utils.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
test_image = preprocess_input(test_image)

prediction = cnn_pretrained.predict(test_image)
prediction = prediction.reshape(prediction.shape[0], -1)
print("Predicted class:", svc.predict(prediction)[0])

In [None]:
test_image = keras.utils.load_img(
    "extra_dataset/golden-retriever.jpg", target_size=img_size
)
test_image = keras.utils.img_to_array(test_image)
test_image = np.expand_dims(test_image, axis=0)
test_image = preprocess_input(test_image)

prediction = cnn_pretrained.predict(test_image)
prediction = prediction.reshape(prediction.shape[0], -1)
print("Predicted class:", svc.predict(prediction)[0])