In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Dropout, Input, Conv2D, BatchNormalization, MaxPooling2D, Activation, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import callbacks
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix


In [None]:
train_dir = "Training"
test_dir =  "Testing"

In [None]:
def list_folders(path):
    folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
    return sorted(folders)


In [None]:
train_classes = list_folders(train_dir)
test_classes = list_folders(test_dir)

print("Training ")
print(f"Found {len(train_classes)} classes in training data:")
print('\n'.join(train_classes))

print("\nTesting")
print(f"Found {len(test_classes)} classes in testing data:")
print('\n'.join(test_classes))

In [None]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 50


In [None]:
train_datagen = ImageDataGenerator(
    rescale=1./255,    
    horizontal_flip=True,
    vertical_flip=True
)
test_datagen = ImageDataGenerator(
    rescale=1./255,       
)

In [None]:
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle = False
)

In [None]:
images, labels = next(train_generator)

print("Batch shape:", images.shape)  

plt.figure(figsize=(12, 6))
num_images = 6 
for i in range(num_images):
    plt.subplot(2, 3, i+1)
    rand_idx = np.random.randint(0, len(images))
    plt.imshow(images[rand_idx])  
    plt.axis("off")
    plt.title(f"Class: {np.argmax(labels[rand_idx])}")
plt.tight_layout()
plt.show()

**CNN model**

In [None]:
def residual_block(x, filters, stride=1, is_conv_block=False):
 
    f1, f2, f3 = filters
    x_skip = x

    x = Conv2D(f1, kernel_size=1, strides=stride)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(f2, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x) 

    x = Conv2D(f3, kernel_size=1, strides=1)(x)
    x = BatchNormalization()(x)

    if is_conv_block:
        x_skip = Conv2D(f3, kernel_size=1, strides=stride)(x_skip)
        x_skip = BatchNormalization()(x_skip)

    x = Add()([x, x_skip])
    x = Activation('relu')(x)

    return x

def custom_CNN(input_shape=(224, 224, 3), num_classes=4):
   
    inputs = Input(shape=input_shape)

    x = Conv2D(64, kernel_size=7, strides=2, padding='same')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=3, strides=2, padding='same')(x)

    x = residual_block(x, filters=[64, 64, 256], stride=1, is_conv_block=True)
    x = residual_block(x, filters=[64, 64, 256])
    x = residual_block(x, filters=[64, 64, 256])
    x = residual_block(x, filters=[64, 64, 256])

    x = residual_block(x, filters=[128, 128, 512], stride=2, is_conv_block=True)
    x = residual_block(x, filters=[128, 128, 512])
    x = residual_block(x, filters=[128, 128, 512])
   
    
    x = residual_block(x, filters=[256, 256, 1024], stride=2, is_conv_block=True)
    x = residual_block(x, filters=[256, 256, 1024])
    x = residual_block(x, filters=[256, 256, 1024])

    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation='relu')(x)
    f1 = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax')(f1)

    cnn = Model(inputs=inputs, outputs=outputs)

    return cnn

cnn = custom_CNN()


cnn.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)
cnn.summary()

In [None]:
import visualkeras
visualkeras.layered_view(cnn, legend=True)

In [None]:
checkpoint_cb = ModelCheckpoint(
    'cnn_model.keras',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stopping_cb = callbacks.EarlyStopping(
    monitor='val_accuracy',
    patience=15,
    restore_best_weights=True,
    verbose=1
)
reduce_lr_cb = ReduceLROnPlateau(
    monitor='val_accuracy',   
    factor=0.3,               
    patience=7,               
    min_lr=1e-6,              
    verbose=1                 
)

In [None]:
history = cnn.fit(
    train_generator,
    epochs=50,
    validation_data=test_generator,
    callbacks=[checkpoint_cb, early_stopping_cb, reduce_lr_cb]
)

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.show()

In [None]:
loss, accuracy = cnn.evaluate(test_generator)

print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Test Loss: {loss:.4f}")


In [None]:
misclassification_rate = 1 - accuracy
print(f"Misclassification Rate: {misclassification_rate:.3f}")

In [None]:
y_true = test_generator.classes  
class_labels = list(test_generator.class_indices.keys())  

y_pred_probs = cnn.predict(test_generator)  
y_pred = np.argmax(y_pred_probs, axis=1)  

conf_matrix = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

report = classification_report(y_true, y_pred, target_names=class_labels)
print("Classification Report:\n", report)
