# Transfer Learning with Xception

## Import Libraries and Seed

In [None]:
import os
import random
from datetime import datetime

import splitfolders

import numpy as np
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.applications.xception import preprocess_input

In [None]:
# Tensorflow version control
tfk = tf.keras
tfkl = tf.keras.layers
print(tf.__version__)

In [None]:
# Setting seed for reproducibility
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

## Dataset Configuration

In [None]:
# Splitting the main dataset into train and val
dataset_dir = '../datasetNoTest'

if not(os.path.exists('../datasetNoTest')) :
    print('splitting')
    splitfolders.ratio('dataset', output='datasetNoTest', seed=seed, ratio=(0.8, 0.2))

# Setting dataset directories
training_dir = os.path.join(dataset_dir, 'train')
validation_dir = os.path.join(dataset_dir, 'val')

## Model Parameters and Classes Weights

In [None]:
# Labels of the dataset for classification
labels = ['Apple',              # 0
          'Blueberry',          # 1
          "Cherry",             # 2
          "Corn",               # 3
          "Grape",              # 4
          "Orange",             # 5
          "Peach",              # 6
          "Pepper",             # 7
          "Potato",             # 8
          "Raspberry",          # 9
          "Soybean",            # 10
          "Squash",             # 11
          "Strawberry",         # 12
          "Tomato"]             # 13

In [None]:
# Input Parameters
img_w = 256
img_h = 256
input_shape = (256, 256, 3)
classes = 14

# Training Parameters
epochs = 90
batch_size = 16
reg_rate = 0.001

# Fine Tuning Parameters
fine_tuning = True
last_nonTrainable_layer = 86

# Earlystopping Parameters
early_stopping = False
patience_epochs = 9

In [None]:
# This calculate the weights for all the classes
# by counting the number of images for each class
# and dividing by the number of total images
category_weight = {}
elements_per_class = {}

for i in range(classes):
    category_weight[i] = 0.0

for i in range(classes):
    elements_per_class[i] = 0

_, classes_directories, _ = next(os.walk(training_dir))

for img_class in classes_directories:
    class_dir = training_dir + '/' + str(img_class)
    _, _, files = next(os.walk(class_dir))
    elements_per_class[labels.index(img_class)] = len(files)

total_images = sum(elements_per_class.values())

for i in category_weight.keys():
    category_weight[i] = total_images / (classes * elements_per_class[i])

## Data Augmentation

In [None]:
train_data_gen = ImageDataGenerator(rotation_range=20,
                                        height_shift_range=0.3,
                                        width_shift_range=0.4,
                                        zoom_range=0.4,
                                        horizontal_flip=True,
                                        vertical_flip=True, 
                                        brightness_range=[0.3,1.4],
                                        fill_mode='nearest',
                                        preprocessing_function=preprocess_input)

train_gen = train_data_gen.flow_from_directory(directory=training_dir,
                                               target_size=(256,256),
                                               color_mode='rgb',
                                               classes=labels,
                                               class_mode='categorical',
                                               batch_size=batch_size,
                                               shuffle=True,
                                               seed=seed)

In [None]:
valid_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)

valid_gen = train_data_gen.flow_from_directory(directory=validation_dir,
                                               target_size=(256,256),
                                               color_mode='rgb',
                                               classes=labels,
                                               class_mode='categorical',
                                               batch_size=batch_size,
                                               shuffle=False,
                                               seed=seed)

## Xception Model Setting

In [None]:
# Download and plot the Xception model
supernet = tfk.applications.Xception(
    include_top=False,
    weights="imagenet",
    input_shape=input_shape
)

# Setting the trainable layers to execute fine tuning
# if fine_tuning is True, than the model can be customized
# and the layers can be trained
#
# if fine_tuning is False, none of the layers can be trained
supernet.trainable = fine_tuning

for i, layer in enumerate(supernet.layers[:last_nonTrainable_layer]):
  layer.trainable=False

for i, layer in enumerate(supernet.layers):
   print(i, layer.name, layer.trainable)

## Callbacks

In [None]:
# Utility function to create folders and callbacks for training

def create_folders_and_callbacks(model_name) :
    
    exps_dir = os.path.join('data_augmentation_experiments')
    if not os.path.exists(exps_dir):
        os.makedirs(exps_dir)

    now = datetime.now().strftime('%b%d_%H-%M-%S')
    
    exp_dir = os.path.join(exps_dir, model_name + '_' + str(now))
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
      
    callbacks = []


    # Model checkpoint ---------------------------------------------------
    ckpt_dir = os.path.join(exp_dir, 'ckpts')
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    
    ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_dir, 'cp'), # filepath is where I want to save the model
                                                     save_weights_only=False, # save only the weights ora all the model
                                                     save_best_only=True) # if True saves only the results of the best epoch                                                              
    callbacks.append(ckpt_callback)


    # Visualize Learning on Tensorboard ----------------------------------
    tb_dir = os.path.join(exp_dir, 'tb_logs') # logs where we save the events, where the tensorboard will read the logs
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)
    
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tb_dir,
                                              profile_batch=0,
                                              histogram_freq=1)
    callbacks.append(tb_callback)


    # Early Stopping -----------------------------------------------------
    if early_stopping:
        es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience_epochs, restore_best_weights=True)
        callbacks.append(es_callback)

    
    return callbacks

## Complete Model

In [None]:
# Build the complete model
# puttin first the Xception net
# then a GAP layer, a Dense with LeakyReLU regularization
# and finally the output layer
inputs = tfk.Input(shape=(256,256,3))

x = supernet(inputs)

# Global Average Pooling Layer -----------------------------------------------------------
glob_pooling = tfkl.GlobalAveragePooling2D(name='GlobalPooling')(x)


# Dense Layer -----------------------------------------------------------
x = tfkl.Dense(
    512,
    kernel_initializer = tfk.initializers.GlorotUniform(seed)
)(glob_pooling)

leaky_relu_layer = tfkl.LeakyReLU()(x)

x = tfkl.Dropout(0.3)(leaky_relu_layer)


# Output Layer -----------------------------------------------------------
outputs = tfkl.Dense(
    14, 
    activation='softmax',
    kernel_initializer = tfk.initializers.GlorotUniform(seed)
)(x)


model = tfk.Model(inputs=inputs, outputs=outputs, name='model')


model.compile(
    loss=tfk.losses.CategoricalCrossentropy(),
    optimizer=tfk.optimizers.Adam(learning_rate=1e-5),
    metrics=['accuracy', tf.metrics.Precision(), tf.metrics.Recall()]
)
model.summary()

## Training

In [None]:
callbacks = create_folders_and_callbacks(model_name='transferLearningModel')

history = model.fit(
    x = train_gen,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = valid_gen,
    class_weight = category_weight,
    callbacks = callbacks
).history

model.save("transferLearningModel")

## Some Nice Graphs

In [None]:
# All the metrics : Accuracy, Precision and Recall
ALPHA = 0.5

plt.figure(figsize=(20,10))

plt.plot(history['accuracy'], label='Accuracy Train', alpha=ALPHA, color='#E64A19')
plt.plot(history['val_accuracy'], label='Accuracy Val', alpha=ALPHA, color='#F57C00')

plt.plot(history['precision'], label='Precision Train', alpha=ALPHA, color='#388E3C')
plt.plot(history['val_precision'], label='Precision Val', alpha=ALPHA, color='#689F38')

plt.plot(history['recall'], label='Recall Train', alpha=ALPHA, color='#303F9F')
plt.plot(history['val_recall'], label='Recall Val', alpha=ALPHA, color='#1976D2')

plt.ylim(.5, 1)
plt.title('Metrics')
plt.legend(loc='lower right')
plt.grid(alpha=.3)
plt.show()

In [None]:
# Accuracy Graph
plt.figure(figsize=(20,10))

plt.plot(history['accuracy'], label='Accuracy Train', alpha=ALPHA, color='#E64A19')
plt.plot(history['val_accuracy'], label='Accuracy Val', alpha=ALPHA, color='#F57C00')

plt.ylim(.5, 1)
plt.title('Accuracy')
plt.legend(loc='lower right')
plt.grid(alpha=.3)
plt.show()

In [None]:
# Precision Graph
plt.figure(figsize=(20,10))

plt.plot(history['precision'], label='Precision Train', alpha=ALPHA, color='#388E3C')
plt.plot(history['val_precision'], label='Precision Val', alpha=ALPHA, color='#689F38')

plt.ylim(.5, 1)
plt.title('Precision')
plt.legend(loc='lower right')
plt.grid(alpha=.3)
plt.show()

In [None]:
# Recall Graph
plt.figure(figsize=(20,10))

plt.plot(history['recall'], label='Recall Train', alpha=ALPHA, color='#303F9F')
plt.plot(history['val_recall'], label='Recall Val', alpha=ALPHA, color='#1976D2')

plt.ylim(.5, 1)
plt.title('Recall')
plt.legend(loc='lower right')
plt.grid(alpha=.3)
plt.show()

In [None]:
# Loss Graph
plt.figure(figsize=(15,10))

plt.plot(history['loss'], label='Loss Train', alpha=ALPHA, color='#ff7f0e')
plt.plot(history['val_loss'], label='Loss Val', alpha=ALPHA, color='#4D61E2')

plt.ylim(0, 4)
plt.title('Loss')
plt.legend(loc='upper right')
plt.grid(alpha=.3)
plt.show()