In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV3Large
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Model
import numpy as np
from sklearn.utils import shuffle

In [2]:
# Data Augmentation for Training
train_datagen = ImageDataGenerator(
    rescale=1.0/255.0,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
)

In [3]:
# Data Generators
train_batches = train_datagen.flow_from_directory(
    'reduceddata/train',
    target_size=(224, 224),
    batch_size=20,
    class_mode='categorical'
)

Found 8755 images belonging to 5 classes.


In [4]:
valid_batches = ImageDataGenerator(rescale=1.0/255.0).flow_from_directory(
    'reduceddata/valid',
    target_size=(224, 224),
    batch_size=30,
    class_mode='categorical'
)

Found 1840 images belonging to 5 classes.


In [6]:
base_model = MobileNetV3Large(weights='imagenet', include_top=False,input_shape=(224,224,3))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = BatchNormalization()(x)
predictions = Dense(5, activation='softmax')(x)

In [7]:
# Create a new model
new_model = Model(inputs=base_model.input, outputs=predictions)

In [8]:
# Freeze some layers for fine-tuning
for layer in base_model.layers:
    layer.trainable = False

In [9]:
# Compile the model
new_model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [10]:
# Callbacks for model saving and early stopping
checkpointer = ModelCheckpoint(filepath='drdetection.hdf5', save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=5)

In [None]:
# Training
history = new_model.fit(
    train_batches,
    steps_per_epoch=len(train_batches),
    validation_data=valid_batches,
    validation_steps=len(valid_batches),
    epochs=30,
    verbose=1,
    callbacks=[checkpointer, early_stopping]
)

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30