In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
import matplotlib.pyplot as plt

In [None]:
# Load COCO 2017 dataset with annotations
base_dir = '../..'
dataset, info = tfds.load('coco/2017', with_info=True, data_dir=f'{base_dir}/data/tensorflow_datasets')

In [None]:
from src.algonauts.data_processors.coco_dataset import create_datasets_from_coco
from src.algonauts.data_processors.image_transforms import transform_vgg16

batch_size = 32
num_classes = 80  # number of classes in COCO dataset
data_dir = '/notebooks/tensorflow_datasets'

# Get training and validation datasets
train_ds, val_ds = create_datasets_from_coco(dataset, num_classes, transform_vgg16, batch_size)

## Load pretrained architecture and change the last layers to fine-tune the model or freeze the layers and train only the last layers (VGG16 example)

In [None]:
# Load the VGG16 model without the top layers
base_model = tf.keras.applications.VGG16(weights="imagenet", include_top=False, input_shape=(224, 224, 3))

# Freeze the base model
base_model.trainable = False

# Add top layers for multi-label classification
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(2048, activation='relu', name='dense_additional1')(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation='relu', name='dense_additional2')(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='sigmoid')(x)

# Create the final model
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)

# Print the layer names
print(*(layer.name for layer in model.layers), sep=' -> ')

optimiser = tf.keras.optimizers.Adam(learning_rate=1e-5)
f1_score = tfa.metrics.F1Score(num_classes=num_classes, average='macro')
# Compile the model
model.compile(optimizer=optimiser, loss='binary_crossentropy', metrics=[f1_score])

## Train, save and plot loss/accuracy on the training and validation set

In [None]:
history = model.fit(train_ds, epochs=10, validation_data=val_ds)
model.save(f'{base_dir}/data/models/vgg16_imagenet_trained_on_coco.h5')

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot training & validation accuracy values
plt.plot(history.history[f1_score.name])
plt.plot(history.history[f'val_{f1_score.name}'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()
