# Setup

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

# Introduction
- **Transfer learning** consists of taking features learned on one problem, and leveraging them on a new, similar problem. 
    - For instance, features from a model that has learned to identify racoons may be useful to kick-start a model meant to identify tanukis.
- Transfer learning is usually done for tasks where your dataset has too little data to train a full-scale model from scratch.
- The most common incarnation of transfer learning in the context of deep learning is the following worfklow:
    1. Take layers from a previously trained model
    2. Freeze them, so as to avoid destroying any of the information they contain during future training rounds
    3. Add some new, trainable layers on top of the frozen layers
        - They will learn to turn the old features into predictions on a new dataset.
    4. Train the new layers on your dataset
- A last, optional step, is **fine-tuning**, which consists of unfreezing the entire model you obtained above (or part of it), and re-training it on the new data with a very low learning rate. 
    - This can potentially achieve meaningful improvements, by incrementally adapting the pretrained features to the new data.

# Freezing layers: understanding the `trainable` attribute
- Layers & models have three weight attributes:
    - `weights` is the list of all weights variables of the layer.
    - `trainable_weights` is the list of those that are meant to be updated (via gradient descent) to minimize the loss during training.
    - `non_trainable_weights` is the list of those that aren't meant to be trained.
        - Typically, they are updated by the model during the forward pass.

In [2]:
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 2
non_trainable_weights: 0


- In general, all weights are trainable weights. 
- The only built-in layer that has non-trainable weights is the `BatchNormalization` layer. 
    - It uses non-trainable weights to keep track of the mean and variance of its inputs during training. 
- To learn how to use non-trainable weights in your own custom layers, see the guide to writing new layers from scratch.

In [3]:
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 4
trainable_weights: 2
non_trainable_weights: 2


- Layers & models also feature a boolean attribute `trainable`. 
- Setting `layer.trainable = False` moves all the layer's weights from trainable to non-trainable. 
    - This is called **"freezing" the layer**: the state of a frozen layer won't be updated during training (either when training with `fit()` or when training with any custom loop that relies on `trainable_weights` to apply gradient updates).

In [4]:
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 0
non_trainable_weights: 2


- When a trainable weight becomes non-trainable, its value is no longer updated during training.

In [5]:
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation='relu')
layer2 = keras.layers.Dense(3, activation='sigmoid')
model = keras.Sequential([
    keras.Input(shape=(3,)),
    layer1,
    layer2
])

In [6]:
# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

In [7]:
# Train the model
model.compile(optimizer='adam', loss='mse')
model.fit(np.random.random((2,3)), np.random.random((2, 3)))

Train on 2 samples


<tensorflow.python.keras.callbacks.History at 0x7fc99e5de910>

In [8]:
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()

np.testing.assert_allclose(initial_layer1_weights_values[0], final_layer1_weights_values[0])
np.testing.assert_allclose(initial_layer1_weights_values[1], final_layer1_weights_values[1])

- Note: do not confuse the `layer.trainable` attribute with the argument `training` in `layer.__call__()`, which controls whether the layer should run its forward pass in reference mode or training mode.

# Recursive setting of the `trainable` attribute
- If you set `trainable = False` on a model or on any layer that has sublayers, all children layers become non-trainable as well.

In [9]:
inner_model = keras.Sequential([
    keras.Input(shape=(3,)),
    keras.layers.Dense(3, activation='relu'),
    keras.layers.Dense(3, activation='relu')
])

model = keras.Sequential([
    keras.Input(shape=(3,)),
    inner_model,
    keras.layers.Dense(3, activation='sigmoid')
])

In [10]:
model.trainable = False # Freeze the outer model

In [11]:
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recurvisely

# The typical transfer-learning workflow
- This leads us to how a typical transfer learning workflow can be implemented in Keras:
    1. Instantiate a base model and load pre-trained weights into it
    2. Freeze all layers in the base model by setting `trainable = False`
    3. Create a new model on top of the output of one (or several) layers from the base model
    4. Train your new model on your new dataset
- Note that an alternative, more lightweight workflow could also be:
    1. Instantiate a base model and load pre-trained weights into it
    2. Run your new dataset through it and record the output of one (or several) layers from the base model 
        - This is called feature extraction.
    3. Use that output as input data for a new, smaller model
- A key advantage of that second workflow is that you only run the base model once one your data, rather than once per epoch of training. So it's a lot faster & cheaper.
- An issue with that second workflow, though, is that it doesn't allow you to dynamically modify the input data of your new model during training, which is required when doing data augmentation, for instance. 
- Transfer learning is typically used for tasks when your new dataset has too little data to train a full-scale model from scratch, and in such scenarios data augmentation is very important. So in what follows, we will focus on the first workflow.

In [12]:
# Instantiate a base model with pre-trained weights
base_model = keras.applications.Xception(
    weights='imagenet', # Load weights pre-trained on ImageNet
    input_shape=(150, 150, 3),
    include_top=False # Do not include the ImageNet classifier at the top
)

In [13]:
# Freeze the base model
base_model.trainable = False

In [14]:
# Create a new model on top
inputs = keras.Input(shape=(150, 150, 3))

# We make sure that the base_model is running in inference mode here by passing `training=False`.
X = base_model(inputs, training=False)

X = keras.layers.GlobalAveragePooling2D()(X)
outputs = keras.layers.Dense(1)(X)

model = keras.Model(inputs, outputs)

In [15]:
# Train the model on new data
model.compile(optimizer="adam", 
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# model.fit(new_dataset, epoch=20, callbacks=..., validation_data=...)

# Fine-tuning
- Once your model has converged on the new data, you can try to unfreeze all or part of the base model and retrain the whole model end-to-end with a very low learning rate.
    - This is an optional last step that can potentially give you incremental improvements. 
    - It could also potentially lead to quick overfitting -- keep that in mind.
- It is critical to only do this step after the model with frozen layers has been trained to convergence. 
    - If you mix randomly-initialized trainable layers with trainable layers that hold pre-trained features, the randomly-initialized layers will cause very large gradient updates during training, which will destroy your pre-trained features.
- It's also critical to use a very low learning rate at this stage, because you are training a much larger model than in the first round of training, on a dataset that is typically very small. 
    - As a result, you are at risk of overfitting very quickly if you apply large weight updates. 
    - Here, you only want to readapt the pretrained weights in an incremental way.

In [16]:
# Unfreeze the base model
base_model.trainable = True

# Remember to recompile model after make changes to the `trainable` attribute of any inner layer
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end
# Be careful to stop before overfitting
# model.fit(...)

**Important note about `compile()` and `trainable`**
- Calling `compile()` on a model is meant to "freeze" the behavior of that model. 
    - This implies that the `trainable` attribute values at the time the model is compiled should be preserved throughout the lifetime of that model, until `compile` is called again. 
    - Hence, if you change any trainable value, make sure to call `compile()` again on your model for your changes to be taken into account.

**Important notes about `BatchNormalization` layer**
- Many image models contain `BatchNormalization` layers. 
    - That layer is a special case on every imaginable count.
- Here are a few things to keep in mind.
    - `BatchNormalization` contains 2 non-trainable weights that get updated during training. These are the variables tracking the mean and variance of the inputs.
    - When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will run in inference mode, and will not update its mean & variance statistics. 
        - This is not the case for other layers in general, as weight trainability & inference/training modes are two orthogonal concepts. But the two are tied in the case of the `BatchNormalization` layer.
    - When you unfreeze a model that contains `BatchNormalization` layers in order to do fine-tuning, you should keep the `BatchNormalization` layers in inference mode by passing `training=False` when calling the base model. 
        - Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model the model has learned.

# Transfer learning & fine-tuning with a custom training loop
- If instead of `fit()` you are using your own low-level training loop, the workflow stays essentially the same.
- You should be careful to only take into account the list `model.trainable_weights` when applying gradient updates.

In [17]:
# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
X = base_model(inputs, training=False)
X = keras.layers.GlobalAveragePooling2D()(X)
outputs = keras.layers.Dense(1)(X)
model = keras.Model(inputs, outputs)

In [18]:
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

In [19]:
# # Iterate over the batches of a dataset
# for inputs, targets in new_dataset:
#     # Open a GradientTape
#     with tf.GradientTape() as tape:
#         predictions = model(inputs) # Forward pass
#         loss_value = loss_fn(targets, predictions) # Compute the loss value for this batch
        
#     # Get gradients of loss w.r.t. the trainable_weights
#     gradients = tape.gradient(loss_value, model.trainable_weights)
    
#     # Update the weights of the model
#     optimizer.apply_gradients(zip(gradients, model.trainable_weights))

# An end-to-end example: fine-tuning an image classification model on cats vs. dogs
- So many things have been updated, and the code from the guidance page no longer works.
- To view the code, visit:
https://www.tensorflow.org/guide/keras/transfer_learning