#### Import Libraries

In [None]:
import tensorflow as tf
from tensorflow.data import AUTOTUNE
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Dropout, Dense, RandomFlip, RandomRotation, RandomZoom, Rescaling, Conv2D, MaxPooling2D, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNetV3Small
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
DATA_DIRECTORY = './img/'

#### Prepare Image Data

In [None]:
# 80% -> training, 20% -> validation.
train_ds = image_dataset_from_directory(
  DATA_DIRECTORY,
  shuffle=True,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(IMG_HEIGHT, IMG_WIDTH),
  batch_size=BATCH_SIZE)

val_ds = image_dataset_from_directory(
  DATA_DIRECTORY,
  shuffle=True,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(IMG_HEIGHT, IMG_WIDTH),
  batch_size=BATCH_SIZE)

class_names = train_ds.class_names
num_classes = len(class_names)
print(class_names)

#### Visualize Image Data

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

#### Creat Test Dataset [Optional]

In [None]:
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 5)

print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))

#### Iterate Dataset & Batches Retrieval

In [None]:
# for image_batch, labels_batch in train_ds:
#   print(image_batch.shape)
#   print(labels_batch.shape)
#   break

#### Configure Dataset For Performance

In [None]:
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

#### Data Augmentation

In [None]:
data_augmentation = Sequential([
    RandomFlip("horizontal", input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
    RandomRotation(0.2),
    RandomZoom(0.2),
  ])

# plt.figure(figsize=(10, 10))
# for images, _ in train_ds.take(1):
#   for i in range(9):
#     augmented_images = data_augmentation(images)
#     ax = plt.subplot(3, 3, i + 1)
#     plt.imshow(augmented_images[0].numpy().astype("uint8"))
#     plt.axis("off")

#### Building The Model [CNN Architecture]

In [None]:
# CNN Architecture 
model = Sequential([
  data_augmentation,
  Rescaling(1./255),
  Conv2D(16, 3, padding='same', activation='relu'),
  MaxPooling2D(),
  Conv2D(32, 3, padding='same', activation='relu'),
  MaxPooling2D(),
  Conv2D(64, 3, padding='same', activation='relu'),
  MaxPooling2D(),
  Dropout(0.2),
  Flatten(),
  Dense(128, activation='relu'),
  Dense(num_classes)
])

#### Building The Model [MobileNetv3 Transfer Learning]

In [None]:
# Transfer Learning, MobileNetV3Small
base_model  = MobileNetV3Small(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), include_top=False, weights='imagenet')

image_batch, label_batch = next(iter(train_ds))
feature_batch = base_model(image_batch)

base_model.trainable = False

global_average_layer = GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
prediction_layer = Dense(num_classes)
prediction_batch = prediction_layer(feature_batch_average)

inputs = tf.keras.Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
x = data_augmentation(inputs)
x = tf.keras.applications.mobilenet_v3.preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
# model.summary()

#### Train The Model

In [None]:
checkpoint_filepath = './tmp/'
model_checkpoint_callback = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only=True)
early_stopping_monitor = EarlyStopping(monitor='val_loss', patience=10, mode='min', restore_best_weights=True) 

EPOCHS = 100
LEARNING_RATE = 0.0001

model.compile(optimizer=Adam(learning_rate=LEARNING_RATE),  loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[early_stopping_monitor, model_checkpoint_callback])

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(EPOCHS)

#### Visualize Training Results

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

#### Test & Evalution Initial Model

In [None]:
model.load_weights(checkpoint_filepath) # Load Best Model Weights
loss, accuracy = model.evaluate(test_ds)
print('Test accuracy :', accuracy)

#### Fine Tuning [Optional]

In [None]:
base_model.trainable = True
print("Number of layers in the base model: ", len(base_model.layers))

for layer in base_model.layers[:10]: layer.trainable = False

model.load_weights(checkpoint_filepath)
model.compile(optimizer=Adam(learning_rate=0.00001),  loss=SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
history = model.fit(train_ds, validation_data=val_ds, epochs=10, batch_size=BATCH_SIZE, callbacks=[early_stopping_monitor, model_checkpoint_callback])

#### Test & Evalution Fine Tuned Model

In [None]:
checkpoint_filepath = './tmp/'
model.load_weights(checkpoint_filepath)
loss, accuracy = model.evaluate(test_ds)
print('Test accuracy :', accuracy)

#### Save Model & Labels

In [None]:
# Load Best Model Weights
model.load_weights(checkpoint_filepath)
model.save('model.h5')
with open('labels.txt', 'w') as f:
  f.write(', '.join(str(i) for i in class_names))