In [None]:
import tensorflow as tf
import numpy as np
from keras.applications import EfficientNetV2B0
import matplotlib.pyplot as plt

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
training_data_path = "./train_dataset/"

seed=np.random.randint(0,100000)

train_data = tf.keras.utils.image_dataset_from_directory(
    training_data_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256),
    validation_split=0.2,
    subset="training",
    seed=seed
)

validation_data = tf.keras.utils.image_dataset_from_directory(
    training_data_path,
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256),
    validation_split=0.2,
    subset="validation",
    seed=seed
)

In [None]:
augmentation = tf.keras.Sequential([
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.08), # +- 30 degrees
        tf.keras.layers.RandomTranslation(0.1, 0.1, fill_mode="nearest"),
        tf.keras.layers.RandomZoom(0.1, fill_mode="nearest"),
        tf.keras.layers.RandomContrast(0.15),
        tf.keras.layers.RandomBrightness(0.15),
        tf.keras.layers.GaussianNoise(0.2)
])

In [None]:
base_model = EfficientNetV2B0(weights="imagenet", include_top=False)

base_model.trainable = False
#for layer in base_model.layers[:80]:
#    layer.trainable = False

inputs = tf.keras.layers.Input(shape=(256,256,3))
x = augmentation(inputs)
x = tf.keras.applications.efficientnet_v2.preprocess_input(x)
x = base_model(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.5)(x)
predictions = tf.keras.layers.Dense(6, activation='softmax')(x)

model = tf.keras.Model(inputs=inputs, outputs=predictions)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4) , loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

In [None]:
checkpoint = tf.keras.callbacks.ModelCheckpoint(monitor="val_loss",
                                                filepath="./best_model.keras",
                                                verbose=1,
                                                save_best_only=True)

early_stop = tf.keras.callbacks.EarlyStopping(monitor="val_loss",
                                            patience=10,
                                            restore_best_weights=True)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss",
                                                factor=0.5,
                                                patience=3,
                                                verbose=1,
                                                min_delta=0.0001,
                                                min_lr=0.000001)

callbacks = [checkpoint, early_stop, reduce_lr]

In [None]:
history = model.fit(
    train_data,
    validation_data=validation_data,
    epochs=300,
    verbose=1,
    callbacks=callbacks
)

In [None]:
plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.plot(history.history["loss"], label="loss")
plt.plot (history.history["val_loss"], label="val_loss")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid(True)
plt.legend()

plt.subplot(1,2,2)
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot (history.history["val_accuracy"], label="val_accuracy")
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.grid(True)
plt.legend()
plt.show()

In [None]:
model = tf.keras.models.load_model('./best_model.keras')

num_classes = len(validation_data.class_names)
conf_matrix = np.zeros((num_classes, num_classes), dtype=int)

y_true = []
y_pred = []

for images, labels in validation_data:
    predictions = model.predict(images)
    y_true_batch = np.argmax(labels, axis=1)
    y_pred_batch = np.argmax(predictions, axis=1)
    y_true.extend(y_true_batch)
    y_pred.extend(y_pred_batch)
    for true, pred in zip(y_true_batch, y_pred_batch):
        conf_matrix[true, pred] += 1

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

y_pred = []
y_true = []

for images, labels in validation_data:
    y_pred.extend(np.argmax(model.predict(images), axis=1))
    y_true.extend(np.argmax(labels.numpy(), axis=1))

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(8, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=validation_data.class_names, 
            yticklabels=validation_data.class_names)
plt.title('Confusion Matrix')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()