# Transfer learning and fine-tuning with pretrained ResNet-50

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
import math
from tensorflow import keras
from keras import layers
import numpy as np

The usual pipeline of a transfer-learning & fine-tuning workflow. See https://keras.io/guides/transfer_learning/.
1. Append trainable layers to a pretrained foundation model.
2. Freeze the base model, train the last layer until convergence.
3. Unfreeze the base model, train the whole model with very small learning rate.

## Data loading and preprocessing

In [2]:
data_train = tf.keras.utils.image_dataset_from_directory("Processed_Split/train", labels='inferred', image_size=(224, 224), batch_size=32)
data_val = tf.keras.utils.image_dataset_from_directory("Processed_Split/val", labels='inferred', image_size=(224, 224), batch_size=32)
data_test = tf.keras.utils.image_dataset_from_directory("Processed_Split/test", labels='inferred', image_size=(224, 224), batch_size=32)

Found 45828 files belonging to 50 classes.
Found 9438 files belonging to 50 classes.
Found 9504 files belonging to 50 classes.


## Data preprocessing

In [3]:
# Use the batch method to prepare batches.
data_train = data_train.map(lambda x, y: (tf.keras.applications.resnet50.preprocess_input(x), y))
data_val = data_val.map(lambda x, y: (tf.keras.applications.resnet50.preprocess_input(x), y))
data_test = data_test.map(lambda x, y: (tf.keras.applications.resnet50.preprocess_input(x), y))

## Model configuration

In [4]:
num_classes = 50

inputs = layers.Input(shape=(224, 224, 3))
resnet = keras.applications.resnet50.ResNet50(include_top=False, weights="imagenet")
resnet.trainable = False
x = resnet(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = layers.Flatten()(x)
outputs = keras.layers.Dense(num_classes, activation="softmax")(x)
model = keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss="sparse_categorical_crossentropy",
    metrics="accuracy")

## Model training (freeze ResNet)

In [5]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', mode="max", patience=4, verbose=1, baseline=0.0, restore_best_weights=True) # monitor validation loss, stop training if loss stops decreasing
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', mode="max", factor=0.5, patience=2, min_lr=0.000003125)

log = model.fit(x=data_train, epochs=20, validation_data=data_val, callbacks=[early_stopping, reduce_lr])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [6]:
model.evaluate(data_test)
model.save("freeze_resnet")





INFO:tensorflow:Assets written to: freeze_resnet/assets


INFO:tensorflow:Assets written to: freeze_resnet/assets


In [None]:
data_test_fixed_order = tf.keras.utils.image_dataset_from_directory("Processed_Split/test", labels='inferred', image_size=(224, 224), batch_size=32, shuffle=False).map(lambda x, y: (tf.keras.applications.resnet50.preprocess_input(x), y))
y_pred = model.predict(data_test_fixed_order)
y_hat = np.argmax(y_pred, axis=1)
y_true = np.concatenate([y for x, y in data_test_fixed_order], axis=0)

In [9]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
print("Accuracy: ", accuracy_score(y_true, y_hat))
print("Precision: ", precision_score(y_true, y_hat, average='weighted'))
print("Recall: ", recall_score(y_true, y_hat, average='weighted'))
print("F1: ", f1_score(y_true, y_hat, average='weighted'))

Accuracy:  0.9446548821548821
Precision:  0.9449983750663737
Recall:  0.9446548821548821
F1:  0.9433189415173883


## Model training (unfreeze ResNet)

In [10]:
resnet.trainable = True
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss="sparse_categorical_crossentropy",
    metrics="accuracy")

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', mode="max", patience=2, verbose=1, baseline=0.0, restore_best_weights=True) # monitor validation loss, stop training if loss stops decreasing
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', mode="max", factor=0.5, patience=1, min_lr=0.000003125)

log = model.fit(x=data_train, epochs=10, validation_data=data_val, callbacks=[early_stopping, reduce_lr])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 7: early stopping


In [11]:
model.evaluate(data_test)
model.save('resnet')





INFO:tensorflow:Assets written to: resnet/assets


INFO:tensorflow:Assets written to: resnet/assets


## Evaluation metrics on test data

In [13]:
trained_model = tf.keras.models.load_model('resnet')

In [19]:
data_test_fixed_order = tf.keras.utils.image_dataset_from_directory("Processed_Split/test", labels='inferred', image_size=(224, 224), batch_size=32, shuffle=False).map(lambda x, y: (tf.keras.applications.resnet50.preprocess_input(x), y))
y_pred = trained_model.predict(data_test_fixed_order)
y_hat = np.argmax(y_pred, axis=1)
y_true = np.concatenate([y for x, y in data_test_fixed_order], axis=0)



In [23]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
print("Accuracy: ", accuracy_score(y_true, y_hat))
print("Precision: ", precision_score(y_true, y_hat, average='weighted'))
print("Recall: ", recall_score(y_true, y_hat, average='weighted'))
print("F1: ", f1_score(y_true, y_hat, average='weighted'))

Accuracy:  0.9638047138047138
Precision:  0.9640554083361087
Recall:  0.9638047138047138
F1:  0.962832199476148


In [22]:
import os
from sklearn.metrics import classification_report
print(classification_report(y_true, y_hat, digits=4, target_names=sorted(os.listdir("Processed_Split/test"))))

                                    precision    recall  f1-score   support

                 Amylax_triacantha     1.0000    0.5000    0.6667         4
           Aphanizomenon_flosaquae     0.9765    0.9924    0.9844      1049
       Aphanothece_paralleliformis     0.8333    1.0000    0.9091         5
                             Beads     1.0000    1.0000    1.0000        20
                      Centrales_sp     0.8625    0.9583    0.9079        72
             Ceratoneis_closterium     0.8333    0.6250    0.7143         8
                    Chaetoceros_sp     0.9700    0.9327    0.9510       208
             Chaetoceros_sp_single     0.9697    0.9697    0.9697        33
                    Chlorococcales     0.7857    0.7333    0.7586        15
                     Chroococcales     0.9412    0.7273    0.8205        22
                 Chroococcus_small     0.9535    0.9840    0.9685       125
                           Ciliata     0.9714    0.9189    0.9444        37
           