# Diabetic Retinopathy Detection using ResNet with Attention Mechanism

This notebook implements a deep learning model for diabetic retinopathy detection using ResNet architecture with attention mechanism. The dataset contains 5 classes of diabetic retinopathy severity levels.

## Dataset Classes:
- **No_DR**: No diabetic retinopathy (1805 images)
- **Mild**: Mild diabetic retinopathy (370 images)
- **Moderate**: Moderate diabetic retinopathy (999 images)
- **Severe**: Severe diabetic retinopathy (193 images)
- **Proliferate_DR**: Proliferative diabetic retinopathy (295 images)


## 1. Import Required Libraries


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings('ignore')

# Deep Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import torch.nn.functional as F

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)


## 2. Dataset Configuration and Data Loading


In [None]:
# Dataset configuration
DATA_DIR = '/Users/landaganesh/Documents/Projects /Miniproject/colored_images'
CLASSES = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']
NUM_CLASSES = len(CLASSES)
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001

print(f"Number of classes: {NUM_CLASSES}")
print(f"Classes: {CLASSES}")


In [None]:
# Create dataset mapping
def create_dataset_mapping(data_dir, classes):
    """
    Create a mapping of image paths to their corresponding labels
    """
    data = []
    
    for class_idx, class_name in enumerate(classes):
        class_dir = os.path.join(data_dir, class_name)
        if os.path.exists(class_dir):
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_file)
                    data.append({
                        'image_path': img_path,
                        'label': class_idx,
                        'class_name': class_name
                    })
    
    return pd.DataFrame(data)

# Create dataset
df = create_dataset_mapping(DATA_DIR, CLASSES)
print(f"Total images: {len(df)}")
print(f"\nClass distribution:")
print(df['class_name'].value_counts())

# Display sample data
print(f"\nSample data:")
print(df.head())


In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
class_counts = df['class_name'].value_counts()
plt.bar(class_counts.index, class_counts.values, color=['skyblue', 'lightcoral', 'lightgreen', 'orange', 'purple'])
plt.title('Class Distribution in Dataset')
plt.xlabel('Diabetic Retinopathy Classes')
plt.ylabel('Number of Images')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Calculate class weights for handling imbalanced data
class_weights = compute_class_weight('balanced', classes=np.unique(df['label']), y=df['label'])
class_weights_tensor = torch.FloatTensor(class_weights).to(device)
print(f"\nClass weights: {class_weights}")


## 3. Custom Dataset Class and Data Preprocessing


In [None]:
class DiabeticRetinopathyDataset(Dataset):
    """
    Custom dataset class for diabetic retinopathy images
    """
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        label = self.dataframe.iloc[idx]['label']
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image if loading fails
            image = Image.new('RGB', (IMG_SIZE, IMG_SIZE), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define data transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Data transforms defined successfully!")


## 4. Train-Validation Split


In [None]:
# Split the dataset into train and validation sets
train_df, val_df = train_test_split(
    df, 
    test_size=0.2, 
    random_state=42, 
    stratify=df['label']
)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")

# Create datasets
train_dataset = DiabeticRetinopathyDataset(train_df, transform=train_transform)
val_dataset = DiabeticRetinopathyDataset(val_df, transform=val_transform)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True
)

print(f"Data loaders created successfully!")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")


## 5. Attention Mechanism Implementation


In [None]:
class ChannelAttention(nn.Module):
    """
    Channel Attention Module
    """
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    """
    Spatial Attention Module
    """
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    """
    Convolutional Block Attention Module (CBAM)
    Combines Channel and Spatial Attention
    """
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
    
    def forward(self, x):
        # Apply channel attention
        x = x * self.channel_attention(x)
        # Apply spatial attention
        x = x * self.spatial_attention(x)
        return x

print("Attention mechanisms defined successfully!")


## 6. ResNet with Attention Model


In [None]:
class ResNetWithAttention(nn.Module):
    """
    ResNet50 with CBAM attention mechanism for diabetic retinopathy detection
    """
    def __init__(self, num_classes=5, pretrained=True):
        super(ResNetWithAttention, self).__init__()
        
        # Load pretrained ResNet50
        if pretrained:
            self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        else:
            self.backbone = resnet50(weights=None)
        
        # Get the number of input features for the classifier
        num_features = self.backbone.fc.in_features
        
        # Remove the original classifier
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Add attention modules after each residual block
        self.attention1 = CBAM(256, reduction=16)  # After layer2
        self.attention2 = CBAM(512, reduction=16)  # After layer3
        self.attention3 = CBAM(1024, reduction=16) # After layer4
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Classifier with dropout for regularization
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Extract features using ResNet backbone
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        
        # Layer 1
        x = self.backbone.layer1(x)
        
        # Layer 2 with attention
        x = self.backbone.layer2(x)
        x = self.attention1(x)
        
        # Layer 3 with attention
        x = self.backbone.layer3(x)
        x = self.attention2(x)
        
        # Layer 4 with attention
        x = self.backbone.layer4(x)
        x = self.attention3(x)
        
        # Global average pooling
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        
        # Classification
        x = self.classifier(x)
        
        return x

# Create model
model = ResNetWithAttention(num_classes=NUM_CLASSES, pretrained=True)
model = model.to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model created successfully!")

# Test model with a sample input
sample_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(device)
with torch.no_grad():
    sample_output = model(sample_input)
print(f"Sample output shape: {sample_output.shape}")


## 7. Training Configuration


In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=5, 
    verbose=True
)

print(f"Loss function: CrossEntropyLoss with class weights")
print(f"Optimizer: Adam with learning rate {LEARNING_RATE}")
print(f"Scheduler: ReduceLROnPlateau")
print(f"Training configuration completed!")


## 8. Training and Validation Functions


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """
    Train the model for one epoch
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device):
    """
    Validate the model for one epoch
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc, all_predictions, all_targets

print("Training and validation functions defined successfully!")


## 9. Training Loop


In [None]:
# Initialize training history
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
best_val_acc = 0.0
best_model_state = None

print("Starting training...")
print(f"Training for {EPOCHS} epochs")
print(f"Device: {device}")
print("-" * 50)

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 30)
    
    # Training
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss, val_acc, val_predictions, val_targets = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict().copy()
        print(f"New best validation accuracy: {best_val_acc:.2f}%")
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Early stopping check (optional)
    if epoch > 10 and val_acc < max(val_accuracies[-10:]) - 5:
        print("Early stopping triggered!")
        break

print("\nTraining completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")


## 10. Load Best Model and Evaluate


In [None]:
# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Best model loaded!")

# Final evaluation
print("\nFinal evaluation on validation set...")
val_loss, val_acc, val_predictions, val_targets = validate_epoch(model, val_loader, criterion, device)

print(f"Final Validation Accuracy: {val_acc:.2f}%")
print(f"Final Validation Loss: {val_loss:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(val_targets, val_predictions, target_names=CLASSES))


## 11. Visualization of Training Progress


In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
ax1.plot(train_losses, label='Training Loss', color='blue')
ax1.plot(val_losses, label='Validation Loss', color='red')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(train_accuracies, label='Training Accuracy', color='blue')
ax2.plot(val_accuracies, label='Validation Accuracy', color='red')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()


## 12. Confusion Matrix


In [None]:
# Confusion Matrix
cm = confusion_matrix(val_targets, val_predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASSES, yticklabels=CLASSES)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

# Calculate per-class accuracy
class_accuracies = cm.diagonal() / cm.sum(axis=1)
print("\nPer-class Accuracy:")
for i, class_name in enumerate(CLASSES):
    print(f"{class_name}: {class_accuracies[i]:.3f}")


## 13. Sample Predictions Visualization


In [None]:
def visualize_predictions(model, val_loader, class_names, num_samples=8):
    """
    Visualize sample predictions
    """
    model.eval()
    
    # Get a batch of validation data
    data_iter = iter(val_loader)
    images, labels = next(data_iter)
    
    # Move to device
    images = images.to(device)
    labels = labels.to(device)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        probabilities = F.softmax(outputs, dim=1)
    
    # Move back to CPU for visualization
    images = images.cpu()
    labels = labels.cpu()
    predicted = predicted.cpu()
    probabilities = probabilities.cpu()
    
    # Create subplots
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.ravel()
    
    for i in range(min(num_samples, len(images))):
        # Denormalize image for display
        img = images[i]
        img = img * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img = img + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        img = torch.clamp(img, 0, 1)
        img = img.permute(1, 2, 0)
        
        # Plot image
        axes[i].imshow(img)
        axes[i].set_title(f'True: {class_names[labels[i]]}\nPred: {class_names[predicted[i]]}\nConf: {probabilities[i][predicted[i]]:.3f}')
        axes[i].axis('off')
        
        # Color code the title based on correctness
        if labels[i] == predicted[i]:
            axes[i].title.set_color('green')
        else:
            axes[i].title.set_color('red')
    
    plt.tight_layout()
    plt.show()

# Visualize sample predictions
print("Sample Predictions:")
visualize_predictions(model, val_loader, CLASSES, num_samples=8)


## 14. Save Model


In [None]:
# Save the trained model
model_save_path = '/Users/landaganesh/Documents/Projects /Miniproject/diabetic_retinopathy_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': CLASSES,
    'num_classes': NUM_CLASSES,
    'img_size': IMG_SIZE,
    'best_val_acc': best_val_acc,
    'training_history': {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
}, model_save_path)

print(f"Model saved to: {model_save_path}")
print(f"Model includes:")
print(f"- Model weights")
print(f"- Class names: {CLASSES}")
print(f"- Number of classes: {NUM_CLASSES}")
print(f"- Image size: {IMG_SIZE}")
print(f"- Best validation accuracy: {best_val_acc:.2f}%")


## 15. Model Summary and Performance Metrics


In [None]:
# Print final summary
print("=" * 60)
print("DIABETIC RETINOPATHY DETECTION MODEL SUMMARY")
print("=" * 60)
print(f"Dataset: {len(df)} total images")
print(f"Classes: {CLASSES}")
print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Model: ResNet50 with CBAM Attention")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs trained: {len(train_losses)}")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Final validation accuracy: {val_acc:.2f}%")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("=" * 60)

# Print class-wise performance
print("\nCLASS-WISE PERFORMANCE:")
print("-" * 30)
for i, class_name in enumerate(CLASSES):
    print(f"{class_name}: {class_accuracies[i]:.3f} accuracy")

print("\nTraining completed successfully!")
print(f"Model saved to: {model_save_path}")
