In [1]:
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet import preprocess_input
from tensorflow.keras.callbacks import ModelCheckpoint
from os import path, walk

In [2]:
# Constants
TRAIN_DIR = path.join('Dataset', 'Train')
VAL_DIR = path.join('Dataset', 'Val')
MODEL_DIR = path.join('ModelDir', 'MobileNetV2', 'TransferLearning')
NUM_EPOCHS = 25

In [3]:
# Get the names of the classes
class_names = []
for subdir, dirs, files in walk(TRAIN_DIR):
    try:
        class_names.append(subdir.split('\\')[2])
    except:
        pass
num_classes = len(class_names)
class_names

['HDMI', 'USB-A']

In [4]:
# Create the model and create new last layer
model = MobileNetV2(input_shape=(224, 224, 3), include_top=False)
model.trainable = False
model = Sequential([model, Flatten(), Dense(num_classes, activation='softmax', name="out")])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
mobilenetv2_1.00_224 (Functi (None, 7, 7, 1280)        2257984   
_________________________________________________________________
flatten (Flatten)            (None, 62720)             0         
_________________________________________________________________
out (Dense)                  (None, 2)                 125442    
Total params: 2,383,426
Trainable params: 125,442
Non-trainable params: 2,257,984
_________________________________________________________________


In [5]:
# Initialize the training and validation generators
train_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
train_generator = train_data_gen.flow_from_directory(TRAIN_DIR,
                                                     target_size=(224,224),
                                                     color_mode='rgb',
                                                     batch_size=16,
                                                     class_mode='categorical',
                                                     shuffle=True)

val_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)
val_generator = val_data_gen.flow_from_directory(VAL_DIR,
                                                     target_size=(224,224),
                                                     color_mode='rgb',
                                                     batch_size=16,
                                                     class_mode='categorical',
                                                     shuffle=True)

Found 84 images belonging to 2 classes.
Found 102 images belonging to 2 classes.


In [6]:
# Compile the model and the checkpoint to save only the best
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
checkpoint = ModelCheckpoint(MODEL_DIR, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

In [7]:
# Train the model and save only the best
model.fit(train_generator, epochs=NUM_EPOCHS, validation_data=val_generator, callbacks=[checkpoint])

Epoch 1/25
Epoch 00001: val_accuracy improved from -inf to 0.81373, saving model to ModelDir\MobileNetV2\TransferLearning
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ModelDir\MobileNetV2\TransferLearning\assets
Epoch 2/25
Epoch 00002: val_accuracy did not improve from 0.81373
Epoch 3/25
Epoch 00003: val_accuracy did not improve from 0.81373
Epoch 4/25
Epoch 00004: val_accuracy did not improve from 0.81373
Epoch 5/25
Epoch 00005: val_accuracy did not improve from 0.81373
Epoch 6/25
Epoch 00006: val_accuracy improved from 0.81373 to 0.88235, saving model to ModelDir\MobileNetV2\TransferLearning
INFO:tensorflow:Assets written to: ModelDir\MobileNetV2\TransferLearning\assets
Epoch 7/25
Epoch 00007: val_accuracy did not improve from 0.88235
Epoch 8/25
Epoch 00008: val_a

KeyboardInterrupt: 