# EfficientNet-B0 Training

This notebook trains one convolutional neural network architecture, **EfficientNet-B0**, for binary classification of mammograms (benign vs malignant).

The training was performed on **Kaggle** using GPU acceleration.  
If you wish to run this notebook locally, you may need to modify some directory paths (e.g., for data loading, model saving, and output locations) to match your machine's folder structure.

Key components of this notebook include:
- Loading and preprocessing the training and validation datasets
- Applying data augmentation to improve model generalisation
- Fine-tuning EfficientNet-B0 model pre-trained on ImageNet
- Tracking training and validation performance over epochs
- Saving the best-performing model checkpoints based on accuracy and validation AUC


In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, random_split, ConcatDataset, Subset, DataLoader
from torchvision import datasets
import torchvision.transforms as T
import random
import matplotlib.pyplot as plt

import time
from PIL import Image
from tqdm import tqdm
import os
import cv2
import pandas as pd
import numpy as np
from torchvision.io import read_image

from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image_rgb = Image.open(img_path).convert("RGB")
        label = self.img_labels.iloc[idx, 2]
        
        if self.transform:
            image_rgb = self.transform(image_rgb)
        if self.target_transform:
            label = self.target_transform(label)
        return image_rgb, label

In [None]:
# Function to add Gaussian noise
class AddGaussianNoise(torch.nn.Module):
    def __init__(self, mean=0.0, std=0.05, p=0.3):
        super().__init__()
        self.mean = mean
        self.std = std
        self.p = p

    def forward(self, tensor):
        if random.random() < self.p:  
            noise = torch.randn_like(tensor) * self.std
            tensor = torch.clamp(tensor + noise, 0, 1) 
        return tensor

In [None]:
# Custom CLAHE Transform
class ApplyCLAHE:
    def __init__(self, clip_limit=5.0, tile_grid_size=(8,8), p=1.0):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size
        self.p = p 

    def __call__(self, img):
        if random.random() < self.p: 
            img_np = np.array(img)
            img_gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
            img_clahe = clahe.apply(img_gray)

            # Convert back to 3-channel grayscale (DenseNet expects 3 channels)
            img_clahe = cv2.merge([img_clahe, img_clahe, img_clahe])
            return Image.fromarray(img_clahe)
        return img


In [None]:
# define transformations to apply to the images (for ImageNet weights)
train_transform  = T.Compose([
    T.Resize((224, 224)),
    #ApplyCLAHE(clip_limit=5.0, tile_grid_size=(8,8), p=1.0),
    T.ColorJitter(brightness=0.1, contrast=0.1), 
    T.ToTensor(),
    AddGaussianNoise(std=0.05, p=0.3),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform  = T.Compose([
    T.Resize((224, 224)),
    #ApplyCLAHE(clip_limit=5.0, tile_grid_size=(8,8), p=1.0),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Define paths to cropped calcification data
calc_train_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Training-png-cropped/labels/calc-train_labels.csv"
calc_test_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Test-png-cropped/labels/calc-test_labels.csv"

calc_train_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Training-png-cropped/images"
calc_test_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Calc-Test-png-cropped/images"

# Apply the transformations to the calcification images
calc_training_data = CustomImageDataset(calc_train_label_dir, calc_train_img_dir, train_transform)
calc_test_data = CustomImageDataset(calc_test_label_dir, calc_test_img_dir, test_transform)

# Define paths to cropped mass data
mass_train_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Training-png-cropped/labels/mass-train_labels.csv"
mass_test_label_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Test-png-cropped/labels/mass-test_labels.csv"

mass_train_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Training-png-cropped/images"
mass_test_img_dir = "/Users/giulia/Desktop/dissertation-mammogram-classification/mammogram-ai-project/Data/Data png cropped/Mass-Test-png-cropped/images"

# Apply the transformations to the mass images
mass_training_data = CustomImageDataset(mass_train_label_dir, mass_train_img_dir, train_transform)
mass_test_data = CustomImageDataset(mass_test_label_dir, mass_test_img_dir, test_transform)

In [None]:
# Merge training datasets
combined_train_data = ConcatDataset([calc_training_data, mass_training_data])

# Merge test datasets
combined_test_data = ConcatDataset([calc_test_data, mass_test_data])

# Split training data into train and validation (maintaining the labels balanced in both datasets)
# Define validation split sizes (80% train, 20% validation)
val_size = 0.2

# Extract labels
labels = []
for dataset in combined_train_data.datasets:
    labels.extend(dataset.img_labels.iloc[:, 2].tolist())

# Perform stratified split
train_idx, val_idx = train_test_split(
    np.arange(len(combined_train_data)), 
    test_size = val_size, 
    stratify = labels,
    random_state = 42
)

# Create subsets
train_dataset = Subset(combined_train_data, train_idx)
val_dataset = Subset(combined_train_data, val_idx)

# Shuffle validation and testing datasets once before creating the dataloaders
# Create shuffled indices
val_indices = np.random.permutation(len(val_dataset))
test_indices = np.random.permutation(len(combined_test_data))

# Apply shuffled indices to dataset
shuffled_val_data = Subset(val_dataset, val_indices)
shuffled_test_data = Subset(combined_test_data, test_indices)

# Check dataset sizes
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(shuffled_val_data)}")
print(f"Total testing samples: {len(shuffled_test_data)}")

In [None]:
# Create DataLoaders
batch_size = 32
num_workers = 4

# shuffle = True
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
# shuffle = False, so that the batches don't change every epoch, making it easier to compare results
val_dataloader = DataLoader(shuffled_val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_dataloader = DataLoader(shuffled_test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [None]:
labels_map = {
    0: "Benign",
    1: "Malignant",
}

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_dataset), size=(1,)).item()
    img, label = train_dataset[sample_idx]

    image_np = np.array(img)
    print(image_np.shape)

    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")

    if(torch.is_tensor(img)):
        plt.imshow(img.permute(1, 2, 0))
    else:
        plt.imshow(img)
plt.show()

In [None]:
# training
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()

    running_loss = 0.0
    correct_predictions = 0

    for batch, (X, y) in tqdm(enumerate(dataloader), total=len(dataloader), desc="Training", leave=True):
        X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True).view(-1, 1).float()

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Accumulate loss
        running_loss += loss.item() 

        # Calculate accuracy
        pred_labels = (pred.sigmoid() > 0.5).float()
        correct_predictions += (pred_labels == y).sum().item()

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    # Calculate overall training accuracy
    epoch_accuracy = correct_predictions / size * 100
    
    # You can calculate and store the average loss at the end of each epoch
    epoch_loss = running_loss / len(dataloader)
    print(f"Epoch Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.2f}%")

    return epoch_loss, epoch_accuracy

In [None]:
# validation
def val(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    
    test_loss, correct = 0, 0
    all_labels = []
    all_preds = []
    all_probs = [] 

    with torch.no_grad():
        for X, y in tqdm(dataloader, desc="Testing", total=num_batches, leave=True):
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True).view(-1, 1).float()
            pred = model(X)
            prob = pred.sigmoid()  # Convert logits to probabilities
            pred_labels = (prob > 0.5).float()  # Convert probabilities to binary labels
            
            test_loss += loss_fn(pred, y).item()
            correct += (pred_labels == y).sum().item()
            
            all_labels.extend(y.cpu().numpy())  # Collect true labels
            all_preds.extend(pred_labels.cpu().numpy())  # Collect predicted labels
            all_probs.extend(prob.cpu().numpy())  # Collect predicted probabilities

    # Compute overall statistics
    test_loss /= num_batches
    accuracy = correct / size * 100
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)

    print(f"Test Error: \n Accuracy: {accuracy:.1f}%, Avg loss: {test_loss:.6f}")
    print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1-score: {f1:.3f}, AUC: {auc:.3f}\n")

    return test_loss, accuracy, precision, recall, f1, auc

In [None]:
# Define global variables for Grad-CAM
gradients = None
activations = None

def backward_hook(module, grad_input, grad_output):
    global gradients
    gradients = grad_output

def forward_hook(module, args, output):
    global activations
    activations = output

def generate_gradcam(model, image):
    """
    Generates Grad-CAM heatmap for a given image.
    """
    global gradients, activations
    
    model.zero_grad()
    output = model(image)
    prob = output.sigmoid()
    pred_label = (prob > 0.5).float()
    
    # Backward pass to get gradients
    output.backward(torch.ones_like(output))  
    
    # Pool gradients across the channels
    pooled_gradients = torch.mean(gradients[0], dim=[0, 2, 3])

    # Weight the channels by corresponding gradients
    for i in range(activations.size()[1]):
        activations[:, i, :, :] *= pooled_gradients[i]

    # Compute heatmap
    heatmap = torch.mean(activations, dim=1).squeeze()
    heatmap = F.relu(heatmap)
    heatmap /= torch.max(heatmap)

    return heatmap.detach().cpu()

def overlay_heatmap(img_tensor, heatmap):
    """
    Overlays the Grad-CAM heatmap on the original image.
    """

    unnorm_img = unnormalize(img_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    original_img = to_pil_image(unnorm_img.clamp(0, 1), mode='RGB')

    # Resize the heatmap to match image size
    overlay = to_pil_image(heatmap, mode='F').resize((224, 224), resample=PIL.Image.BICUBIC)

    # Apply colormap
    cmap = colormaps['jet']
    overlay = (255 * cmap(np.asarray(overlay) ** 2)[:, :, :3]).astype(np.uint8)

    return original_img, overlay

In [None]:
# Function to unnormalise images
def unnormalise(img_tensor, mean, std):
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    return img_tensor * std + mean


In [None]:
def test(dataloader, model, model_code, num_rows=2):
    print("\nEvaluating on Test Set...")
    num_batches = len(dataloader)
    
    model.eval()
    
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    all_probs = []
    
    sample_images = []
    sample_heatmaps = []
    sample_labels = []
    sample_preds = []
    sample_confidences = []

    num_examples = num_rows * 5
    
    # Identify the last convolutional layer
    if model_code == "densenet121":
        last_conv_layer = model.features.denseblock4.denselayer16.conv2
    elif model_code == "densenet169":
        last_conv_layer = model.features.denseblock4.denselayer32.conv2
    # Register hooks for Grad-CAM
    last_conv_layer.register_full_backward_hook(backward_hook)
    last_conv_layer.register_forward_hook(forward_hook)
    
    with torch.no_grad():
        for X, y in tqdm(dataloader, desc="Testing", total=num_batches, leave=True):
            X, y = X.to(device), y.to(device).view(-1, 1).float()
            pred = model(X)
            prob = pred.sigmoid()
            pred_labels = (prob > 0.5).float()

            correct += (pred_labels == y).sum().item()
            total += y.size(0)

            all_labels.extend(y.cpu().numpy())
            all_preds.extend(pred_labels.cpu().numpy())
            all_probs.extend(prob.cpu().numpy())

            # Store sample images and Grad-CAM visualizations
            if len(sample_images) < num_examples:
                for i in range(min(num_examples - len(sample_images), X.shape[0])): 
                    img_tensor = X[i].cpu()
                    # Enable gradients only for Grad-CAM
                    with torch.set_grad_enabled(True):
                        heatmap = generate_gradcam(model, X[i].unsqueeze(0))

                    original_img, overlay_img = overlay_heatmap(img_tensor, heatmap)

                    sample_images.append(original_img)
                    sample_heatmaps.append(overlay_img)
                    sample_labels.append(int(y[i].cpu().item()))
                    sample_preds.append(int(pred_labels[i].cpu().item()))

                    # Compute adjusted confidence score
                    prob_value = prob[i].cpu().item()
                    if pred_labels[i] == 1:
                        confidence = (prob_value - 0.5) * 200
                    else:
                        confidence = (0.5 - prob_value) * 200

                    sample_confidences.append(confidence)

    # Compute final metrics
    accuracy = correct / total * 100
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)

    print(f"Test Accuracy: {accuracy:.2f}%")
    print(f"Precision: {precision:.3f}, Recall: {recall:.3f}, F1-score: {f1:.3f}, AUC: {auc:.3f}")

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Plot confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Benign", "Malignant"])
    disp.plot(cmap="Blues", values_format="d")
    plt.title("Confusion Matrix")
    plt.show()

    # Plot original images and Grad-CAM heatmaps
    fig, axes = plt.subplots(num_rows * 2, 5, figsize=(15, 6 * num_rows))
    fig.suptitle("Sample Predictions with Grad-CAM", fontsize=16)

    for i in range(num_examples):
        row = (i // 5) * 2
        col = i % 5

        # Plot original image
        axes[row, col].imshow(sample_images[i], cmap="gray")
        true_label = "Malignant" if sample_labels[i] == 1 else "Benign"
        pred_label = "Malignant" if sample_preds[i] == 1 else "Benign"
        confidence = sample_confidences[i]

        axes[row, col].set_title(f"True: {true_label}\nPred: {pred_label}\nConf: {confidence:.1f}%")
        axes[row, col].axis("off")

        # Plot Grad-CAM heatmap
        axes[row + 1, col].imshow(sample_images[i], cmap="gray")
        axes[row + 1, col].imshow(sample_heatmaps[i], alpha=0.4, interpolation="nearest")
        axes[row + 1, col].set_title("Grad-CAM")
        axes[row + 1, col].axis("off")

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

    return accuracy, precision, recall, f1, auc

In [18]:
print("GPU Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0))
print("Current Device:", torch.cuda.current_device())

GPU Available: True
GPU Name: Tesla T4
Current Device: 0


In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
#torch.set_num_threads(1)

In [None]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR 
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights

# Using pretrained EfficientNet-B0
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) 

# Modify the classifier for binary classification
model.classifier = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(in_features=1280, out_features=1)
)

learning_rate = 1e-5
weight_decay = 1e-4
epochs = 10

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(device)

# Define loss function and optimiser
loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.7], device=device))  
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  

# Define learning rate scheduler
#scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

cuda:0


In [None]:
# Initialize tracking lists
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
val_precisions = []
val_recalls = []
val_f1s = []
val_aucs = []

# Training loop with validation
patience = 3
patience_counter = 0

# Define path to save the best model
best_model_path = "best_model_eb0.pth"

# Initialize variables to track best model
best_val_loss, best_val_auc = float("inf"), 0.0

for t in range(epochs):
    print(f"\nEpoch {t+1}\n-------------------------------")
    
    start_time = time.time()
    
    train_loss, train_accuracy = train(train_dataloader, model, loss_fn, optimizer)
    val_loss, val_acc, val_prec, val_rec, val_f1, val_auc = val(val_dataloader, model, loss_fn)

    end_time = time.time()
    
    epoch_duration = end_time - start_time
    epoch_minutes = epoch_duration / 60
    print(f"Epoch {t+1} took {epoch_minutes:.2f} minutes")
    
    # Store metrics for plotting
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_precisions.append(val_prec)
    val_recalls.append(val_rec)
    val_f1s.append(val_f1)
    val_aucs.append(val_auc)

    # Early stopping and best model saving
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f"New best model saved with loss: {val_loss:.4f}, accuracy: {val_acc:.2f}%, AUC: {val_auc:.2f}\n")
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping at epoch {t+1}")
        break

    # Step the scheduler at the end of the epoch
    scheduler.step()
    print(f"Epoch {t+1} completed, LR: {scheduler.get_last_lr()}")

print("Training Done!")

# Plot loss and accuracy
plt.figure(figsize=(12, 6))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Training Loss", marker="o")
plt.plot(val_losses, label="Validation Loss", marker="o")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label="Training Accuracy", marker="o")
plt.plot(val_accuracies, label="Validation Accuracy", marker="o")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Training & Validation Accuracy")
plt.legend()

plt.show()

In [None]:
# Load the best model state
best_model_path = "best_model_eb0.pth"

# Using EfficientNet-B0
best_model = efficientnet_b0(weights=None) 

# Modify the classifier for binary classification
best_model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(in_features=1280, out_features=1)
)

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
best_model = best_model.to(device)

best_model.load_state_dict(torch.load(best_model_path, weights_only=True))
best_model.eval()

# Run final test
test_accuracy, test_precision, test_recall, test_f1, test_auc = test(test_dataloader, best_model, model_code="efficientnetB0", num_rows=5)