# Models Exploration using CUB dataset

## References
* [Transfer Learning with Hub](https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub)
* [`tf.keras.utils.image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/image_dataset_from_directory)
* [Limiting GPU Memory Growth](https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth)

## Setup

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import datetime
import os

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from tensorflow.keras.preprocessing.image import ImageDataGenerator

import tensorflow_hub as hub
from keras.utils.layer_utils import count_params


In [2]:
def limit_memory_growth(limit=True):
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, limit)
            logical_gpus = tf.config.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

In [3]:
limit_memory_growth()

1 Physical GPUs, 1 Logical GPUs


## Utility

In [4]:
def plot_predictions(
    image_batch,
    predicted_class_names,
):
    plt.figure(figsize=(10,9))
    plt.subplots_adjust(hspace=0.5)
    for n in range(30):
        plt.subplot(6,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(predicted_class_names[n])
        plt.axis('off')
    _ = plt.suptitle("Predictions")

def plot_images(
    ds,
    class_names,
):
    plt.figure(figsize=(10, 10))
    for images, labels in ds.take(1):
        for i in range(9):
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(class_names[labels[i]])
            plt.axis("off")
    
def get_timestamp():
    return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

## Enumerate Datasets to test

In [5]:
import pathlib

flowers_dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
flowers_data_dir = tf.keras.utils.get_file('flower_photos', origin=flowers_dataset_url, untar=True)
flowers_data_dir = pathlib.Path(flowers_data_dir)

datasets = [
    '/mnt/cub/CUB_200_2011/images',
    flowers_data_dir,
]

## Dataset

In [6]:
def build_dataset(
    data_dir = '/mnt/cub/CUB_200_2011/images',
    batch_size = 64,
    image_size = (299,299),
    preprocess_input = None,
    # normalization = True,
):
   
    '''
    train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        batch_size = batch_size,
        validation_split = 0.2,
        image_size = image_size,
        subset = "both",
        shuffle = True, # default but here for clarity
        seed=42,
    )
    '''
    datagen = ImageDataGenerator(
        preprocessing_function = preprocess_input,
        validation_split=0.2,
        rotation_range=10,
        width_shift_range=0.1,
        height_shift_range=0.1,
        shear_range=0.15,
        zoom_range=0.1,
        channel_shift_range=10.,
        horizontal_flip=True,
    )
    
    train_ds = datagen.flow_from_directory(
        data_dir,
        target_size=(299,299),
        batch_size=batch_size,
        subset='training',
        shuffle=True,
    )

    val_ds = datagen.flow_from_directory(
        data_dir,
        target_size=(299,299),
        batch_size=batch_size,
        subset='validation',
        shuffle=True,
    )

    

    '''

    # Use model specific preprocessing function
    if preprocess_input:
        train_ds.map(lambda x, y: (preprocess_input(x), y))
        val_ds.map(lambda x, y: (preprocess_input(x), y))
    else:
        # normalization_layer = layers.Rescaling(
        #     1./255,
        #     name="normalization_layer",
        # )
        # train_ds.map(lambda x, y: (normalization_layer(x)-0.5, y))
        # val_ds.map(lambda x, y: (normalization_layer(x)-0.5, y))
        pass
    '''
    
    # Retrieve number of classes
    # (can't do this after converting to PrefetchDataset)
    #num_classes = len(train_ds.class_names)
    num_classes = 200
    
    # print(num_classes) # 200
    
    # Prefetch images
    # train_ds = train_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    # val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    
    
    return (train_ds, val_ds, num_classes)

## Enumerate Models to test

In [7]:
base_models_metadata = [
    # ('https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', 224),
    # ('https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4', 299),
    # ('https://tfhub.dev/google/inaturalist/inception_v3/feature_vector/5', 299),
    (tf.keras.applications.Xception, 299, tf.keras.applications.xception.preprocess_input),
    # (tf.keras.applications.resnet.ResNet101, 224),
    # (tf.keras.applications.ResNet50, 224),
    # (tf.keras.applications.InceptionResNetV2, 299),
    # (tf.keras.applications.efficientnet_v2.EfficientNetV2B0, 224)
]

def get_model_name( model_handle ):
    
    if callable(model_handle):
        return f'keras.applications.{model_handle.__name__}'
    else:
        split = model_handle.split('/')
        return f'{split[-5]}.{split[-4]}.{split[-3]}'
    

## Model Building

In [8]:
# Print model weight counts
def print_weight_counts(model):
    print(f'Full Model - Non-trainable weights: {count_params(model.non_trainable_weights)}')
    print(f'Full Model - Trainable weights: {count_params(model.trainable_weights)}')

def build_base_model_layer(
    model_handle,
    name="base_model_layer",
):
    if callable(model_handle):
        base_model_layer = model_handle(
            include_top=False,
            weights='imagenet',
            pooling = 'avg',
        )
        base_model_layer.trainable = False
    else:
        base_model_layer = hub.KerasLayer(
            model_handle,
            name=name,
            trainable = False, # default but here for clarity
        )
        
    # Print Base model weights
    print("Base Model:")
    print_weight_counts(base_model_layer)
    print()
    
    return base_model_layer

def build_model(
    base_model_metadata,
    dropout,
    num_classes = 200,
):
    model_handle, input_dimension, preprocess_input = base_model_metadata

    model = Sequential([
        # layers.Lambda(preprocess_input),
        build_base_model_layer(
            model_handle,
        ),
        layers.Dense(
            num_classes,
            # activation = 'softmax',
        ),
        layers.Dropout(dropout),
        layers.Activation("softmax", dtype="float32"),
    ])
    
    
    # Print weight counts
    print("Full Model:")
    print_weight_counts(model)
    print()
    
    return model

## Build and run all models

In [9]:
# Hyperparameters
batch_size = 64
max_epochs = 15
dropout = 0.4
learning_rate = 0.0005

# Directory for logs
base_log_dir = "models_cub_logs"

# for each base model
for base_model_metadata in base_models_metadata:
    
    model_handle, input_dimension, preprocess_input = base_model_metadata

    image_size = (input_dimension, input_dimension)
    
    # Build dataset/pipeline
    train_ds, val_ds, num_classes = build_dataset(
        datasets[0],
        batch_size = batch_size,
        image_size = image_size,
        preprocess_input = preprocess_input,
    )
    
    # plot_images(train_ds, )

    # Build model
    model = build_model(
        base_model_metadata,
        dropout,
        num_classes,
    )
    
    # Compile model
    model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate),
        # loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        loss=tf.keras.losses.CategoricalCrossentropy(
            # from_logits=True,
        ),
        metrics=[
            'accuracy',
            # tf.keras.metrics.SparseCategoricalAccuracy(),
            # tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
            # tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name="Top3"),
            # tf.keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="Top10"),
        ],
    )
    
    # Logging
    model_id = get_model_name(model_handle)
    log_dir = os.path.join( base_log_dir, model_id )
    
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=log_dir,
        histogram_freq=1,
    )
    
    # Early stopping
    early_stopping_callback = tf.keras.callbacks.EarlyStopping(
        # monitor='val_sparse_categorical_accuracy',
        monitor='accuracy',
        patience=5,
        min_delta=0.001,
    ),
    
    print()
    print(model_id)
    
    # Train
    model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=max_epochs,
        callbacks=[
            tensorboard_callback,
            early_stopping_callback,
        ]
    )
    
    # Save model
    # model.save(os.path.join(log_dir, 'final_model' ))    
    

Found 9465 images belonging to 200 classes.
Found 2323 images belonging to 200 classes.
Base Model:
Full Model - Non-trainable weights: 20861480
Full Model - Trainable weights: 0

Full Model:
Full Model - Non-trainable weights: 20861480
Full Model - Trainable weights: 409800


keras.applications.Xception
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


KeyboardInterrupt

