# Transfer learning

In this notebook, we will look at transfer learning. We saw how long it might take to train a model for classifying cats and dogs from the ground up. Let us instead try to utilize an existing model already trained for image classification to improve our classification of cats and dogs. Throughout this notebook, we will use the `VGG16` model for such, but there are many other newer and better models, of course. 

The code is heavily inspired by the Chapter 8 notebook from the book "Deep learning with Python" by François Chollet, available [here](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/chapter08_intro-to-dl-for-computer-vision.ipynb).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import keras

from tensorflow import keras
from tensorflow.keras import layers

In [None]:
from tensorflow.keras.utils import image_dataset_from_directory

train_dataset = image_dataset_from_directory(
    "dogs_vs_cats_tiny/train",
    image_size=(180, 180),
    batch_size=32)
validation_dataset = image_dataset_from_directory(
    "dogs_vs_cats_tiny/validation",
    image_size=(180, 180),
    batch_size=32)
test_dataset = image_dataset_from_directory(
    "dogs_vs_cats_tiny/test",
    image_size=(180, 180),
    batch_size=32)

## Feature extraction once

We fill first define or load a version of the VGG16 mode. Note, the first time, the code is run, the VGG16 model is downloaded. The `weights` argument just specify the particular version we want and the `include_top?` that we will not include the top dense layers as we want to add our own.

In [None]:
conv_base = keras.applications.vgg16.VGG16(
    weights="imagenet",
    include_top=False,
    input_shape=(180, 180, 3))
conv_base.summary()

Let us now pass our images through this convolutional base to extract a new feature representation:

In [None]:
def get_features_and_labels(dataset):
    all_features = []
    all_labels = []
    for images, labels in dataset:
        preprocessed_images = keras.applications.vgg16.preprocess_input(images)
        features = conv_base.predict(preprocessed_images)
        all_features.append(features)
        all_labels.append(labels)
    return np.concatenate(all_features), np.concatenate(all_labels)

train_features, train_labels =  get_features_and_labels(train_dataset)
val_features, val_labels =  get_features_and_labels(validation_dataset)
test_features, test_labels =  get_features_and_labels(test_dataset)

print("Shape of training features:", train_features.shape)
print("Shape of training labels:", train_labels.shape)

We can now train a dense deep neural network on this new feature representation

In [None]:
model_dense = keras.Sequential(
    [
        keras.Input(shape=(5, 5, 512)),
        layers.Flatten(),
        layers.Dense(256),
        layers.Dropout(0.5),
        layers.Dense(1, activation="sigmoid")
    ]
)

model_dense.summary()

In [None]:
model_dense.compile(loss="binary_crossentropy",
              optimizer="rmsprop",
              metrics=["accuracy"])

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint(
      filepath="feature_extraction_once.keras",
      save_best_only=True,
      monitor="val_loss")
]

In [None]:
history = model_dense.fit(
    train_features, train_labels,
    epochs=20,
    validation_data=(val_features, val_labels),
    callbacks=callbacks)

Note how fast it went, as we only have to train a few dense layers! Let us plot the training and validation accuracy and loss.

In [None]:
import matplotlib.pyplot as plt
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, "bo", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

We quite quickly got good performance and started to overfit! This is probably also due to the fact that the cats and dog images we are training on here, were likely also part of the ImageNet dataset that VGG16 was trained on, so the base model might already have seen these training images and learned them. Still we get a really god performance also on the testset:

In [None]:
test_model = keras.models.load_model("feature_extraction_once.keras")
test_loss, test_acc = test_model.evaluate(test_features, test_labels)
print(f"Test accuracy: {test_acc:.3f}")

## Feature extraction integrated in the learning step

We will now set up a complete model with the convolutional base together with a new classifier of dense layers on top. We will then freeze the parameters of the convolutional base and train the entire model on our data. In this case, we also have the ability to add augmentation first.

In [None]:
conv_base = keras.applications.vgg16.VGG16(
    weights="imagenet",
    include_top=False,
    input_shape=(180, 180, 3))

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
    ]
)

In [None]:
inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs)
x = keras.applications.vgg16.preprocess_input(x)
x = conv_base(x)
x = layers.Flatten()(x)
x = layers.Dense(256)(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model_dense_learning = keras.Model(inputs, outputs)
model_dense_learning.summary()

Let us freeze the convolutional base.

In [None]:
conv_base.trainable = False
model_dense_learning.summary()

In [None]:
model_dense_learning.compile(loss="binary_crossentropy",
              optimizer="rmsprop",
              metrics=["accuracy"])

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint(
      filepath="feature_extraction_learning.keras",
      save_best_only=True,
      monitor="val_loss")
]

We can now train the model. However, this will take quite some time, as we have to parse the training data through the entire model and we also have additional data augmenatin now.

In [None]:
history = model_dense_learning.fit(
    train_dataset,
    epochs=30,
    validation_data=validation_dataset,
    callbacks=callbacks)

In [None]:
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, "bo", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

In [None]:
test_model = keras.models.load_model("feature_extraction_learning.keras")
test_loss, test_acc = test_model.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.3f}")

## Fine-tuning

We can now take the model just trained and un freeze a few of the parameters in the convolutional base, and then re-train the entire model. Let us first unfreeze the mentioned parameters.

In [None]:
conv_base.trainable = True
for layer in conv_base.layers[:-4]:
    layer.trainable = False

Before we can retrain the model, we need to compile it again.

In [None]:
model_dense_learning.compile(loss="binary_crossentropy",
              optimizer=keras.optimizers.RMSprop(learning_rate=1e-5),
              metrics=["accuracy"])

In [None]:
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath="fine_tuning.keras",
        save_best_only=True,
        monitor="val_loss")
]

In [None]:
history = model_dense_learning.fit(
    train_dataset,
    epochs=30,
    validation_data=validation_dataset,
    callbacks=callbacks)

In [None]:
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, "bo", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

In [None]:
model = keras.models.load_model("fine_tuning.keras")
test_loss, test_acc = model.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.3f}")