# Automating Transfer Learning: Gradual fine-tuning of a TensorFlow model produces: “ValueError: Unknown metric function: val_loss.” exception during the fit method of the fine-tuning stage.

## Overview
I’m attempting to automate transfer learning fine-tuning using iterative, gradual thawing of an underlying pre-trained and frozen base network. A compiled model containing the underlying, pre-trained, and frozen base network architecture is fed into an X4Learner class. The class exposes two methods: feature_extraction, and fine_tuning. The feature extraction method fits the model for a designated number of epochs and stores the last epoch as an instance variable. The fine_tuning method operates on the ‘feature_extracted’ model and performs an iterative fine-tuning process in which each iteration thaws some number of layers in the underlying base model, then recompiles it. The next fit step produces the following exception.

ValueError: Unknown metric function: val_loss. Please ensure this object is passed to the `custom_objects` argument.

The traceback is included at the end of the reproducible example.

## Research
Research into potential root causes has been complicated as the exception above has been associated with a host of unrelated problems from missing validation data to 

## The X4Learner Class Code
A simplified version of the transfer learning class follows. 

In [1]:
from typing import Union
import os

import numpy as np
import pandas as pd
import tensorflow as tf
# ------------------------------------------------------------------------------------------------ #
class X4LearnerLite:
    """Performs transfer learning of a TensorFlow model containing a pre-trained base model.

    Two methods are exposed: extract_features, and fine_tune. The extract_features method trains
    the model on the given data using the designated learning rate. The fine_tune method
    thaws one or more layers in the model, then trains it on a decayed learning rate. Each
    fine tuning session decays the learning rate by a learning_rate_decay factor to mitigate
    catastrophic forgetting.

    Args:
        model (tf.keras.Model): Model containing a frozen, pre-trained base model.
        train_ds (tf.data.Dataset): TensorFlow training dataset.
        val_ds (tf.data.Dataset): TensorFlow validation dataset.
        base_model_layer (int): Index for the base model layer for thawing.
        learning_rate (float): The learning rate for feature extraction. Default = 0.0001
        metric (str): The metric used to evaluate model fit performance. Default = 'val_loss'
        loss (str): The loss function. Default = 'binary_crossentropy'.
        activation (str): Activation function. Default = 'sigmoid'.

    """

    def __init__(
        self,
        model: tf.keras.Model,
        base_model_layer: int,
        train_ds: tf.data.Dataset,
        val_ds: tf.data.Dataset,
        learning_rate: float = 0.0001,
        metric: str = "val_loss",
        loss: str = "binary_crossentropy",
        activation: str = "sigmoid",
    ) -> None:
        self._model = model
        self._base_model_layer = base_model_layer

        self._train_ds = train_ds
        self._val_ds = val_ds

        self._learning_rate = learning_rate

        self._metric = metric
        self._loss = loss
        self._activation = activation
        # Used during the thawing process to determine number of layers to thaw as proportion of
        # total number of layers in the underlying base model.
        self._n_layers = len(self._model.layers[self._base_model_layer].layers)
        self._initial_epoch = None

    # ------------------------------------------------------------------------------------------------ #
    def extract_features(self, epochs: int = 5) -> None:
        """Performs the feature extraction phase of transfer learning

        Args:
            epochs (int): Number of epochs to execute
        """

        history = self._model.fit(
            self._train_ds,
            epochs=epochs,
            validation_data=self._val_ds,
        )

        # Save the last feature extraction epoch for fine tune phase
        self._initial_epoch = history.epoch[-1]

    # ------------------------------------------------------------------------------------------------ #
    def fine_tune(
        self,
        epochs: int = 10,
        sessions: int = 10,
        learning_rate_decay_factory: float = 0.1,
        thaw_rate: Union[float, int] = 0.05,
    ) -> None:
        """Performs iterative fine tuning using gradual unfreezing of the base model.

        Args:
            epochs (int): Number of epochs per session. Default = 10
            sessions (int): Number of fine tuning sessions to execute. Default is 10
            learning_rate_decay_factor (float): Factor by which the learning rate is reduced each session.
            thaw_rate (Union[float, int]): Rate by which layers are thawed. This can be a raw
                integer or a float proportion of base model layers. Default = 0.05.
        """
        session = 0
        learning_rate = self._learning_rate
        initial_epoch = self._initial_epoch

        while session < sessions:
            session += 1

            learning_rate *= learning_rate_decay_factory

            # Thaw the top n layers of the base model according to the following
            n = max(int(self._n_layers * thaw_rate * session),1)
            self._model.layers[self._base_model_layer].trainable = True
            for layer in self._model.layers[self._base_model_layer].layers[:-n]:
                layer.trainable = False

            self._model.compile(
                loss=self._loss,
                optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                metrics=[self._metric],
            )

            total_epochs = epochs + initial_epoch
            history = self._model.fit(
                self._train_ds,
                epochs=total_epochs,
                validation_data=self._val_ds,
                initial_epoch=initial_epoch,
            )

            initial_epoch = history.epochs[-1]


## Reproducible Example
This reproducible example borrows liberally from the [TensorFlow transfer learning and fine-tuning tutorial](https://www.tensorflow.org/tutorials/images/transfer_learning)

### Data Preprocessing
#### Data Download

In [2]:
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)

# Create a preprocessing layer for the Mobilenet architecture.
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

Found 2000 files belonging to 2 classes.
Found 1000 files belonging to 2 classes.


#### Configure Dataset for Performance

In [3]:
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)

### Modeling
#### Create Base Model from MobileNet V2

In [4]:
# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')



#### Freeze the Convolutional Base

In [5]:
base_model.trainable = False

#### Add Task-Specific Layers

In [6]:
# Add a classification heard
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
# Add Dense layer to convert features into single prediction per image.
prediction_layer = tf.keras.layers.Dense(1)

#### Create Final Model 

In [7]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = preprocess_input(inputs)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

#### Compile the Model

In [8]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='val_loss')])

#### Feature Extraction

In [9]:
x4l = X4LearnerLite(model=model, 
                    base_model_layer=3, 
                    train_ds=train_dataset, 
                    val_ds=validation_dataset, 
                    learning_rate=0.0001, 
                    metric="val_loss", 
                    loss="binary_crossentropy", 
                    activation="signmoid")

In [10]:
x4l.extract_features(epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


#### Fine Tuning

In [11]:
x4l.fine_tune(epochs=5, sessions=3, learning_rate_decay_factory=0.1, thaw_rate=0.1)

Epoch 5/9


ValueError: in user code:

    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/training.py", line 864, in train_step
        return self.compute_metrics(x, y, y_pred, sample_weight)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/training.py", line 957, in compute_metrics
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 438, in update_state
        self.build(y_pred, y_true)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 358, in build
        self._metrics = tf.__internal__.nest.map_structure_up_to(y_pred, self._get_metric_objects,
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 484, in _get_metric_objects
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 484, in <listcomp>
        return [self._get_metric_object(m, y_t, y_p) for m in metrics]
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 503, in _get_metric_object
        metric_obj = metrics_mod.get(metric)
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/metrics.py", line 4262, in get
        return deserialize(str(identifier))
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/metrics.py", line 4218, in deserialize
        return deserialize_keras_object(
    File "/home/john/anaconda3/envs/bcd/lib/python3.10/site-packages/keras/utils/generic_utils.py", line 709, in deserialize_keras_object
        raise ValueError(

    ValueError: Unknown metric function: val_loss. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.


## Environment and Versions


In [12]:
%load_ext watermark
%watermark -v -m -p tensorflow,keras

Python implementation: CPython
Python version       : 3.10.12
IPython version      : 8.14.0

tensorflow: 2.8.0
keras     : 2.8.0

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 5.15.133.1-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 24
Architecture: 64bit

