In [1]:
import zipfile
import os

In [2]:
dataset_zip = "C:\\Users\\Dhruv\\Downloads\\Dataset.zip"
dataset_dir = "C:\\Users\\Dhruv\\Downloads\\Dataset"
training_dir = os.path.join(dataset_dir, "training")
validation_dir = os.path.join(dataset_dir, "validation")

# Check if training and validation directories exist
if not os.path.exists(training_dir) or not os.path.exists(validation_dir):
    print("Training or validation directory does not exist. Extracting the dataset...")
    with zipfile.ZipFile(dataset_zip, "r") as z:
        z.extractall(dataset_dir)
else:
    print("Training and validation directories already exist. Skipping extraction.")

Training and validation directories already exist. Skipping extraction.


In [3]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
     

In [4]:
!pip install openpyxl

Defaulting to user installation because normal site-packages is not writeable


DEPRECATION: Loading egg at c:\program files\python312\lib\site-packages\vboxapi-1.0-py3.12.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330

[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
def plot_confusion_matrix(cm, class_names):
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt=".2f", cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.title('Normalized Confusion Matrix')
    plt.show()

def plot_roc_curves(y_true, y_pred, class_names):
    plt.figure(figsize=(12, 8))
    for i, class_name in enumerate(class_names):
        fpr, tpr, _ = roc_curve(y_true[:, i], y_pred[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curve')
    plt.legend(loc='lower right')
    plt.show()

if __name__ == "__main__":
    # Set parameters
    base_dir = r"C:\Users\Dhruv\Downloads\Dataset"
    train_dir = os.path.join(base_dir, 'training')
    val_dir = os.path.join(base_dir, 'validation')
    image_size = (224, 224)
    batch_size = 32
    epochs = 10
    class_names = ['Angioectasia', 'Bleeding', 'Erosion', 'Erythema', 'Foreign Body', 'Lymphangiectasia', 'Normal', 'Polyp', 'Ulcer', 'Worms']

    # Create training and validation datasets
    train_dataset = tf.keras.utils.image_dataset_from_directory(
        train_dir,
        labels='inferred',
        label_mode='categorical',
        class_names=class_names,
        image_size=image_size,
        batch_size=batch_size,
        shuffle=True,
        seed=123
    )

    val_dataset = tf.keras.utils.image_dataset_from_directory(
        val_dir,
        labels='inferred',
        label_mode='categorical',
        class_names=class_names,
        image_size=image_size,
        batch_size=batch_size,
        shuffle=False
    )

    # Prefetching data for performance optimization
    AUTOTUNE = tf.data.AUTOTUNE
    train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
    val_dataset = val_dataset.prefetch(buffer_size=AUTOTUNE)

    # Calculate class weights to handle class imbalance
    train_labels = np.concatenate([y for x, y in train_dataset], axis=0)
    train_labels_numeric = np.argmax(train_labels, axis=1)
    class_weights = class_weight.compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_labels_numeric),
        y=train_labels_numeric
    )
    class_weights_dict = dict(enumerate(class_weights))

    # Build the model
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(image_size[0], image_size[1], 3))
    base_model.trainable = False  # Freeze the base model

    inputs = tf.keras.Input(shape=(image_size[0], image_size[1], 3))
    x = preprocess_input(inputs)
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu')(x)
    outputs = Dense(len(class_names), activation='softmax')(x)
    model = Model(inputs, outputs)

    # Compile the model
    model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

    # Train the model
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        class_weight=class_weights_dict
    )

    # Evaluate the model
    val_predictions = model.predict(val_dataset)
    y_true = np.concatenate([y for x, y in val_dataset], axis=0)
    y_pred_classes = np.argmax(val_predictions, axis=1)
    y_true_classes = np.argmax(y_true, axis=1)

    # Confusion matrix
    cm = confusion_matrix(y_true_classes, y_pred_classes, normalize='true')
    plot_confusion_matrix(cm, class_names)

    # Classification report
    report = classification_report(y_true_classes, y_pred_classes, target_names=class_names)
    print("Classification Report:\n", report)

    # ROC Curves
    plot_roc_curves(y_true, val_predictions, class_names)

    # Save predictions to Excel
    val_file_paths = []
    for batch in val_dataset:
        for file_path in batch[0]._keras_mask:
            val_file_paths.append(file_path.numpy())

    results_df = pd.DataFrame({
        'Image_path': val_dataset.file_paths,
        'Actual_class': [class_names[i] for i in y_true_classes],
        'Predicted_class': [class_names[i] for i in y_pred_classes]
    })
    results_file_path = os.path.join(base_dir, "results.xlsx")
    results_df.to_excel(results_file_path, index=False)

Found 37607 files belonging to 10 classes.
Found 16132 files belonging to 10 classes.
Epoch 1/10
[1m 115/1176[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m22:55[0m 1s/step - accuracy: 0.2256 - loss: 2.3815