# Diabetic Retinopathy Detection using Transfer Learning

This notebook implements a Deep Learning pipeline for classifying Diabetic Retinopathy using Transfer Learning (VGG16) and Data Augmentation.

## 1. Imports

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16, ResNet50
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns

# Check for GPU
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## 2. Preprocessing: Black Border Removal

We define a function to crop the images to remove the black borders, focusing on the retinal area.

In [None]:
def crop_image_from_gray(img, tol=7):
    """
    Crops the black borders from the fundus image.
    """
    if img.ndim == 2:
        mask = img > tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim == 3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img > tol
        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0): # image is too dark so that we crop out everything,
            return img # return original image
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([img1,img2,img3],axis=-1)
        return img

def preprocess_input_image(img):
    """
    Custom preprocessing function to be passed to ImageDataGenerator.
    1. Crop black borders
    2. Apply CLAHE (Green channel enhancement could be added here, but VGG expects 3 channels)
    """
    # Convert to RGB if read as BGR (OpenCV default) - ImageDataGenerator loads as RGB usually
    # But if we use this in preprocessing_function, input is a numpy array
    
    # Crop
    img = img.astype('uint8')
    img = crop_image_from_gray(img)
    
    # Resize is handled by the generator, but cropping changes size. 
    # We need to ensure it returns the target size or let the generator handle resizing AFTER this.
    # However, ImageDataGenerator's preprocessing_function runs AFTER resize usually? 
    # Actually, it runs BEFORE resize if using flow_from_directory? No, it runs after image is loaded.
    # To be safe, we'll just use this for cropping.
    
    # Resize back to target size if needed, but let's just return the cropped image
    # and let the model resize it? No, generator expects fixed size.
    # We will rely on the generator's resizing.
    
    # For simplicity in this notebook, we will use standard VGG preprocessing
    img = cv2.resize(img, (224, 224))
    img = tf.keras.applications.vgg16.preprocess_input(img)
    return img


## 3. Data Loading & Augmentation

We use `ImageDataGenerator` to load images from the `datasets/` directory and apply data augmentation.

In [None]:
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
DATA_DIR = 'datasets/' # Update this path if your data is elsewhere

# Data Augmentation for Training
train_datagen = ImageDataGenerator(
    rescale=1./255, # Normalize pixel values
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2, # Use 20% for validation
    preprocessing_function=None # Add custom preprocessing if needed
)

# Validation Generator (No augmentation, just rescaling)
val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Load Data
print("Loading Training Data:")
train_generator = train_datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical', # or 'binary' if 2 classes
    subset='training'
)

print("\nLoading Validation Data:")
validation_generator = val_datagen.flow_from_directory(
    DATA_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False # Important for evaluation
)

## 4. Model Definition: Transfer Learning (VGG16)

We load the VGG16 model pre-trained on ImageNet, freeze its layers, and add a custom classification head.

In [None]:
def build_model(num_classes):
    # Load VGG16 base model
    base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
    # Freeze base layers
    for layer in base_model.layers:
        layer.trainable = False
        
    # Add custom head
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=predictions)
    
    return model

num_classes = len(train_generator.class_indices)
model = build_model(num_classes)

model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

## 5. Training

Train the model with Early Stopping to prevent overfitting.

In [None]:
epochs = 20 # Adjust as needed

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    epochs=epochs,
    callbacks=[early_stopping]
)

## 6. Evaluation

Evaluate the model using Confusion Matrix and ROC Curve.

In [None]:
# Plot Accuracy and Loss
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.legend()
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss')
plt.show()

# Confusion Matrix
Y_pred = model.predict(validation_generator, validation_generator.samples // BATCH_SIZE + 1)
y_pred = np.argmax(Y_pred, axis=1)
print('Confusion Matrix')
cm = confusion_matrix(validation_generator.classes, y_pred)
print(cm)

target_names = list(train_generator.class_indices.keys())
print(classification_report(validation_generator.classes, y_pred, target_names=target_names))

# Plot Confusion Matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()
