In [1]:
import os
from random import sample
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Must be set before importing TF to supress messages
os.environ["CUDA_VISIBLE_DEVICES"]= '3'

import tensorflow as tf
from tensorflow.keras.callbacks import CSVLogger
import numpy as np
from utils.loader import DataLoader
from utils.tools import test_model
from utils.data_sampler import CustomDataGenerator, CustomIterator
from utils.configs import config
from typing import List

def load_VGG_model(img_height: int, img_width: int, lr: int, loss: tf.keras.losses.Loss, metrics: List[str], trainable: True) -> tf.keras.Model:
    """ Loads VGG-16 model.

    Args:
        img_height (int): Image height.
        img_width (int): Image width.
        lr (int): Learning rate.
        loss (tf.keras.losses.Loss): Model loss.
        metrics (List[str]): Training metrics.
        trainable (True): Set if model weights should be kept frozen or not.

    Returns:
        tf.keras.Model: TensorFlow VGG-16 model.
    """
    model = tf.keras.applications.vgg16.VGG16(input_shape=(img_height, img_width, 3))
    model.trainable = trainable
    model.compile(optimizer=tf.keras.optimizers.Adam(lr, epsilon=0.1),
                loss=loss,
                metrics=metrics)

    return model

def train_model(model: tf.keras.Model, train_set: CustomIterator, val_set: CustomIterator, epochs: int, batch_size: int, callbacks=None):
    """ Train the model. 

    Args:
        train_set (CustomIterator): Training data.
        val_set (CustomIterator): Validation data.
        epochs (int): Number of epochs to train for.
        callbacks (_type_, optional): Callbacks for loggers. Defaults to None.

    Returns:
        history: Model training history information.
    """
    history = model.fit(train_set, validation_data=val_set, epochs=epochs, steps_per_epoch=train_set.n//batch_size, validation_steps=val_set.n//batch_size, verbose=1, callbacks=callbacks)

    return history

In [4]:
# Set configs
img_height = 224
img_width = 224
batch_size = 64
epochs = 10
lr = 3e-5
log_path = os.path.join(config['logs_path'], 'vgg_training_new.csv')

# Set augmentation and pre-processing
train_datagen = CustomDataGenerator(
                horizontal_flip=True,
                validation_split=0.1,
                preprocessing_function=tf.keras.applications.vgg16.preprocess_input, dtype=tf.float32)
test_datagen = CustomDataGenerator(
                preprocessing_function=tf.keras.applications.vgg16.preprocess_input, dtype=tf.float32)

# Load ImageNet dataset with the VGG augmentation
loader = DataLoader(batch_size, (img_height, img_width))
train_set = loader.load_train_set(aug_train=train_datagen, class_mode='categorical', shuffle=True)
val_set = loader.load_val_set(aug_val=train_datagen, class_mode='categorical', shuffle=True)
test_set = loader.load_test_set(aug_test=test_datagen, set_batch_size=False)

train_set.set_subsampling(200000)

Loading test set...
Found 48238 images belonging to 1000 classes.


In [3]:
model = load_VGG_model(img_height=img_height, img_width=img_width, lr=lr, loss=tf.keras.losses.CategoricalCrossentropy(), metrics=['accuracy'], trainable=True)

In [7]:
# Train and use CSV logger to store logs
if not os.path.exists(os.path.join(config['logs_path'], 'vgg_training_neww.csv')):
    with open(os.path.join(config['logs_path'], 'vgg_training_neww.csv'), "w") as my_empty_csv: pass

csv_logger = CSVLogger(os.path.join(config['logs_path'], 'vgg_training_neww.csv'), separator=',', append=False)
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
train_history = train_model(model=model, train_set=train_set, val_set=val_set, epochs=5, batch_size=batch_size, callbacks=[csv_logger, early_stop])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
 717/3125 [=====>........................] - ETA: 29:51 - loss: 0.7279 - accuracy: 0.8011



Epoch 5/5


In [6]:
model = tf.keras.models.load_model('vgg_trained')

In [7]:
test_model(model, test_set)


Predicting on test-set...
Computing accuracy...

-----------------------------------------
Model Accuracy on test-set: 0.6941622787014387
-----------------------------------------



In [8]:
model.save('vgg_trained')