# Import libs

In [None]:
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from efficientnet.tfkeras import EfficientNetB0
import efficientnet.tfkeras as efc

In [12]:
SEED = 24

def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    
set_seed()

# Load data

In [13]:
PATH = './Dataset'

IMAGE_SIZE = (224, 224)
EPOCHS = 70
BATCH_SIZE = 128
NUM_CLASSES = 3

In [14]:
train_gen = ImageDataGenerator(
    horizontal_flip=True,
    zoom_range=.3,
    rotation_range=45,
    validation_split=.2
)

val_gen = ImageDataGenerator(validation_split=.2)

In [15]:
train_data = train_gen.flow_from_directory(
    PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    seed=SEED,
    subset='training'
)

val_data = val_gen.flow_from_directory(
    PATH,
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    seed=SEED,
    shuffle=False,
    subset='validation'
)

Found 151 images belonging to 2 classes.
Found 37 images belonging to 2 classes.


# Define model

**EfficientNetB0** is pretty good model at image classification among the others pretrained models.

In [None]:
base_model = tf.keras.applications.EfficientNetB0(include_top=False)

In [None]:
model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

These callbacks are very useful. I'm gonna stop model training when it tends to overfit. Also I will save model with best weights so I can use it later in other tasks.

In [None]:
early_stopping = EarlyStopping(patience=5, verbose=1)
checkpoint = ModelCheckpoint('mask_model.h5', save_best_only=True, verbose=1)
lr_reduce = ReduceLROnPlateau(patience=2, verbose=1)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(amsgrad=True),
    loss='categorical_crossentropy',
#     our data classes are perfectly balanced so I'm able to use simple accuracy metric
    metrics=['acc']
)

history = model.fit(
    train_data,
    validation_data=val_data,
    epochs=EPOCHS,
    callbacks=[early_stopping, checkpoint, lr_reduce]
)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train_loss', 'val_loss'], loc='upper right')
plt.show()