**Libraries**

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns


In [None]:

# Define paths
base_dir = 'content/Diabetic_Balanced_Data'
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')


In [None]:
# Define image size and batch size
img_height, img_width = 224, 224
batch_size = 32

In [None]:
#transform images for training, validating and testing 
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}


In [None]:

def load_datasets():
    print("Loading datasets...")
    train_dataset = datasets.ImageFolder(train_dir, data_transforms['train'])
    val_dataset = datasets.ImageFolder(val_dir, data_transforms['val'])
    test_dataset = datasets.ImageFolder(test_dir, data_transforms['test'])
    
    # Create data loaders
    #num_workers=0 to avoid multiprocessing issues - no clue what that is  - check 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    #printing inf o 
    classes = train_dataset.classes
    print("\n--- Dataset Information ---")
    print(f"Training set size: {len(train_dataset)} images")
    print(f"Validation set size: {len(val_dataset)} images")
    print(f"Test set size: {len(test_dataset)} images")
    print(f"Classes: {classes}")
    print("---------------------------\n")
    
    
    class_counts = {cls: 0 for cls in classes}
    for _, label in train_dataset.samples:
        class_counts[classes[label]] += 1
    print("Class distribution in training set:")
    for cls, count in class_counts.items():
        print(f"{cls}: {count} images")
    print()
    
    return train_loader, val_loader, test_loader, classes


In [None]:

#keeping track of training 
def plot_training_history(history, save_path=None):
    plt.figure(figsize=(12, 5))
    
    # Plot training and validation accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Over Epochs')
    plt.legend()
    plt.grid(True)
    
    # Plot training and validation loss
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Over Epochs')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        print(f"Training history plot saved to {save_path}")
    
    plt.show()


In [None]:

# Training function with detailed progress tracking
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=30, phase="initial_training", save_path=None):
    best_val_acc = 0.0
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    total_start_time = time.time()
    
    for epoch in range(num_epochs):
        print(f"\n[{phase}] Epoch {epoch+1}/{num_epochs}")
        epoch_start_time = time.time()
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        # Progress bar for training
        train_bar = tqdm(train_loader, desc=f"Training")
        for inputs, labels in train_bar:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            _, preds = torch.max(outputs, 1)
            batch_loss = loss.item() * inputs.size(0)
            batch_corrects = torch.sum(preds == labels.data).double()
            
            running_loss += batch_loss
            running_corrects += batch_corrects
            
            # Update progress bar with current batch loss and accuracy
            batch_acc = batch_corrects / inputs.size(0)
            train_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{batch_acc:.4f}"
            })
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects / len(train_loader.dataset)
        
        # Store training metrics
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        
        # Progress bar for validation
        val_bar = tqdm(val_loader, desc=f"Validation")
        with torch.no_grad():
            for inputs, labels in val_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                _, preds = torch.max(outputs, 1)
                batch_loss = loss.item() * inputs.size(0)
                batch_corrects = torch.sum(preds == labels.data).double()
                
                val_loss += batch_loss
                val_corrects += batch_corrects
                
                # Update progress bar
                batch_acc = batch_corrects / inputs.size(0)
                val_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{batch_acc:.4f}"
                })
        
        val_loss /= len(val_loader.dataset)
        val_acc = val_corrects / len(val_loader.dataset)
        
        # Store validation metrics
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc.item())
        
        epoch_time = time.time() - epoch_start_time
        
        # Print epoch summary
        print(f"\n[{phase}] Epoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Training   - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
        print(f"  Validation - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
        print(f"  Time: {epoch_time:.2f}s")
        
        # Save the best model based on validation accuracy
        if val_acc > best_val_acc and save_path:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, save_path)
            print(f"  Saved best model with validation accuracy: {val_acc:.4f}")
    
    total_time = time.time() - total_start_time
    print(f"\n[{phase}] Training completed in {total_time/60:.2f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    
    #summary
    plot_path = f"{save_path.split('.')[0]}_history.png" if save_path else None
    plot_training_history(history, save_path=plot_path)
    
    return model, history


In [None]:

def evaluate_model(model, test_loader, criterion, classes):
    print("\n=== FINAL MODEL EVALUATION ===")
    
    model.eval()
    test_loss = 0.0
    test_corrects = 0
    all_preds = []
    all_labels = []
    
    print("Evaluating on test set...")
    test_bar = tqdm(test_loader, desc="Testing")
    with torch.no_grad():
        for inputs, labels in test_bar:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            test_loss += loss.item() * inputs.size(0)
            test_corrects += torch.sum(preds == labels.data)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_loss /= len(test_loader.dataset)
    test_acc = test_corrects.double() / len(test_loader.dataset)
    
    print("\nTest Set Results:")
    print(f"Loss: {test_loss:.4f}")
    print(f"Accuracy: {test_acc:.4f}")
    
    # Class-wise accuracy
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=classes))
    
    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.show()



In [None]:

def main():
    # Load datasets
    train_loader, val_loader, test_loader, classes = load_datasets()
    
    # Check if CUDA is available
    global device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load pre-trained EfficientNet-B3 model
    print("\nInitializing EfficientNet-B3 model...")
    model = models.efficientnet_b3(weights="IMAGENET1K_V1")  # Updated from pretrained=True
    
    # Modify the classifier to match the number of classes
    num_classes = len(classes)
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    model = model.to(device)
    print(f"Model output layer modified for {num_classes} classes")
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # PHASE 1: Initial Training
    initial_model_path = 'efficientnet_b3_initial.pth'
    print("\n=== STARTING INITIAL TRAINING ===")
    model, initial_history = train_model(model, train_loader, val_loader, criterion, optimizer, 
                                        num_epochs=30, phase="Initial Training", save_path=initial_model_path)
    
    print(f"\nInitial training completed. Best model saved to {initial_model_path}")
    
    # PHASE 2: Fine-tuning
    print("\n=== STARTING FINE-TUNING ===")
    print("Loading the best initial model...")
    
    # Load the saved model
    checkpoint = torch.load(initial_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']} with validation accuracy {checkpoint['val_acc']:.4f}")
    
    # Freeze most layers and only train the last few layers
    print("Freezing early layers...")
    for param in model.parameters():
        param.requires_grad = False
    
    # Unfreeze the last few layers for fine-tuning
    print("Unfreezing the last 10 feature layers and classifier...")
    for param in model.features[-10:].parameters():
        param.requires_grad = True
    # Also unfreeze the classifier
    for param in model.classifier.parameters():
        param.requires_grad = True
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")
    
    # Re-define optimizer with a lower learning rate for fine-tuning
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)
    print("Optimizer reset with learning rate 0.0001")
    
    # Fine-tune the model
    final_model_path = 'efficientnet_b3_finetuned.pth'
    model, finetuning_history = train_model(model, train_loader, val_loader, criterion, optimizer, 
                                           num_epochs=30, phase="Fine-tuning", save_path=final_model_path)
    
    print(f"\nFine-tuning completed. Best model saved to {final_model_path}")
    
    # Load the best fine-tuned model for evaluation
    checkpoint = torch.load(final_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate the model
    evaluate_model(model, test_loader, criterion, classes)
    
    print("Training and evaluation completed!")
    print(f"Initial model saved to: {initial_model_path}")
    print(f"Fine-tuned model saved to: {final_model_path}")
    print(f"Training plots and confusion matrix saved to current directory")


if __name__ == "__main__":
    main()