# End-to-End Deep Learning Model Training and Evaluation Pipeline


## Overview

This document outlines a comprehensive procedure for training, evaluating, and testing a deep learning model using the SwinUNETR architecture. The workflow includes initialising the dataset and dataloader, setting up the model, training with early stopping, evaluating performance using metrics like SSIM and PSNR, and saving the results for further analysis.

### Key Steps:

1. **Data Preparation**: Loading and preprocessing the dataset for both training and validation.
2. **Model Initialisation**: Setting up the SwinUNETR model architecture and loading pre-trained weights.
3. **Training**: Executing the training loop with early stopping and logging the progress.
4. **Evaluation**: Running the trained model on validation data to compute key metrics.
5. **Testing**: Testing the model on a subset of the validation data and saving the results.
6. **Results Saving**: Exporting the performance metrics and model state for future reference and analysis.

This pipeline is designed to be modular and flexible, allowing for easy adjustments to hyperparameters, model architecture, and evaluation criteria.


## Importing Necessary Libraries

This section imports the necessary libraries for logging, file handling, system operations, and deep learning model development.


In [None]:
# Install tensorboardX, which is used for logging and visualizing training metrics in TensorBoard
!pip install tensorboardX

# Install MONAI (Medical Open Network for AI) version 1.3.2 from the conda-forge channel.
# MONAI is a framework for developing deep learning models in medical imaging.
!conda install -c conda-forge monai=1.3.2

In [None]:
# Importing necessary libraries for logging, file handling, and system operations
import logging  # Provides a flexible framework for emitting log messages from Python programs.
import os  # Provides a way of using operating system-dependent functionality, such as reading or writing to the file system.
import shutil  # Used for high-level file operations, such as copying or removing files and directories.
import sys  # Provides access to some variables used or maintained by the Python interpreter and to functions that interact with the interpreter.
import tempfile  # Used to create temporary files and directories.
import random  # Implements pseudo-random number generators for various distributions.
import numpy as np  # Fundamental package for scientific computing with Python, providing support for large, multi-dimensional arrays and matrices.
from tqdm import trange  # Provides a progress bar for loops, making it easier to track long-running tasks.
import matplotlib.pyplot as plt  # A plotting library used for creating static, animated, and interactive visualizations in Python.
import torch  # PyTorch, an open-source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing.

# Importing functions and classes from the MONAI library, a deep learning framework specialized for healthcare imaging
from monai.apps import download_and_extract  # Utility to download and extract files, particularly useful for datasets.
from monai.apps import download_url
from monai.config import print_config  # Prints the current configuration of MONAI, including the environment, installed packages, and versions.
from monai.data import CacheDataset, DataLoader  # CacheDataset caches data and is useful for datasets that fit in memory. DataLoader is used to load the data in batches.
from monai.networks.nets import AutoEncoder, SwinUNETR  # Importing an AutoEncoder model, which is a type of neural network used for unsupervised learning, particularly for dimensionality reduction.
from monai.transforms import (  # Importing various transformations to be applied to the data.
    EnsureChannelFirstD,  # Ensures the channel dimension is first in the data tensor.
    Compose,  # Allows the chaining of multiple transformations to be applied sequentially.
    LoadImageD,  # Loads images from a file.
    RandFlipD,  # Randomly flips the image along a specified axis.
    RandRotateD,  # Randomly rotates the image within a specified angle range.
    RandZoomD,  # Randomly zooms in or out of the image within a specified range.
    ScaleIntensityD,  # Scales the intensity of the image to a specified range.
    EnsureTypeD,  # Ensures the output is of a specific data type.
    Lambda,  # Allows for custom transformations using a lambda function.
)
from monai.utils import set_determinism  # Sets the seed for random number generators to ensure reproducibility.
from monai.networks.utils import copy_model_state  # Utility function for copying the model state.
from monai.networks.nets.swin_unetr import filter_swinunetr  # Filters specific layers in the SwinUNETR model.
from monai.transforms import RandSpatialCropd  # Randomly crops a portion of the image.

from torch.amp import autocast, GradScaler  # Mixed precision training utilities
import math  # Provides access to mathematical functions.
import warnings  # Used to issue warning messages.
from typing import List  # Allows for type hinting, specifying that a variable is a list.

from torch import nn as nn  # PyTorch module containing all neural network layers and functions.
from torch.optim import Adam, Optimizer  # Adam optimizer and base optimizer class from PyTorch.
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler  # Learning rate scheduler and base class for custom schedulers.
import glob  # For finding all file paths matching a specified pattern.
import pydicom  # For reading, modifying, and writing DICOM files.
from tensorboardX import SummaryWriter  # Writes TensorBoard-compatible logs for PyTorch models.
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter

# Print the current MONAI configuration
print_config()  # This line will print the MONAI configuration, including environment details and installed packages.

import torch.optim as optim  # Optimizer functions from PyTorch.
from torch.utils.data import DataLoader  # DataLoader class to load data in batches.
from torch.utils.tensorboard import SummaryWriter  # Writes TensorBoard logs, similar to tensorboardX but included in PyTorch.
from tqdm import tqdm  # Provides a progress bar for loops, similar to trange.
import time  # Provides various time-related functions.
import torch.nn.functional as F  # Contains functions used in building neural networks, such as activation functions and loss functions.
import psutil  # For memory tracking.
import gc  # Python's garbage collection module, used to manually free up memory.
from torch.cuda.amp import autocast  # For automatic mixed precision (AMP) in CUDA.
import pandas as pd  # For data manipulation and analysis.

# Standard library imports
from pathlib import Path  # To handle and manipulate filesystem paths.

# Third-party imports
from PIL import Image  # For opening, manipulating, and saving many different image file formats.

# PyTorch-I/O extension
import torchio as tio  # For medical image processing in PyTorch.

# pydicom imports
from pydicom.data import get_testdata_file  # For accessing test DICOM files.
from pydicom.fileset import FileSet  # For working with DICOM FileSets.

# Scikit-learn imports
from sklearn.model_selection import train_test_split  # For splitting datasets into training and testing sets.

from collections import defaultdict  # For creating dictionaries with default values.
from monai.transforms import apply_transform  # Applies a transform to data.
from tqdm import tqdm
from piqa import PSNR, SSIM


In [None]:
torch.cuda.empty_cache()  # Releases all unused memory cached by the CUDA backend, freeing up GPU memory.

In [3]:
# Set up basic logging configuration to output log messages to the console (stdout).
# The logging level is set to INFO, meaning all messages at this level and above
# (INFO, WARNING, ERROR, CRITICAL) will be displayed.
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


# Set deterministic behavior for reproducibility. This is important in experiments where
# you want to ensure that the results are the same every time the code is run. 
# The seed value is set to 0, which will be used to initialize the random number generator.
set_determinism(0)

# Determine the device to run the computations on. If a CUDA-capable GPU is available,
# the device will be set to "cuda" (meaning GPU); otherwise, it will fall back to "cpu".
# This allows the code to take advantage of GPU acceleration if possible.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Defining Transformations for Training and Testing

In this section, we define the transformations applied to the training and testing datasets using `RandSpatialCropd`. This function crops a 3D region from the image with specific parameters.


In [None]:
# Define transformations for training and testing datasets
# RandSpatialCropd randomly crops a 3D region from the image. 
# - keys=["im"]: Specifies the key in the dictionary that holds the image data.
# - roi_size=(96, 96, 96): Specifies the size of the region of interest (ROI) to be cropped (in 3D dimensions).
# - random_center=True: Ensures that the center of the cropped region is selected randomly.
# - random_size=False: The size of the cropped region is fixed, not random.

train_transforms = RandSpatialCropd(keys=["im"], roi_size=(96, 96, 96), random_center=True, random_size=False)
test_transforms = RandSpatialCropd(keys=["im"], roi_size=(96, 96, 96), random_center=True, random_size=False)

## Data Preparation: Splitting Labels into Training and Validation Sets

This section of the code reads label data from a CSV file, splits the data into training and validation sets, and saves these splits as separate CSV files for future reference.


In [None]:
# Define the base directory and CSV path
training_data_dir = r'c:\Users\scmmw\OneDrive - University of Leeds\t1_vibe_we_hand_subset'
csv_path = r'C:\Users\scmmw\OneDrive - University of Leeds\t1_vibe_we_hand_subset\training_labels_subset.csv'

# Read the labels DataFrame
labels_df = pd.read_csv(csv_path)

# Split the data into training and validation sets
train_df, valid_df = train_test_split(labels_df, test_size=0.2, random_state=42, stratify=labels_df['progression'])

# Save the splits for reference (optional)
train_df.to_csv(os.path.join(training_data_dir, 'train_split.csv'), index=False)
valid_df.to_csv(os.path.join(training_data_dir, 'valid_split.csv'), index=False)


In [None]:
# Create a list of full paths for training patient directories.
# os.path.join() is used to concatenate the base directory with the patient IDs from the training DataFrame.
train_patient_dirs = [os.path.join(training_data_dir, subject_name) for subject_name in train_df['patient ID'].tolist()]

# Create a list of full paths for validation patient directories.
# os.path.join() is used to concatenate the base directory with the patient IDs from the validation DataFrame.
valid_patient_dirs = [os.path.join(training_data_dir, subject_name) for subject_name in valid_df['patient ID'].tolist()]

## HandScanDataset2: Custom PyTorch Dataset Class

This section defines the `HandScanDataset2` class, a custom PyTorch `Dataset` for handling and processing hand scan images stored in DICOM format. The class includes methods for loading, transforming, and preparing the data for use in deep learning models.


In [None]:


class HandScanDataset2(Dataset):
    def __init__(self, data_dir, transform=None, device=None):
        """
        Args:
            patient_dirs (list): List of paths to the patient directories.
            transform (callable, optional): Optional transform to be applied on a sample.
            device (torch.device, optional): Device to use for tensor operations (e.g., 'cuda' or 'cpu').
        """
        self.data_dir = sorted(data_dir)  # Sort the list of patient directories for consistent ordering.
        self.transform = transform  # Store the transform function if provided.
        self.device = device if device else torch.device('cpu')  # Set the device to 'cuda' if provided, otherwise default to 'cpu'.

    def __len__(self):
        return len(self.data_dir)  # Return the total number of patient directories.

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()  # Convert tensor index to a list if necessary.

        # Get the patient directory based on the index
        patient_dir = self.data_dir[idx]

        # Load and process images from the patient directory
        images = self.get_best_patient_images(patient_dir)

        if len(images) == 0:
            raise ValueError(f"No images found for patient directory {patient_dir}")

        # Correct use of autocast and tensor creation
        with torch.amp.autocast(device_type=self.device.type, dtype=torch.float16):
            # Proper handling of tensor conversion
            images = [
                torch.tensor(img, dtype=torch.float32).to(self.device) if not isinstance(img, torch.Tensor) else img.to(self.device).float()
                for img in images
            ]
            
            images_tensor = torch.stack(images, dim=0).to(self.device)  # Stack images along a new dimension.
            images_tensor_channel = torch.unsqueeze(images_tensor, 0)  # Add a channel dimension.

        # Prepare the data dictionary with the image tensor
        data = {"im": images_tensor_channel}

        # Apply any provided transforms
        if self.transform:
            data = self.transform(data)  # Assuming transform takes and returns a dictionary.

        return data

    def get_best_patient_images(self, base_path):
        """ 
        Process all images in the 't1_vibe_we' subfolder of each subject.
        Sort images by Instance Number and return a sequence of a fixed length.
        """
        seq_len = 32  # Desired sequence length.
        all_images = []
        dicom_files = [] 
        target_size = (512, 512)  # Set a fixed image shape.

        for root, dirs, files in os.walk(base_path):
            if 't1_vibe_we' in dirs:
                t1_vibe_we_path = os.path.join(root, 't1_vibe_we')
                
                # Get the images in the 't1_vibe_we' sequence
                dicom_files = []
                for image_path in glob.glob(os.path.join(t1_vibe_we_path, '*')):
                    try:
                        dicom_file = pydicom.dcmread(image_path)
                        dicom_files.append((dicom_file, image_path))
                    except Exception as e:
                        print(f"Error reading {image_path}: {e}")

            # Sort the files by Instance Number
            dicom_files.sort(key=lambda x: x[0].InstanceNumber)
            
            # Remove duplicates
            dicom_files = self.remove_duplicates(dicom_files)

            # Find the best slice
            if dicom_files:
                # Find the slice with the highest intensity
                max_sum = -1
                best_dicom_file, best_image_path = None, None
                for dicom_file, image_path in dicom_files:
                    image = dicom_file.pixel_array
                    image_sum = np.sum(image)
                    if image_sum > max_sum:
                        max_sum = image_sum
                        best_dicom_file, best_image_path = dicom_file, image_path

                if best_dicom_file is not None:
                    best_instance_number = best_dicom_file.InstanceNumber

                    # Calculate the central slice index
                    central_index = best_instance_number - 1  # InstanceNumber is 1-based.

                    # Determine the range of slices to extract the central 5 slices
                    start_index = max(0, central_index - 2)
                    end_index = min(len(dicom_files), central_index + 3)

                    # Extract the central 5 slices
                    selected_slices = dicom_files[start_index:end_index]

                    images = []
                    for dicom_file, image_path in selected_slices:
                        try:
                            image = self.process_dicom_image(image_path, target_size=target_size)
                            images.append(image)
                        except Exception as e:
                            print(f"Error processing image {image_path}: {e}")

                    # Pad to the required sequence length if needed
                    if len(images) < seq_len:
                        # Pad with zero images of the same shape as the original images
                        diff = seq_len - len(images)
                        images.extend([torch.zeros(target_size, dtype=torch.float32).to(self.device) for _ in range(diff)])

                    all_images.extend(images)

        return all_images

    def remove_duplicates(self, dicom_files):
        """ Remove duplicate instance numbers, keeping only the slice with the highest sum of intensities. """
        instance_dict = defaultdict(list)

        for dicom_file, image_path in dicom_files:
            instance_number = dicom_file.InstanceNumber
            instance_dict[instance_number].append((dicom_file, image_path))

        # Compare DICOM files with the same Instance Number
        unique_dicom_files = []
        for instance_number, files in instance_dict.items():
            if len(files) > 1:
                best_slice = self.find_best_slice(files)
                unique_dicom_files.append(best_slice)
            else:
                unique_dicom_files.append(files[0])

        return unique_dicom_files

    def find_best_slice(self, dicom_files):
        """ Find the slice with the 'DOTAREM' ContrastBolusAgent or, as a fallback, return the first available slice. """
        best_slice = None

        # Check for the slice with 'DOTAREM'
        for dicom_file, image_path in dicom_files:
            if hasattr(dicom_file, 'ContrastBolusAgent') and dicom_file.ContrastBolusAgent == 'DOTAREM':
                best_slice = (dicom_file, image_path)
                break  # Stop searching once we find the 'DOTAREM' slice

        # Fallback: If no slice with 'DOTAREM' is found, return the first slice
        if best_slice is None:
            best_slice = dicom_files[0]

        return best_slice

    def process_dicom_image(self, path: str, resize=True, target_size=(512, 512)) -> torch.Tensor:
        dicom_file = pydicom.dcmread(path)
        image = torch.tensor(dicom_file.pixel_array, dtype=torch.float32).to(self.device)
        
        # Skip invalid images
        if 0 in image.shape:
            print(f"Skipping image due to invalid shape: {image.shape}")
            return torch.zeros(target_size, dtype=torch.float32).to(self.device)
        
        # Use autocast to optimize tensor operations
        with torch.amp.autocast(device_type=self.device.type, dtype=torch.float16):
            # Normalize the image: Zero mean and unit variance
            mean = image.mean()
            std = image.std()
            image = (image - mean) / (std + 1e-7)

            # Apply 95% clipping
            lower_bound = torch.quantile(image, 0.025)
            upper_bound = torch.quantile(image, 0.975)
            image = torch.clamp(image, lower_bound, upper_bound)

            # Normalize again after clipping
            mean = image.mean()
            std = image.std()
            image = (image - mean) / (std + 1e-7)

            # Resize the image to the target size
            if resize:
                image = image.unsqueeze(0)  # Add channel dimension for resizing
                image = torch.nn.functional.interpolate(image.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
        
        return image


In [None]:
# Define the device to be used for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [8]:
batch_size = 1  # Set the batch size to 1 for loading data one sample at a time.
num_workers = 0  # Set the number of worker processes for data loading. 0 means data will be loaded in the main process.

# Select a subset of subjects from the training set (e.g., the first 80 subjects).
train_subset_df = train_patient_dirs[:80]

# Initialize the HandScanDataset2 dataset with the selected subjects for training.
# The dataset will apply the specified transformations and use the CUDA device for tensor operations.
train_ds = HandScanDataset2(data_dir=train_subset_df, transform=train_transforms, device=torch.device("cuda"))

# Create a DataLoader for the training dataset.
# The DataLoader will iterate over the dataset in batches, shuffling the data to ensure randomness.
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# Select a subset of subjects from the validation set (e.g., the first 20 subjects).
valid_subset_df = valid_patient_dirs[:20]

# Initialize the HandScanDataset2 dataset with the selected subjects for validation.
# The dataset will apply the specified transformations and use the CUDA device for tensor operations.
train_dataset = HandScanDataset2(data_dir=valid_subset_df, transform=test_transforms, device=torch.device("cuda"))

# Create a DataLoader for the validation dataset.
# The DataLoader will iterate over the dataset in batches, without shuffling, to preserve the order of data.
test_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


## SwinUNETR Model Initialisation and Loading Pre-Trained Weights

This section demonstrates how to initialise a SwinUNETR model with specific parameters and load pre-trained weights into the model. The weights are downloaded from a remote resource, and only the layers that match the pre-trained model are updated.


In [None]:
# Path where the pre-trained weights will be saved
pre_training_weights_path = 'C:/Users/scmmw/OneDrive - University of Leeds/ssl_pretrained_weights'

# Initialize the SwinUNETR model with the specific parameters to match the pre-trained model
swin_model = SwinUNETR(
    img_size=(32, 96, 96),  # Input image size
    in_channels=1,  # Number of input channels (e.g., grayscale images)
    out_channels=1,  # Number of output channels (e.g., single segmentation class)
    feature_size=48,  # Size of features to match the pre-trained model
    depths=(2, 2, 2, 2),  # Depths of the layers in the model
    num_heads=(3, 6, 12, 24),  # Number of attention heads in each layer
    spatial_dims=3,  # The model operates in 3D space
    use_checkpoint=True  # Enables gradient checkpointing to save memory during training
)
swin_model = swin_model.to(torch.float32)  # Ensure the model operates in 32-bit floating-point precision
swin_model = swin_model.to(device)  # Move the model to the specified device (e.g., GPU)

# URL of the pre-trained weights to download
resource = (
    "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
)

# Download the pre-trained weights from the given URL
download_url(resource, pre_training_weights_path)

# Load the downloaded weights into memory
ssl_weights = torch.load(pre_training_weights_path)["model"]

# Copy the pre-trained weights into the SwinUNETR model, filtering layers as needed
dst_dict, loaded, not_loaded = copy_model_state(swin_model, ssl_weights, filter_func=filter_swinunetr)


In [None]:
# Freeze specific layers if you don't want them to be updated during fine-tuning
for name, param in swin_model.named_parameters():
    if "swinViT.layers1" in name or "swinViT.layers2" in name:
        param.requires_grad = False  # Freeze these layers


# Freeze specific layers if you don't want them to be updated during fine-tuning
for name, param in swin_model.named_parameters():
    if "patch_embed" in name:
        param.requires_grad = False  # Freeze these layers


for name, param in swin_model.named_parameters():
    if 'encoder' in name and 'encoder4' not in name and 'encoder10' not in name:
        param.requires_grad = False  # Freeze these layers

# Freeze some parts of the decoder if needed
for name, param in swin_model.named_parameters():
     if 'decoder' in name and 'decoder4' not in name and 'decoder5' not in name:
         param.requires_grad = False

# Print which layers will be fine-tuned
for name, param in swin_model.named_parameters():
    if param.requires_grad:
        print(f"Fine-tuning layer: {name}")

In [None]:
# Define path to where I will save results
ResultsPath = '/Users/eleanorbolton/Documents/project_repo/predicting-rheumatoid-arthritis-progression-using-hand-mri-scans-from-at-risk-patients/Predicting-Rheumatoid-Arthritis-Progression-Using-Hand-MRI-Scans-from-At-Risk-Patients/predicting-rheumatoid-arthritis-progression-using-hand-mri-scans-from-at-risk-patients/src/results'
os.makedirs(ResultsPath, exist_ok=True)
ResultsPath

## Training Function with SwinUNETR and Early Stopping

This section defines a `train` function that handles the training of the SwinUNETR model using PyTorch. It includes features like mixed precision training, early stopping, and logging to TensorBoard.


In [10]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        """
        Args:
            patience (int): Number of epochs to wait after the last time validation loss improved.
            min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience  # The number of epochs to wait before stopping after no improvement.
        self.min_delta = min_delta  # The minimum change in validation loss to consider as an improvement.
        self.counter = 0  # Counts the number of epochs with no improvement.
        self.min_validation_loss = float('inf')  # Stores the minimum validation loss encountered.
        self.activation_count = 0  # Counts the number of times early stopping has been activated.

    def early_stop(self, validation_loss):
        """
        Checks if early stopping should be triggered.

        Args:
            validation_loss (float): The current epoch's validation loss.

        Returns:
            bool: True if early stopping criteria are met, False otherwise.
        """
        if validation_loss < self.min_validation_loss:
            # Update the minimum validation loss and reset counter if there's an improvement.
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            # Increment the counter if no improvement and check against patience.
            self.counter += 1
            if self.counter >= self.patience:
                self.activation_count += 1
                return True  # Trigger early stopping
        return False

# Instantiate EarlyStopper
early_stopper = EarlyStopper(patience=10, min_delta=0.001)  # Creates an EarlyStopper with specified patience and minimum delta.


In [20]:
# Initialize paths
log_dir = ResultsPath  # Directory where logs and TensorBoard data will be saved

# Set up basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("swin_transformer_3d")  # Create a logger object for the training script

# Create a TensorBoard writer
writer = SummaryWriter(log_dir=log_dir)  # Initialize TensorBoard writer with the specified log directory

# Enable cuDNN auto-tuner to find the best algorithm for your hardware
torch.backends.cudnn.benchmark = True  # This can improve performance for some models

# Training function using SwinUNETR
def train(train_loader, test_loader, max_epochs=10, learning_rate=1e-3, patience=10, model=None):
    # Initialize the SwinUNETR model
    logger.info('Initializing model')
    
    # Move the model to the specified device (e.g., GPU)
    model.to(device)

    # Define the loss function (Mean Squared Error) and optimizer (Adam)
    loss_function = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    early_stopper = EarlyStopper(patience=patience, min_delta=0.003)  # Early stopping with defined patience and delta

    # Initialize GradScaler for mixed precision training
    scaler = GradScaler()

    epoch_loss_values = []  # List to store the loss values for each epoch

    for epoch in range(max_epochs):
        epoch_start_time = time.time()  # Track the start time of the epoch
        model.train()  # Set the model to training mode
        epoch_loss = 0  # Initialize the loss for this epoch
        step = 0  # Initialize the step counter

        # Progress bar for the epoch
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{max_epochs}", unit="batch") as pbar:
            for batch_data in train_loader:
                batch_start_time = time.time()  # Track the start time of the batch
                step += 1

                inputs = batch_data["im"].to(torch.float32).to(device)  # Move input data to the device
                optimizer.zero_grad()  # Reset the gradients

                # Use autocast for mixed precision during the forward pass
                with autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(inputs)  # Forward pass through the model
                    loss = loss_function(outputs, inputs)  # Calculate the loss
                
                # Scale the loss and perform the backward pass
                scaler.scale(loss).backward()
                scaler.step(optimizer)  # Update model parameters
                scaler.update()  # Update the scaler for the next iteration

                epoch_loss += loss.item()  # Accumulate the loss

                batch_time = time.time() - batch_start_time  # Calculate the time taken for the batch
                pbar.set_postfix({"Batch Time": f"{batch_time:.4f} sec"})  # Update the progress bar with batch time
                pbar.update(1)  # Move the progress bar forward

        epoch_loss /= step  # Calculate the average loss for the epoch
        epoch_loss_values.append(epoch_loss)  # Append the loss value to the list

        logger.info(f"Epoch {epoch + 1}/{max_epochs}, Loss: {epoch_loss:.4f}")  # Log the loss for this epoch

        # Validation loop with progress bar
        total_val_loss = 0.0  # Initialize the total validation loss
        model.eval()  # Set the model to evaluation mode
        with tqdm(total=len(test_loader), desc="Validation", unit="batch") as pbar:
            with torch.no_grad():  # Disable gradient computation for validation
                for batch_data in test_loader:
                    inputs = batch_data["im"].to(torch.float32).to(device)  # Move input data to the device

                    # Use autocast for mixed precision during the forward pass in validation
                    with autocast(device_type='cuda', dtype=torch.float16):
                        outputs = model(inputs)  # Forward pass through the model
                        loss_L1 = loss_function(outputs, inputs)  # Calculate the validation loss
                        total_val_loss += loss_L1.item()  # Accumulate the validation loss
                    
                    pbar.update(1)  # Move the progress bar forward

        avg_val_loss = total_val_loss / len(test_loader)  # Calculate the average validation loss
        logger.info(f"Validation Loss: {avg_val_loss:.4f}")  # Log the validation loss

        epoch_time = time.time() - epoch_start_time  # Calculate the time taken for the epoch
        logger.info(f"Epoch {epoch + 1} took {epoch_time:.4f} seconds")  # Log the epoch duration

        # Early stopping check
        if early_stopper.early_stop(avg_val_loss):
            # Save model parameters if early stopping criteria are met for the first time
            if early_stopper.counter == 1:
                logger.info(f"\nEarly stopping criteria reached!\nSaving model parameters to {ResultsPath}/optimal_model_weights_epoch_{epoch + 1}.pth\n")
                opt_model_filepath = f"{ResultsPath}/optimal_model_weights_epoch_{epoch + 1}.pth"
                torch.save(model.state_dict(), opt_model_filepath)  # Save the model's state dictionary

        # Logging to TensorBoard
        writer.add_scalar('Loss/train', epoch_loss, epoch)  # Log the training loss for TensorBoard
        writer.add_scalar('Loss/validation', avg_val_loss, epoch)  # Log the validation loss for TensorBoard

    return model, epoch_loss_values  # Return the trained model and the loss values


In [None]:
import matplotlib.pyplot as plt

def plot_loss_curve(epoch_losses, early_stop_epoch=None, save_path=None):
    """
    Plot the training and validation loss curves and mark the early stopping point.

    Args:
        epoch_losses (list of tuples): List of tuples where each tuple contains (train_loss, val_loss) for each epoch.
        early_stop_epoch (int, optional): Epoch where early stopping occurred. Defaults to None.
        save_path (str, optional): Path to save the plot image. Defaults to None.
    """
    epochs = range(1, len(epoch_losses) + 1)
    train_losses = [loss[0] for loss in epoch_losses]  # Extract training losses
    val_losses = [loss[1] for loss in epoch_losses]  # Extract validation losses

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, 'b', label='Training Loss', marker='o')
    plt.plot(epochs, val_losses, 'g', label='Validation Loss', marker='o')

    if early_stop_epoch is not None:
        plt.axvline(x=early_stop_epoch, color='r', linestyle='--', label='Early Stopping')

    plt.title('Training and Validation Loss Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    if save_path:
        plt.savefig(save_path)

    plt.show()



In [None]:
# Set hyperparameters
max_epochs = 85  # Maximum number of training epochs
learning_rate = 1e-2  # Learning rate for the optimizer
patience = 10  # Number of epochs to wait for improvement before early stopping
models = []  # List to store trained models
epoch_losses = []  # List to store loss values for each epoch

# Run the training function
model, epoch_loss = train(train_loader, test_loader, max_epochs=max_epochs, learning_rate=learning_rate, patience=patience, model=swin_model)

# Store the trained model and loss values
models.append(model)  # Append the trained model to the models list
epoch_losses.append(epoch_loss)  # Append the epoch loss values to the epoch_losses list

# Convert the epoch loss list to a DataFrame
epoch_loss_df = pd.DataFrame(epoch_loss, columns=['Train Loss', 'Validation Loss'])  # Create a DataFrame for the losses

# Define the path to save the CSV file
csv_path = f"{ResultsPath}/epoch_losses.csv"  # Path to save the loss values as a CSV

# Save the DataFrame to a CSV file
epoch_loss_df.to_csv(csv_path, index=False)  # Save the epoch losses to a CSV file
logger.info(f"Epoch losses saved to {csv_path}")  # Log the save operation

# Save the model and state dict after training
model_path = f"{ResultsPath}/swin_unetr_model.pth"  # Path to save the full model
state_dict_path = f"{ResultsPath}/swin_unetr_state_dict.pth"  # Path to save the model's state dictionary

# Save the full model
torch.save(model, model_path)  # Save the complete model

# Save only the state dictionary (parameters) of the model
torch.save({"state_dict": model.state_dict()}, state_dict_path)
logger.info(f"State dict saved to {state_dict_path}")  # Log the save operation


In [None]:

plot_loss_curve(epoch_losses, early_stop_epoch=None, save_path=f"{ResultsPath}/loss_curve_Val_train.png")

## Testing Procedure: Evaluating the Model on the Test Set


In [None]:

def visualize_slice(input_image, output_image, slice_idx, title_prefix=""):
    """
    Visualize a slice from the input and output images.

    Args:
        input_image (numpy array): The input image as a numpy array.
        output_image (numpy array): The output image as a numpy array.
        slice_idx (int): The index of the slice to visualize.
        title_prefix (str): Prefix for the title of the plots.
    """
    plt.figure(figsize=(12, 6))  # Set up the figure size for the plot

    # Plot the input slice
    plt.subplot(1, 2, 1)  # Create a subplot for the input image
    plt.imshow(input_image[slice_idx], cmap='gray')  # Display the input slice with grayscale colormap
    plt.title(f'{title_prefix} Input Slice {slice_idx}')  # Set the title of the plot
    plt.axis('off')  # Turn off the axis for better visualization

    # Plot the output slice
    plt.subplot(1, 2, 2)  # Create a subplot for the output image
    plt.imshow(output_image[slice_idx], cmap='gray')  # Display the output slice with grayscale colormap
    plt.title(f'{title_prefix} Output Slice {slice_idx}')  # Set the title of the plot
    plt.axis('off')  # Turn off the axis for better visualization

    plt.show()  # Display the plot


In [None]:

def normalize_tensor(tensor):
    """Normalize the tensor to the [0, 1] range."""
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    return (tensor - tensor_min) / (tensor_max - tensor_min + 1e-7)

def clamp_tensor(tensor, min_val=0.0, max_val=1.0):
    """Clamp the tensor to the range [min_val, max_val]."""
    return torch.clamp(tensor, min=min_val, max=max_val)

def normalize_psnr(psnr_value, max_psnr=100.0):
    """Normalize the PSNR value to [0, 1] based on the max_psnr."""
    return min(psnr_value, max_psnr) / max_psnr

def run_test(model, test_loader, loss_function, max_psnr=100.0):
    """Run a test on the model using a test data loader, and calculate the loss, SSIM, and PSNR metrics."""
    model.eval()  # Set the model to evaluation mode
    total_test_loss = 0.0  # Initialize the total test loss
    individual_losses = []  # List to store individual loss values
    ssim_scores = []  # List to store SSIM scores
    psnr_scores = []  # List to store PSNR scores

    # Initialize SSIM and PSNR metrics for grayscale images (1 channel)
    ssim_metric = SSIM(n_channels=1).to(device)
    psnr_metric = PSNR().to(device)

    with torch.no_grad():  # Disable gradient computation for testing
        with tqdm(total=len(test_loader), desc="Testing on last 20 images", unit="batch") as pbar:
            for batch_idx, batch_data in enumerate(test_loader):
                inputs = batch_data["im"].to(torch.float32).to(device)  # Move inputs to the device

                with autocast(dtype=torch.float16):  # Use autocast for mixed precision
                    outputs = model(inputs)  # Forward pass through the model
                    loss = loss_function(outputs, inputs)  # Calculate the loss
                    total_test_loss += loss.item()  # Accumulate the loss

                for i in range(inputs.shape[0]):
                    inputs_all = inputs[i].to(torch.float32).to(device)
                    outputs_pall = outputs[i].detach().to(torch.float32).to(device)                    
                    
                    # Prepare tensors for PSNR calculation
                    inputs_psnr = inputs[i].to(torch.float32).to(device)
                    outputs_psnr = outputs[i].detach().to(torch.float32).to(device)

                    # Clamp tensors to ensure all values are non-negative
                    inputs_psnr = clamp_tensor(inputs_psnr)
                    outputs_psnr = clamp_tensor(outputs_psnr)

                    # Normalize input and output images for SSIM calculation
                    input_image = normalize_tensor(inputs[i]).to(torch.float32).to(device)
                    output_image = normalize_tensor(outputs[i].detach()).to(torch.float32).to(device)

                    # Calculate SSIM and PSNR using piqa for grayscale images
                    ssim_score = ssim_metric(input_image.unsqueeze(0), output_image.unsqueeze(0)).item()
                    psnr_value = psnr_metric(inputs_psnr.unsqueeze(0), outputs_psnr.unsqueeze(0)).item()

                    ssim_scores.append(ssim_score)  # Store the SSIM score
                    psnr_scores.append(psnr_value)  # Store the PSNR score

                    individual_losses.append(loss.item())  # Store the individual loss value

                    # Visualize a slice of the input and output images
                    slice_idx = input_image.shape[1] // 2
                    visualize_slice(input_image.cpu().numpy()[0], output_image.cpu().numpy()[0], slice_idx, title_prefix=f"Batch {batch_idx + 1}, Image {i + 1}")

                pbar.update(1)  # Update the progress bar

    avg_test_loss = total_test_loss / len(test_loader)  # Calculate the average test loss
    print(f"Average Test Loss on the last 20 images: {avg_test_loss:.4f}")

    # Create a DataFrame to store the detailed results
    results_df = pd.DataFrame({
        'Image Index': range(1, len(individual_losses) + 1),
        'Loss': individual_losses,
        'SSIM Score': ssim_scores,
        'PSNR Score': psnr_scores
    })

    print("\nDetailed Results for Last 20 Images:")
    print(results_df)  # Print the detailed results

    # Save the DataFrame to a CSV file
    csv_path = os.path.join(ResultsPath, "test_results.csv")
    results_df.to_csv(csv_path, index=False)  # Save the results as a CSV file
    print(f"Results saved to {csv_path}")

    return avg_test_loss, results_df  # Return the average test loss and the results DataFrame


In [None]:
def plot_test_loss_curve(epoch_losses, early_stop_epoch=None, save_path=None):
    """
    Plot the training and validation loss curves over the epochs.

    Args:
        epoch_losses (list of tuples): A list where each tuple contains (train_loss, val_loss) for each epoch.
        early_stop_epoch (int, optional): The epoch number where early stopping occurred. Defaults to None.
        save_path (str, optional): Path to save the plot image. Defaults to None.
    """
    # Generate a range of epoch numbers
    epochs = range(1, len(epoch_losses) + 1)

    # Separate the training and validation losses from the epoch_losses tuples
    train_losses = [loss[0] for loss in epoch_losses]
    val_losses = [loss[1] for loss in epoch_losses]

    # Set up the plot with a specified figure size
    plt.figure(figsize=(10, 6))
    
    # Plot the training and validation loss curves
    plt.plot(epochs, train_losses, label='Training Loss', marker='o')
    plt.plot(epochs, val_losses, label='Validation Loss', marker='o')

    # If early stopping occurred, mark the epoch on the plot
    if early_stop_epoch is not None:
        plt.axvline(x=early_stop_epoch, color='r', linestyle='--', label=f'Early Stopping at Epoch {early_stop_epoch}')

    # Label the axes and add a title
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss per Epoch')
    
    # Add a legend and grid for clarity
    plt.legend()
    plt.grid(True)

    # If a save path is provided, save the plot as an image
    if save_path:
        plt.savefig(save_path)

    # Display the plot
    plt.show()


In [None]:
# Initialize your test loader with the last 20 images from the validation set
test_subset_df = valid_patient_dirs[-20:]  # Select the last 20 images for testing
test_dataset = HandScanDataset2(data_dir=test_subset_df, transform=test_transforms, device=torch.device("cuda"))  # Create a dataset with the selected images
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)  # Initialize the DataLoader for testing

# Load the state dictionary from the saved file
state_dict_path = r'C:\Users\scmmw\OneDrive - University of Leeds\Masters - 23-24\Project\results\swin_unetr_state_dict.pth'  # Path to the saved state dictionary
state_dict = torch.load(state_dict_path)  # Load the state dictionary from the file

swin_model.load_state_dict(state_dict['state_dict'])  # Load the state dictionary into the SwinUNETR model

# Run the test and get the loss and similarity results
avg_test_loss, results_df = run_test(swin_model, test_loader, nn.MSELoss())  # Run the test and obtain the average loss and results DataFrame

import os  # Import the os module for file path handling

# Define the path where you want to save the CSV file
csv_path = os.path.join(ResultsPath, "test_results.csv")  # Set the path for saving the test results CSV

# Save the DataFrame to a CSV file
results_df.to_csv(csv_path, index=False)  # Save the test results to the CSV file without the index column

print(f"Results saved to {csv_path}")  # Print a message indicating where the results were saved
