Transfer learning is the process in which a base model is taken and the final prediction layer is replaced to specify the new objective.

In this example, we take the same model created for our computer vision and add an additional layer to detect whether the photo orientation is properly set

In [9]:
import tensorflow
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, GlobalAveragePooling2D

import tensorflow.keras
from tensorflow.keras.applications.resnet import ResNet50

num_classes = 2
resnet_weights_path = None

my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights=resnet_weights_path))
my_new_model.add(Dense(num_classes, activation='softmax'))

# Say not to train first layer (ResNet) model. It is already trained
my_new_model.layers[0].trainable = False

#### Compile the model

In [10]:
my_new_model.compile(optimizer='sgd', 
                     loss='categorical_crossentropy', 
                     metrics=['accuracy'])

#### Fit the model

In [17]:
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import pathlib

image_size = 4
data_generator = ImageDataGenerator(preprocess_input)

train_generator = data_generator.flow_from_directory(
                                        directory=str(pathlib.Path().absolute())+'/ressources',
                                        target_size=(image_size, image_size),
                                        batch_size=4,
                                        class_mode='categorical')

validation_generator = data_generator.flow_from_directory(
                                        directory=str(pathlib.Path().absolute())+'/ressources',
                                        target_size=(image_size, image_size),
                                        class_mode='categorical')

# fit_stats below saves some statistics describing how model fitting went
# the key role of the following line is how it changes my_new_model by fitting to data
#fit_stats = my_new_model.fit_generator(train_generator,
#                                       steps_per_epoch=1,
#                                       validation_data=validation_generator,
#                                       validation_steps=0)

Found 0 images belonging to 0 classes.
Found 0 images belonging to 0 classes.


## Data Augmentation

A great trick that is commonly used is editing the images to increase the quantity of data. For example, in this example, the orientation isn't affected by the content of the image. For this reason we can take the mirror image of each of our photos and double our quantity of data.

In [18]:
# Specify the values for all arguments to data_generator_with_aug.
data_generator_with_aug = ImageDataGenerator(preprocessing_function=preprocess_input,
                                              horizontal_flip = True,
                                              width_shift_range = 0.1,
                                              height_shift_range = 0.1)
            
data_generator_no_aug = ImageDataGenerator(preprocessing_function=preprocess_input)