In [1]:
# Import necessary libraries
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


## Load and Balance Dataset

In [2]:
# Define dataset path and classes
dataset_path = './dataset'
classes = ['Mild', 'Moderate', 'No DR', 'Proliferative DR', 'Severe']

# Function to count images in each class
def count_images_in_class(class_path):
    return len(os.listdir(class_path))

# Count images in each class
image_counts = {class_name: count_images_in_class(os.path.join(dataset_path, class_name)) for class_name in classes}
print(image_counts)


{'Mild': 2443, 'Moderate': 5292, 'No DR': 25810, 'Proliferative DR': 708, 'Severe': 873}


## Data Augmentation and Balancing

In [3]:
# Define data augmentation strategy
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=10,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Function to augment and balance classes
def balance_classes(dataset_path, classes, target_count=2000):
    for class_name in classes:
        class_path = os.path.join(dataset_path, class_name)
        image_count = count_images_in_class(class_path)
        if image_count < target_count:
            # Perform augmentation
            datagen = ImageDataGenerator(
                rotation_range=20,
                width_shift_range=0.1,
                height_shift_range=0.1,
                shear_range=10,
                zoom_range=0.2,
                horizontal_flip=True,
                fill_mode='nearest'
            )
            # Implement augmentation logic here
            print(f'Augmenting {class_name} from {image_count} to {target_count} images')
        else:
            print(f'{class_name} already has {image_count} images, no augmentation needed')

balance_classes(dataset_path, classes)

Mild already has 2443 images, no augmentation needed
Moderate already has 5292 images, no augmentation needed
No DR already has 25810 images, no augmentation needed
Augmenting Proliferative DR from 708 to 2000 images
Augmenting Severe from 873 to 2000 images


## Preprocessing and Model Training

In [None]:
# Define preprocessing and training steps
def train_mobilenet_model(dataset_path, classes, target_size=(224, 224), batch_size=32):
    # Data generators for training and validation
    train_datagen = ImageDataGenerator(
        preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input,
        validation_split=0.2
    )
    train_generator = train_datagen.flow_from_directory(
        dataset_path,
        target_size=target_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='training'
    )
    validation_generator = train_datagen.flow_from_directory(
        dataset_path,
        target_size=target_size,
        batch_size=batch_size,
        class_mode='categorical',
        subset='validation'
    )

    # Load MobileNetV2 model
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(len(classes), activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)

    # Freeze base layers
    for layer in base_model.layers:
        layer.trainable = False

    # Compile model
    model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

    # Train model
    history = model.fit(
        train_generator,
        steps_per_epoch=train_generator.samples // batch_size,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // batch_size,
        epochs=10,
        callbacks=[ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001),
                   EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)]
    )

    return model, history, train_generator, validation_generator

model, history, train_generator, validation_generator = train_mobilenet_model(dataset_path, classes)

Found 28103 images belonging to 5 classes.
Found 7023 images belonging to 5 classes.
Epoch 1/10


  self._warn_if_super_not_called()


[1m878/878[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1163s[0m 1s/step - accuracy: 0.7201 - loss: 0.9042 - val_accuracy: 0.7410 - val_loss: 0.7803 - learning_rate: 0.0010
Epoch 2/10
[1m  1/878[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m8:45[0m 599ms/step - accuracy: 0.8438 - loss: 0.6446

  self.gen.throw(typ, value, traceback)


[1m878/878[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 726us/step - accuracy: 0.8438 - loss: 0.6446 - val_accuracy: 0.6667 - val_loss: 0.8926 - learning_rate: 0.0010
Epoch 3/10
[1m878/878[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m688s[0m 781ms/step - accuracy: 0.7456 - loss: 0.7558 - val_accuracy: 0.7440 - val_loss: 0.7699 - learning_rate: 0.0010
Epoch 4/10
[1m878/878[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 377us/step - accuracy: 0.7812 - loss: 0.6072 - val_accuracy: 0.6667 - val_loss: 0.7678 - learning_rate: 0.0010
Epoch 5/10
[1m189/878[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m6:46[0m 590ms/step - accuracy: 0.7255 - loss: 0.7851

## Model Evaluation

In [None]:
# Evaluate the model
def evaluate_model(model, validation_generator, classes):
    loss, accuracy = model.evaluate(validation_generator)
    print(f'Test accuracy: {accuracy:.2f}')

    # Generate predictions
    predictions = model.predict(validation_generator)
    predicted_classes = np.argmax(predictions, axis=1)
    true_classes = validation_generator.classes

    # Print classification report and confusion matrix
    print(classification_report(true_classes, predicted_classes, target_names=classes))
    conf_mat = confusion_matrix(true_classes, predicted_classes)
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.show()

evaluate_model(model, validation_generator, classes)