In [None]:
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint
from os import path, walk
import matplotlib.pyplot as plt

In [None]:
# Constants
TRAIN_DIR = path.join('Dataset', 'Train')
VAL_DIR = path.join('Dataset', 'Val')
MODEL_DIR = path.join('ModelDir', 'MobileNetV2', 'TransferLearning')
NUM_EPOCHS = 25

In [None]:
# Get the names of the classes
class_names = []
for subdir, dirs, files in walk(TRAIN_DIR):
    try:
        class_names.append(subdir.split('\\')[2])
    except:
        pass
num_classes = len(class_names)
class_names

In [None]:
# Create the model and create new last layer
model = MobileNetV2(input_shape=(224, 224, 3), include_top=False)
model.trainable = False
model = Sequential([model, Flatten(), Dense(num_classes, activation='softmax', name="out")])
model.summary()

In [None]:
# Initialize the training and validation generators
train_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input,
                                    shear_range=10, 
                                     horizontal_flip=True, 
                                     vertical_flip=True, 
                                     rotation_range=45, 
                                     brightness_range=[0.5,1],
                                     channel_shift_range=50)

train_generator = train_data_gen.flow_from_directory(TRAIN_DIR,
                                                     target_size=(224,224),
                                                     color_mode='rgb',
                                                     batch_size=20,
                                                     class_mode='categorical',
                                                     shuffle=True)

val_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
val_generator = val_data_gen.flow_from_directory(VAL_DIR,
                                                     target_size=(224,224),
                                                     color_mode='rgb',
                                                     batch_size=20,
                                                     class_mode='categorical',
                                                     shuffle=True)

In [None]:
# For experimenting with image augmentation
exp_gen = ImageDataGenerator()

exp_gen = exp_gen.flow_from_directory(TRAIN_DIR,
                                      color_mode='rgb',
                                      batch_size=1,
                                      class_mode='categorical',
                                      shuffle=False)

In [None]:
img_arr = exp_gen.next()
plt.imshow(img_arr[0][0]/255)

In [None]:
# Compile the model and the checkpoint to save only the best
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
checkpoint = ModelCheckpoint(MODEL_DIR, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

In [None]:
# Train the model and save only the best
model.fit(train_generator, epochs=NUM_EPOCHS, validation_data=val_generator, callbacks=[checkpoint])