# Constants 

In [7]:
IMG_SIZE = (224, 224) 
train_dir, test_dir = "../data/data_min_balanced_df/train/", "../data/data_min_balanced_df/test/"  

In [9]:
import tensorflow as tf 

train_data = tf.keras.preprocessing.image_dataset_from_directory(
    directory=train_dir,
    image_size=IMG_SIZE,
    label_mode="categorical", 
    batch_size=32) 

test_data = tf.keras.preprocessing.image_dataset_from_directory(
    directory=test_dir,
    image_size=IMG_SIZE,
    label_mode="categorical", 
    batch_size=32) 

Found 11987 files belonging to 6 classes.
Found 1500 files belonging to 6 classes.


In [12]:
train_data.class_names, test_data.class_names

(['class_0', 'class_1', 'class_2', 'class_3', 'class_4', 'class_5'],
 ['class_0', 'class_1', 'class_2', 'class_3', 'class_4', 'class_5'])

In [13]:
for images, labels in train_data.take(1): 
    print(images.shape, labels.shape)

(32, 224, 224, 3) (32, 6)


# Model 0: Efficient Net B0 

In [3]:
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model.trainable = False 

inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
x = base_model(inputs) 
x = tf.keras.layers.GlobalAveragePooling2D(name='global_average_pooling_layer')(x) 
outputs = tf.keras.layers.Dense(6, activation="softmax", name="output_layer")(x)

model_0 = tf.keras.Model(inputs, outputs) 

In [4]:
model_0.compile(loss="categorical_crossentropy", 
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[
                    tf.keras.metrics.BinaryAccuracy(name='accuracy'),
                    tf.keras.metrics.Precision(name='precision'),
                    tf.keras.metrics.Recall(name='recall')
                ]) 


In [5]:
checkpoint_path = "checkpoints/MODEL_0/checkpoint-{epoch:01d}.ckpt" 
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                         save_weights_only=True,
                                                         save_best_only=False,
                                                         verbose=1) 