<a href="https://www.kaggle.com/code/jfjerin/densenet201?scriptVersionId=278985243" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader, random_split
import os
from PIL import Image
import time
import numpy as np
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_curve, auc
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns

# --- 1. Custom Dataset Class ---
class CervicalDataset(Dataset):
    def __init__(self, root_dir, transform=None, split_type='train'):
        self.root_dir = root_dir
        self.transform = transform
        self.split_type = split_type 
        self.image_paths = []
        self.labels = []
        self.label_map = {} 

        print(f"Initializing CervicalDataset for split '{split_type}' from root_dir: {root_dir}")

        top_level_items = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        
        found_images_count = 0
        
        if not top_level_items:
            print(f"WARNING: No subdirectories found in {root_dir}. Please check your `data_dir`.")
            if os.path.isdir(os.path.join(root_dir, 'im_Dyskeratotic')):
                top_level_items = ['im_Dyskeratotic']
                print(f"Inferred 'im_Dyskeratotic' as top-level item in {root_dir}.")
            else:
                print(f"Could not infer top-level class folders. Please ensure `data_dir` points to the root of your dataset.")


        for top_item in top_level_items:
            first_level_path = os.path.join(root_dir, top_item)
            
            # Check for the nested class folder (e.g., 'im_Dyskeratotic' inside 'im_Dyskeratotic')
            nested_class_path = os.path.join(first_level_path, top_item) 
            
            # Check for the 'CROPPED' folder within the nested path
            cropped_folder_path = os.path.join(nested_class_path, 'CROPPED')

            if os.path.isdir(cropped_folder_path):
                print(f"Found CROPPED folder for class '{top_item}' at: {cropped_folder_path}")
                for img_name in os.listdir(cropped_folder_path):
                    if img_name.lower().endswith('.bmp'):
                        img_path = os.path.join(cropped_folder_path, img_name)
                        self.image_paths.append(img_path)
                        self.labels.append(top_item) 
                        found_images_count += 1
            else:
                # Fallback: Check if CROPPED is directly under the first_level_path (if no nesting)
                direct_cropped_path = os.path.join(first_level_path, 'CROPPED')
                if os.path.isdir(direct_cropped_path):
                    print(f"WARNING: Found CROPPED directly under {first_level_path} (no nested class folder).")
                    for img_name in os.listdir(direct_cropped_path):
                        if img_name.lower().endswith('.bmp'):
                            img_path = os.path.join(direct_cropped_path, img_name)
                            self.image_paths.append(img_path)
                            self.labels.append(top_item)
                            found_images_count += 1
                else:
                    print(f"Neither nested 'CROPPED' nor direct 'CROPPED' found for '{top_item}' at {first_level_path}. Skipping.")
        
        if found_images_count == 0:
            raise ValueError(f"No .bmp images found in the specified root directory: {root_dir}. "
                             f"Please check your `data_dir` path and the exact dataset structure.")

        print(f"Total images found for split '{split_type}': {len(self.image_paths)}")

        unique_labels = sorted(list(set(self.labels)))
        self.label_map = {label: i for i, label in enumerate(unique_labels)}
        self.int_labels = [self.label_map[label] for label in self.labels]
        self.num_classes = len(self.label_map)
        
        print(f"Detected classes: {unique_labels}")
        print(f"Label map: {self.label_map}, Number of classes: {self.num_classes}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
            
        label = self.int_labels[idx]
        return image, label

# Custom Dataset class to handle pre-split paths and labels
class SplitDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        return image, label

# --- 2. Data Preprocessing and Loading ---

# Define transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet stats
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = '/kaggle/input/cervical-cancer-largest-dataset-sipakmed'

print(f"\nAttempting to load data from: {data_dir}")
if not os.path.exists(data_dir):
    raise FileNotFoundError(f"The specified data_dir does not exist: {data_dir}. Please check the path.")

# Create a "master" dataset to get all paths and labels first
master_dataset = CervicalDataset(root_dir=data_dir, transform=None, split_type='master')
num_classes = master_dataset.num_classes
class_names = list(master_dataset.label_map.keys())

print(f"\nMaster dataset size: {len(master_dataset)} images.")
print(f"Number of classes detected: {num_classes}")

# Extract all image paths and their integer labels
all_image_paths = master_dataset.image_paths
all_int_labels = master_dataset.int_labels

# Perform stratified train/validation/test split
train_paths, test_paths, train_labels, test_labels = train_test_split(
    all_image_paths, all_int_labels, test_size=0.15, stratify=all_int_labels, random_state=42
)
train_paths, val_paths, train_labels, val_labels = train_test_split(
    train_paths, train_labels, test_size=(0.15/0.85), stratify=train_labels, random_state=42
)

print(f"Dataset split sizes: Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

# Create dataset instances for each split with their specific transforms
train_dataset = SplitDataset(train_paths, train_labels, data_transforms['train'])
val_dataset = SplitDataset(val_paths, val_labels, data_transforms['val'])
test_dataset = SplitDataset(test_paths, test_labels, data_transforms['test'])

# Check for empty datasets after splitting
if len(train_dataset) == 0 or len(val_dataset) == 0 or len(test_dataset) == 0:
    raise ValueError("One or more dataset splits resulted in zero size. "
                     "Ensure your `master_dataset` is large enough for splitting.")

# Create DataLoaders
dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2), # num_workers can be adjusted
    'val': DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2),
    'test': DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2),
}

print("\nDataLoaders created successfully with the following dataset sizes:")
print(f"Train Dataset: {len(train_dataset)} images")
print(f"Validation Dataset: {len(val_dataset)} images")
print(f"Test Dataset: {len(test_dataset)} images")
print(f"Train DataLoader batches: {len(dataloaders['train'])}")
print(f"Validation DataLoader batches: {len(dataloaders['val'])}")
print(f"Test DataLoader batches: {len(dataloaders['test'])}")

# --- 3. Model Selection and Loading (DenseNet201 Architecture) ---

# Load pre-trained DenseNet201
# Using 'weights' parameter for recommended practice (instead of deprecated 'pretrained=True')
model_ft = models.densenet201(weights=models.DenseNet201_Weights.IMAGENET1K_V1)
print("Loaded pre-trained DenseNet201 model.")

# 4. Freeze initial layers
# DenseNet's feature extractor is typically everything before the final classifier
for param in model_ft.parameters():
    param.requires_grad = False
print("Frozen initial layers of DenseNet201.")

# 5. Modify classification head
# DenseNet's final classification layer is accessed via `model_ft.classifier`
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
print(f"Replaced DenseNet201's classifier head with output for {num_classes} classes.")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)

print(f"\nModel initialized and moved to device: {device}")

# --- 6. Define Loss Function and Optimizer ---
criterion = nn.CrossEntropyLoss()
# Optimize only the parameters of the newly added classification head
# `filter(lambda p: p.requires_grad, model_ft.parameters())` correctly selects only trainable parameters
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=0.001)

# --- 7. Training Loop ---
def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=10):
    since = time.time()
    best_acc = 0.0
    
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward
                # Track gradients only if in training phase
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accuracies.append(epoch_acc.item())
            else:
                val_losses.append(epoch_loss)
                val_accuracies.append(epoch_acc.item())

            # Deep copy the model if it's the best performing
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, train_losses, train_accuracies, val_losses, val_accuracies
    

# --- 8. Evaluation Function ---
def evaluate_model(model, dataloader, device, class_names):
    model.eval() # Set model to evaluate mode
    all_preds = []
    all_labels = []
    all_probs = [] # For ROC/AUC

    start_time = time.time()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    end_time = time.time()
    test_time = end_time - start_time

    # --- Calculate Metrics ---
    
    # Overall Accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    
    # Precision, Recall, F1-score (weighted average for multi-class)
    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    
    # Per-class metrics
    class_precision, class_recall, class_f1_score, _ = precision_recall_fscore_support(all_labels, all_preds, average=None, zero_division=0)
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)

    # ROC Curve and AUC (One-vs-Rest for multi-class)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    # Binarize labels for ROC/AUC
    binarized_labels = label_binarize(all_labels, classes=range(len(class_names)))
    
    for i in range(len(class_names)):
        if np.sum(binarized_labels[:, i]) == 0: # Skip if no true samples for this class
            fpr[i], tpr[i], _ = [0], [0], [0]
            roc_auc[i] = np.nan # Not a number
            continue
        
        # Check if there's only one unique class in the binarized data for this specific class
        # roc_curve requires at least two unique class labels.
        if len(np.unique(binarized_labels[:, i])) < 2:
            print(f"Warning: ROC curve for class {class_names[i]} cannot be computed (only one class present in true labels).")
            fpr[i], tpr[i], _ = [0], [0], [0]
            roc_auc[i] = np.nan
            continue

        fpr[i], tpr[i], _ = roc_curve(binarized_labels[:, i], np.array(all_probs)[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        
    # Micro-average ROC curve and AUC (good for imbalanced datasets)
    fpr_micro, tpr_micro, _ = roc_curve(binarized_labels.ravel(), np.array(all_probs).ravel())
    roc_auc_micro = auc(fpr_micro, tpr_micro)

    # Macro-average ROC curve and AUC
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names)) if not np.isnan(roc_auc[i])]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(len(class_names)):
        if not np.isnan(roc_auc[i]):
            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= len([i for i in range(len(class_names)) if not np.isnan(roc_auc[i])]) # Average and normalize
    fpr_macro = all_fpr
    tpr_macro = mean_tpr
    roc_auc_macro = auc(fpr_macro, tpr_macro)

    # --- Report Results ---
    print("\n--- Evaluation Results ---")
    print(f"Overall Accuracy: {accuracy:.4f}")
    print(f"Precision (weighted): {precision:.4f}")
    print(f"Recall (weighted): {recall:.4f}")
    print(f"F1-Score (weighted): {f1_score:.4f}")
    print(f"Test Inference Time: {test_time:.2f} seconds")

    print("\nPer-Class Metrics:")
    for i, name in enumerate(class_names):
        print(f"  Class '{name}' ({i}): Precision={class_precision[i]:.4f}, Recall={class_recall[i]:.4f}, F1={class_f1_score[i]:.4f}")
    
    print("\nConfusion Matrix:")
    print(cm)

    print("\nROC AUC:")
    for i, name in enumerate(class_names):
        if not np.isnan(roc_auc[i]):
            print(f"  Class '{name}' ({i}) AUC: {roc_auc[i]:.4f}")
        else:
            print(f"  Class '{name}' ({i}) AUC: Not available (single class present)")
    print(f"  Micro-average AUC: {roc_auc_micro:.4f}")
    print(f"  Macro-average AUC: {roc_auc_macro:.4f}")

    # --- Visualizations ---
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()
    
    # Plotting training history
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(len(train_losses)), train_losses, label='Train Loss')
    plt.plot(range(len(val_losses)), val_losses, label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(range(len(train_accuracies)), train_accuracies, label='Train Accuracy')
    plt.plot(range(len(val_accuracies)), val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    plt.plot(fpr_micro, tpr_micro, label=f'Micro-average ROC (AUC = {roc_auc_micro:.2f})', color='deeppink', linestyle=':', linewidth=4)
    plt.plot(fpr_macro, tpr_macro, label=f'Macro-average ROC (AUC = {roc_auc_macro:.2f})', color='navy', linestyle=':', linewidth=4)
    
    # FIX STARTS HERE
    # Get a colormap object
    cmap = plt.colormaps['jet'] 
    # Generate N evenly spaced colors from the colormap
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_names))]
    # FIX ENDS HERE

    for i, color in zip(range(len(class_names)), colors):
        if not np.isnan(roc_auc[i]):
            plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'ROC curve of class {class_names[i]} (AUC = {roc_auc[i]:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right")
    plt.show()

    return accuracy, precision, recall, f1_score, cm, roc_auc, test_time

# --- Main Execution ---
if __name__ == '__main__':
    print("Starting Cervical Cancer Classification Training with DenseNet201...")
    
    # Train the model
    # You can adjust num_epochs
    trained_model, train_losses, train_accuracies, val_losses, val_accuracies = train_model(model_ft, criterion, optimizer_ft, dataloaders, device, num_epochs=10)

    print("\n--- Model Training Complete. Starting Evaluation ---")
    
    # Evaluate the trained model on the test set
    test_accuracy, test_precision, test_recall, test_f1, test_cm, test_auc, test_time = evaluate_model(
        trained_model, dataloaders['test'], device, class_names
    )
    
    # If you want a visual of a sample image:
    print("\n--- Displaying a sample image with its predicted label ---")
    
    # Get a single batch from the test DataLoader
    sample_inputs, sample_labels = next(iter(dataloaders['test']))
    
    # Pick the first image in the batch
    sample_input = sample_inputs[0].unsqueeze(0).to(device) # Add batch dimension
    actual_label = sample_labels[0].item()
    
    # Get prediction
    trained_model.eval()
    with torch.no_grad():
        output = trained_model(sample_input)
        _, predicted_class_idx = torch.max(output, 1)
    
    predicted_label = predicted_class_idx.item()
    
    # Denormalize image for display
    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )
    display_image = inv_normalize(sample_inputs[0]).permute(1, 2, 0).cpu().numpy()
    display_image = np.clip(display_image, 0, 1) # Clip to valid range [0, 1]

    plt.figure(figsize=(6, 6))
    plt.imshow(display_image)
    plt.title(f"Actual: {class_names[actual_label]}\nPredicted: {class_names[predicted_label]}", fontsize=16)
    plt.axis('off')
    plt.show()

    print("\n--- Visualizing the concept of cervical cell classification with DenseNet201 ---")
    
    
    print("\nThis image conceptually illustrates how a DenseNet201 model, like the one implemented, processes microscopic cervical cell images to classify them, aiding in early detection and diagnosis of cervical cancer.")