In [None]:
!pip install tensorflow-datasets
!pip install tensorflow-addons

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
import matplotlib.pyplot as plt

# Fine-tune AlexNet trained on Imagenette with COCO

Here we experiment with transfer learning and fine-tuning. We use our Alexnet implementation trained on the
Imagenette dataset and use it for transfer learning with COCO.
- Take the self-implemented AlexNet model trained on the Imagenette dataset
- Freeze its layers
- Add additional Dense layers on top for multi-class classification
- Train the additional layers on COCO (transfer learning)
- Unfreeze the layers
- Train the whole model further, with a very low learning rate (fine-tuning)

In [2]:
experiment = 'finetune_alexnet_imagenette_on_coco'

In [None]:
# Load COCO 2017 dataset with annotations
base_dir = '..'
dataset, info = tfds.load('coco/2017', with_info=True, data_dir=f'{base_dir}/data/tensorflow_datasets')

In [4]:
from src.algonauts.data_processors.coco_dataset import create_datasets_from_coco
from src.algonauts.data_processors.image_transforms import transform_alexnet

batch_size = 32
num_classes = 80  # number of classes in COCO dataset

# Get training and validation datasets
train_ds, val_ds = create_datasets_from_coco(dataset, num_classes, transform_alexnet, batch_size)

## Load pretrained architecture and change the last layers to fine-tune the model or freeze the layers and train only the last layers

In [5]:
from src.algonauts.models.model_loaders import load_from_file
from src.algonauts.feature_extractors.tf_feature_extractor import slice_model
# Load the alexnet model
model_filename = f'{base_dir}/data/models/alexnet_imagenette.h5'  # model trained for 16 epochs with early stopping
model_loader = lambda: load_from_file(model_filename, transform_alexnet)
base_model, _ = model_loader()

In [6]:
print(*(layer.name for layer in base_model.layers), sep=' -> ')

conv2d_1 -> conv2d_1_bn -> conv2d_1_pool -> conv2d_2 -> conv2d_2_bn -> conv2d_2_pool -> conv2d_3 -> conv2d_3_bn -> conv2d_4 -> conv2d_4_bn -> conv2d_5 -> conv2d_5_bn -> conv2d_5_pool -> flatten -> dense -> dropout -> dense_1 -> dropout_1 -> dense_2


In [7]:
base_model = slice_model(base_model, 'conv2d_5_pool')

In [8]:
print(*(layer.name for layer in base_model.layers), sep=' -> ')

conv2d_1_input -> conv2d_1 -> conv2d_1_bn -> conv2d_1_pool -> conv2d_2 -> conv2d_2_bn -> conv2d_2_pool -> conv2d_3 -> conv2d_3_bn -> conv2d_4 -> conv2d_4_bn -> conv2d_5 -> conv2d_5_bn -> conv2d_5_pool


In [9]:
# Freeze the base model
base_model.trainable = False

# Add top layers for multi-label classification
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(2048, activation='relu', name='dense_additional1')(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation='relu', name='dense_additional2')(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='sigmoid')(x)

# Create the final model
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)

# Print the layer names
print(*(layer.name for layer in model.layers), sep=' -> ')

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
precision = tf.keras.metrics.Precision(name='precision')
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=[precision])

conv2d_1_input -> conv2d_1 -> conv2d_1_bn -> conv2d_1_pool -> conv2d_2 -> conv2d_2_bn -> conv2d_2_pool -> conv2d_3 -> conv2d_3_bn -> conv2d_4 -> conv2d_4_bn -> conv2d_5 -> conv2d_5_bn -> conv2d_5_pool -> global_average_pooling2d -> dense_additional1 -> dropout -> dense_additional2 -> dropout_1 -> dense


## Define callbacks

Here we use the following callbacks:
- Early stopping to stop after 3 epochs if loss does not improve more than 0.001
- Checkpoint to save the model every epoch if accuracy has improved
- Tensorboard callback to write logs, which can be loaded later for comparison

In [10]:
import datetime


current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=0,
    verbose=0,
    mode="min",
    baseline=None,
    restore_best_weights=False
)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{base_dir}/data/out/checkpoints/{experiment}.h5',
    monitor='val_precision',
    mode='max',
    verbose=1,
    save_best_only=True)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"{base_dir}/data/out/training_logs/{experiment}/{current_time}")

callbacks = [early_stopping, tensorboard_callback, checkpoint]

## Transfer learning

Train the freezed model with added layers using the COCO dataset. Further fine-tuning will be applied later.

In [None]:
history = model.fit(train_ds, epochs=10, validation_data=val_ds, callbacks=callbacks)
model.save(f'{base_dir}/data/models/alexnet_imagenette_transfer_coco.h5')

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot training & validation accuracy values
plt.plot(history.history['precision'])
plt.plot(history.history['val_precision'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()


Unfreeze and train further with COCO

In [None]:
base_model.trainable = True

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=[precision])

history = model.fit(train_ds, epochs=10, validation_data=val_ds, callbacks=callbacks)
model.save(f'{base_dir}/data/models/alexnet_imagenette_finetune_coco.h5')

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot training & validation accuracy values
plt.plot(history.history['precision'])
plt.plot(history.history['val_precision'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()