In [None]:
"""
Script for medical image analysis using MONAI and PyTorch frameworks.

This script leverages the MONAI framework for healthcare imaging tasks, combined with PyTorch for model training and evaluation. 
The script demonstrates the use of data augmentation, preprocessing, and evaluation metrics for analyzing medical images. 

Authors:
- Md Kamrul Hasan

Date:
- 19-Nov-2024

Dependencies:
- PyTorch
- MONAI
- Additional libraries as listed in the import section

Note:
Ensure all required dependencies are installed before running the script.
"""

# Import essential libraries and modules
from monai.utils import set_determinism, first  # Utilities for reproducibility and data operations
from monai.transforms import (  # Preprocessing and augmentation pipelines
    EnsureChannelFirstD,
    Compose,
    LoadImageD,
    RandRotateD,
    RandZoomD,
    ScaleIntensityRanged,
)

import monai  # MONAI library for medical image analysis
from monai.data import DataLoader, Dataset, CacheDataset  # Data loading and caching utilities
from monai.config import print_config, USE_COMPILED  # MONAI configuration utilities
from monai.networks.blocks import Warp  # Spatial transformation module for medical image registration
from monai.apps import MedNISTDataset  # Prebuilt medical datasets for quick prototyping

# PyTorch imports for model creation and evaluation
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.autograd import Variable  # Automatic differentiation for operations on tensors
import torch.nn.functional as F  # Commonly used activation functions and loss utilities
import torch.optim as optim  # Optimizers for model training
from torch.utils.data import DataLoader  # Data loading utility for PyTorch

# Libraries for visualization and metric calculation
import torchmetrics  # Metrics for evaluating model performance
import matplotlib.pyplot as plt  # Visualization of results
from torchviz import make_dot, make_dot_from_trace  # Visualize computation graphs
from piqa import SSIM  # Structural Similarity Index Metric for image quality assessment
import visdom  # Visualization library for tracking training progress

# Additional utilities and libraries for data processing
from glob import glob  # File searching and pattern matching
import cv2  # Computer vision operations (e.g., image manipulation)
from scipy.spatial.distance import directed_hausdorff  # Calculate Hausdorff distance
import pandas as pd  # Data analysis and manipulation
import numpy as np  # Numerical operations
import tempfile  # Temporary file creation
import nibabel as nib  # Neuroimaging data I/O
import os  # Operating system interface

from modules.layers import *       # Import custom layers
from utils.helper import *         # Import helper functions
from utils.losses import *         # Import custom loss functions
from modules.FBA_SCA_DLIR import * # Import the FBA-SCA DLIR model implementation

# MONAI-specific loss functions and metrics
from monai.losses import *  # Predefined loss functions for medical image analysis
from monai.metrics import *  # Evaluation metrics for medical imaging tasks

# Custom configuration file
import config  # Configuration file (ensure 'config.py' exists in the same directory)

# Print MONAI configuration and set random seed for reproducibility
print_config()  # Displays MONAI's configuration (e.g., installed version, available features)
set_determinism(42)  # Sets the random seed for reproducible results


In [None]:
"""
GPU Availability Check and Device Setup

This script checks for the availability of GPUs, sets up the appropriate device (GPU/CPU),
and verifies the configuration to ensure efficient training.
"""

# Print the number of GPUs available
print('How many GPUs = ' + str(torch.cuda.device_count()))

# Device setup: Use GPU if available, otherwise fallback to CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Selected device: {device}")

# Raise an exception if no GPU is available, as CPU training can be significantly slower
if not torch.cuda.is_available():
    raise Exception("GPU not available. CPU training will be too slow.")

# Print the name of the GPU device being used
print("Device name:", torch.cuda.get_device_name(0))


In [None]:
# Path to the dataset directory containing training and testing data for the model.
# In this case, the dataset corresponds to the "Adult_ED_ES" dataset, which may include 
# end-diastolic (ED) and end-systolic (ES) frames 
data_dir = 'Datasets/Adult_ED_ES/'
print(data_dir)  # Prints the directory path to confirm its correctness.

# Name for saving the trained model's checkpoint file. 
# This name serves as an identifier for the model architecture, method, or specific experiment.
saveFile = 'FBA_SCA_DLIR'

# Path where the checkpoint of the trained model will be stored.
# The model's parameters will be saved in this file to allow resuming training or inference later.
# The file is saved with a `.pth` extension, commonly used in PyTorch to store model weights.
checkpoint_path = saveFile + '.pth'


In [None]:
from utils.data_loader import NiftiDataset  # Import the custom dataset loader class for NIfTI files

# Print the total number of training images in the directory
# The training images are stored in the "train/image" folder within the dataset directory.
# NIfTI files with the `.nii` extension are assumed to be the input images.
print('Total train Samples=' + str(len(glob(data_dir + "train/image/*.nii"))))

# Print the total number of validation images in the directory
# The validation images are stored in the "val/image" folder within the dataset directory.
# Similarly, `.nii` files are assumed as input images.
print('Total val Samples=' + str(len(glob(data_dir + "val/image/*.nii"))))

# Create a DataLoader for the training dataset:
# - `sorted(glob(...))`: Collects and sorts the paths to NIfTI images and corresponding masks.
# - `NiftiDataset`: Custom dataset class to load and pair images and masks.
# - `config.trainBatch`: Batch size for training.
# - `config.shuffle_`: Whether to shuffle the training data.
# - `config.num_workers`: Number of workers for parallel data loading.
trainData = DataLoader(
    NiftiDataset(
        sorted(glob(data_dir + "train/image/*.nii")),  # Paths to training images
        sorted(glob(data_dir + "train/mask/*.nii")),   # Paths to corresponding masks
        transform=None  # Transformation (e.g., augmentations) can be applied here
    ),
    batch_size=config.trainBatch,
    shuffle=config.shuffle_,
    num_workers=config.num_workers
)

# Create a DataLoader for the validation dataset:
# - Similar to `trainData`, but uses validation paths and parameters.
# - `config.valBatch`: Batch size for validation.
# - `config.shuffle_val`: Whether to shuffle validation data.
valData = DataLoader(
    NiftiDataset(
        sorted(glob(data_dir + "val/image/*.nii")),  # Paths to validation images
        sorted(glob(data_dir + "val/mask/*.nii")),   # Paths to corresponding masks
        transform=None  # No transformation is applied here for validation
    ),
    batch_size=config.valBatch,
    shuffle=config.shuffle_val,
    num_workers=config.num_workers
)

# Retrieve and print a sample from the training dataset:
# - `first(trainData)`: Fetches the first batch from the `trainData` DataLoader.
# - `train_sample['fixed_img']`: Fixed image tensor.
# - `train_sample['fixed_mask']`: Fixed mask tensor.
# - `train_sample['moving_img']`: Moving image tensor.
# - `train_sample['moving_mask']`: Moving mask tensor.
train_sample = first(trainData)
print(train_sample['fixed_img'].shape)  # Shape of the fixed image tensor
print(train_sample['fixed_mask'].shape)  # Shape of the fixed mask tensor
print(train_sample['moving_img'].shape)  # Shape of the moving image tensor
print(train_sample['moving_mask'].shape)  # Shape of the moving mask tensor


In [None]:
# Initialize evaluation metric
# - `Dice3DMultiClass`: Computes the Dice similarity coefficient for multi-class 3D segmentation.
# - `num_classes`: The number of classes is dynamically set based on the configuration.
dice_multiclass_metric = Dice3DMultiClass(num_classes=config.num_classes)

# Initialize custom loss functions
# - `OptimalTransportLoss`: Implements the optimal transport loss for attention weight alignment.
# - `twinLoss`: Computes a custom twin loss with a regularization factor (0.005 in this case).
# - `NCCLoss`: Local normalized cross-correlation loss, commonly used for image alignment tasks.
# - `MSELoss`: Mean Squared Error loss for regression-based tasks (default PyTorch implementation).
# - `BendingEnergyLoss`: Regularization loss to encourage smooth deformations in image registration.
transport_loss = OptimalTransportLoss()
twin_loss = twinLoss(0.005)
NCC_loss = NCCLoss()
mse_loss = nn.MSELoss()  # Imported from PyTorch's `nn` module
regularization_loss = BendingEnergyLoss()


# Initialize the FBA-SCA DLIR 3D Model
# - `FBA_SCA_DLIR3D`: Feedback Attention and Spatial-Context Alignment-based model for 3D image registration.
# - `device`: The computation device (e.g., 'cuda' for GPU, 'cpu' for CPU) is specified dynamically based on availability.
# - `.to(device)`: Ensures that the model is moved to the specified computation device.

model = FBA_SCA_DLIR3D(device).to(device)
# model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Initialize the Spatial Transformer module
# - `SpatialTransformer`: Custom layer/module for applying transformations to the input volumes.
# - This is a key component in many medical image registration frameworks, enabling the transformation of moving images to align with fixed images.

spatial_transform = SpatialTransformer().to(device)

# Estimate and print the total number of trainable parameters in the model
# - `estParams`: A utility function that computes the total trainable parameters in the given model.
# - Useful for debugging, understanding model complexity, and benchmarking.

print(f'Total Params = {estParams(model)}')

In [None]:
# Train and validate a deep learning model for medical image registration with dual GPUs support

# Configure the model and spatial transformer to use two GPUs, if specified in the configuration.
# The nn.DataParallel wrapper will distribute the workload across multiple GPUs.
if config.double_GPU:
    model = nn.DataParallel(model, device_ids=[0, 1], output_device=[0, 1])
    spatial_transform = nn.DataParallel(spatial_transform, device_ids=[0, 1], output_device=[0, 1])

# Initialize the optimizer using the Adam optimization algorithm with model parameters and learning rate from config
optimizer = torch.optim.Adam(model.parameters(), config.LR)

# Initialize the best Dice Similarity Coefficient (DSC) to track the highest value during training
best_dsc = 0

# Training loop: Iterate over the specified number of epochs to train the model
for epoch in range(config.Epochs):
    print("-" * 100)
    print(f"Epoch {epoch + 1}/{config.Epochs}")  # Print the current epoch
    model.train()  # Set the model to training mode (activates dropout, batch norm, etc.)

    # Initialize metric accumulators for the training phase
    step = 0
    epoch_DSC_MYO = 0
    epoch_DSC_LV = 0
    epoch_img_loss = 0
    twinLoss_ = 0
    OtransportLossT_ = 0

    # Iterate through the training data in batches using tqdm for progress tracking
    for trainBatch_data in tqdm(trainData):
        step += 1  # Increment the batch step

        # Zero out the gradients from previous iterations to prepare for backpropagation
        optimizer.zero_grad()

        # Extract images and corresponding masks for training, and move them to the GPU
        fixed_train_img = trainBatch_data['fixed_img'].to(device)
        fixed_train_msk = trainBatch_data['fixed_mask'].to(device)
        moving_train_img = trainBatch_data['moving_img'].to(device)
        moving_train_msk = trainBatch_data['moving_mask'].to(device)

        # Forward pass: Compute the model's predictions (displacement field, attention images, features, etc.)
        outputs = model(fixed_train_img, moving_train_img)

        pred_disp_ES_ED = outputs[0]
        fixed_with_att = outputs[1]
        moving_with_att = outputs[2]
        fixed_mask = outputs[3]
        moving_mask = outputs[4]
        #fixed_feature = projection1
        #moving_feature = projection2
        fixed_feature = outputs[5] 
        moving_feature = outputs[6] 
        # B1= attention_weights_exemplar_to_query
        # B2= attention_weights_query_to_exemplar 
        B1 = outputs[7] 
        B2 = outputs[8] 

        # Calculate individual loss components
        twinLoss = twin_loss(moving_feature, fixed_feature)  # Loss based on feature similarity
        OtransportLossT = transport_loss(B1, B2, device)  # Loss for optimal transport

        # Apply spatial transformations to the predicted images and masks
        pred_fixed = spatial_transform(moving_train_img, pred_disp_ES_ED).to(device)
        pred_fixed_att = spatial_transform(moving_with_att, pred_disp_ES_ED).to(device)
        pred_moving_mask = spatial_transform(moving_mask, pred_disp_ES_ED)

        # Compute the MSE loss between the predicted and ground truth images
        loss_raw_img1 = mse_loss(pred_fixed, fixed_train_img)
        loss_att_img1 = mse_loss(pred_fixed_att, fixed_with_att)  # Loss for attention-weighted image
        loss_attention1 = NCC_loss.loss(pred_moving_mask, fixed_mask).to(device)  # Normalized Cross Correlation loss for masks

        # Compute the Dice Similarity Coefficient (DSC) for the predicted and ground truth masks
        pred_mask_train = spatial_transform(moving_train_msk, pred_disp_ES_ED).to(device)
        labelDSC = dice_multiclass_metric(
            make_one_hot(thresholded(pred_mask_train, config.lower_bound, config.upper_bound), device, C=3),
            make_one_hot(fixed_train_msk, device, C=3)
        )

        # Calculate the total loss as a weighted sum of all components
        total_loss = (
            config.w1 * loss_raw_img1 +
            config.w2 * regularization_loss(pred_disp_ES_ED) +  # Regularization on displacement field
            config.w5 * loss_att_img1 +
            config.w6 * loss_attention1 +
            config.w3 * twinLoss +
            config.w4 * OtransportLossT
        )
            
        total_loss.backward()  # Backpropagate the loss gradients
        optimizer.step()  # Update model parameters using the optimizer

        # Accumulate the training metrics for this batch
        epoch_img_loss += total_loss.item()
        epoch_DSC_MYO += labelDSC[1].item()
        epoch_DSC_LV += labelDSC[2].item()
        twinLoss_ += twinLoss.item()
        OtransportLossT_ += OtransportLossT.item()

    # Compute the average training metrics for the epoch
    epoch_img_loss /= step
    epoch_DSC_MYO /= step
    epoch_DSC_LV /= step
    twinLoss_ /= step
    OtransportLossT_ /= step

    # Log the average training metrics for this epoch
    print(f"Epoch {epoch + 1}, Avg Train Img Loss: {epoch_img_loss:.5f}")
    print(f"Epoch {epoch + 1}, Avg Train DSC MYO: {epoch_DSC_MYO:.5f}")
    print(f"Epoch {epoch + 1}, Avg Train DSC LV: {epoch_DSC_LV:.5f}")
    print("-" * 60)

    # Validation phase: Evaluate the model every 'val_interval' epochs or for the first epoch
    if (epoch + 1) % config.val_interval == 0 or epoch == 0:
        model.eval()  # Set the model to evaluation mode (disables dropout, batch norm)
        step = 0
        epoch_img_loss = 0
        epoch_DSC_MYO = 0
        epoch_DSC_LV = 0

        # Disable gradient computation during validation to save memory and computation
        with torch.no_grad():
            # Iterate through the validation data in batches
            for testBatch_data in valData:
                step += 1  # Increment the validation batch step

                # Extract validation images and masks, and move them to the GPU
                fixed_test_img = testBatch_data['fixed_img'].to(device)
                fixed_test_msk = testBatch_data['fixed_mask'].to(device)
                moving_test_img = testBatch_data['moving_img'].to(device)
                moving_test_msk = testBatch_data['moving_mask'].to(device)

                # Forward pass through the model for validation
                outputs = model(fixed_test_img, moving_test_img)

                pred_disp_ES_ED = outputs[0]
                fixed_with_att = outputs[1]
                moving_with_att = outputs[2]
                fixed_mask = outputs[3]
                moving_mask = outputs[4]
                #fixed_feature = projection1
                #moving_feature = projection2
                fixed_feature = outputs[5] 
                moving_feature = outputs[6] 
                # B1= attention_weights_exemplar_to_query
                # B2= attention_weights_query_to_exemplar 
                B1 = outputs[7] 
                B2 = outputs[8] 

                # Calculate validation loss components
                twinLoss = twin_loss(moving_feature, fixed_feature)
                transportLoss = transport_loss(B1, B2, device)

                pred_fixed = spatial_transform(moving_test_img, pred_disp_ES_ED).to(device)
                pred_fixed_att = spatial_transform(moving_with_att, pred_disp_ES_ED).to(device)
                pred_moving_mask = spatial_transform(moving_mask, pred_disp_ES_ED)

                loss_raw_img1 = mse_loss(pred_fixed, fixed_test_img)
                loss_att_img1 = mse_loss(pred_fixed_att, fixed_with_att)
                loss_attention1 = NCC_loss.loss(pred_moving_mask, fixed_mask).to(device)

                # Compute the Dice Similarity Coefficient (DSC) for the validation masks
                pred_mask_test = spatial_transform(moving_test_msk, pred_disp_ES_ED).to(device)
                labelDSC_test = dice_multiclass_metric(
                    make_one_hot(thresholded(pred_mask_test, config.lower_bound, config.upper_bound), device, C=3),
                    make_one_hot(fixed_test_msk, device, C=3)
                )

                # Calculate total validation loss
                total_loss = (
                    config.w1 * loss_raw_img1 +
                    config.w2 * regularization_loss(pred_disp_ES_ED) +
                    config.w5 * loss_att_img1 +
                    config.w6 * loss_attention1 +
                    config.w3 * twinLoss +
                    config.w4 * transportLoss
                )

                # Accumulate validation metrics
                epoch_img_loss += total_loss.item()
                epoch_DSC_MYO += labelDSC_test[1].item()
                epoch_DSC_LV += labelDSC_test[2].item()

            # Compute the average validation metrics for the epoch
            epoch_img_loss /= step
            epoch_DSC_MYO /= step
            epoch_DSC_LV /= step
            epoch_DSC = (epoch_DSC_MYO + epoch_DSC_LV) / 2

            # Log the average validation metrics for this epoch
            print(f"Epoch {epoch + 1}, Avg Val Img Loss: {epoch_img_loss:.5f}")
            print(f"Epoch {epoch + 1}, Avg Val DSC MYO: {epoch_DSC_MYO:.5f}")
            print(f"Epoch {epoch + 1}, Avg Val DSC LV: {epoch_DSC_LV:.5f}")

            # Save the best model based on DSC improvement
            if epoch_DSC > best_dsc:
                best_loss_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(checkpoint_path))  # Save model weights
                print(f"Validation DSC improved from {best_dsc:.5f} to {epoch_DSC:.5f}! Model saved at {checkpoint_path}.")
                best_dsc = epoch_DSC  # Update the best DSC

            # Print the best model's epoch and DSC
            print(f"Best model epoch: {best_loss_epoch}, Best DSC: {best_dsc:.4f}")


In [None]:
# Initialize the model and spatial transformer, then load pre-trained weights
# The model is FBA_SCA_DLIR3D, which is a specific model architecture for 3D image registration
# The spatial transformer is used to apply spatial transformations to the images
model = FBA_SCA_DLIR3D(device).to(device)
spatial_transform = SpatialTransformer().to(device)

# Load the model state from the checkpoint path
# This restores the model weights from a pre-trained model for evaluation
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Set the model to evaluation mode, which deactivates dropout and batch normalization layers
model.eval()

# Initialize lists to store various evaluation metrics across the validation batches
DSC_back, DSC_MYO, DSC_LV, DSC_EPI = [], [], [], []
imageMatrix = []

# Start the evaluation phase without tracking gradients (faster inference)
with torch.no_grad():  # Disables gradient computation during inference for efficiency
    # Iterate over the validation data in batches
    for testBatch_data in valData:
        
        # Extract the input data (images and masks) for the validation batch
        # These are tensors that will be moved to the device (GPU or CPU)
        fixed_test_img = testBatch_data['fixed_img'].to(device)
        fixed_test_msk = testBatch_data['fixed_mask'].to(device)
        moving_test_img = testBatch_data['moving_img'].to(device)
        moving_test_msk = testBatch_data['moving_mask'].to(device)

        # Generate the displacement field and predicted images/masks by passing inputs through the model
        output = model(fixed_test_img, moving_test_img)

        # Apply spatial transformation to the predicted image to align it with the fixed image
        pred_fixed = spatial_transform(moving_test_img, output[0]).to(device)

        # Calculate the Mean Squared Error (MSE) between the predicted image and the fixed image
        loss_raw_img1 = mse_loss(pred_fixed, fixed_test_img)

        # Apply spatial transformation to the predicted mask (the model output for segmentation)
        pred_mask_test = spatial_transform(moving_test_msk, output[0]).to(device)

        # Threshold the predicted mask to convert continuous predictions into binary or categorical values
        pred_mask_test = thresholded(pred_mask_test, config.lower_bound, config.upper_bound)

        # Calculate the Dice Similarity Coefficient (DSC) for the predicted and ground truth masks
        # DSC measures the overlap between the predicted and actual regions of interest (ROI)
        labelDSC_test = dice_multiclass_metric(
            make_one_hot(pred_mask_test, device, C=3),
            make_one_hot(fixed_test_msk, device, C=3)
        )

        # Append the DSC values for each class (background, MYO, LV) to respective lists
        DSC_back.append(labelDSC_test[0].item())
        DSC_MYO.append(labelDSC_test[1].item())
        DSC_LV.append(labelDSC_test[2].item())

        # Store the MSE for the predicted image for later analysis
        imageMatrix.append(loss_raw_img1.item())

        # Post-process the ground truth and predicted masks to balance the classes (e.g., combine MYO and ENDO classes for EPI)
        fixed_test_msk[fixed_test_msk == 0] = 0
        fixed_test_msk[fixed_test_msk == 1] = 1
        fixed_test_msk[fixed_test_msk == 2] = 1
        pred_mask_test[pred_mask_test == 0] = 0
        pred_mask_test[pred_mask_test == 1] = 1
        pred_mask_test[pred_mask_test == 2] = 1

        # Recalculate the DSC for the post-processed masks to ensure consistency after the balancing step
        labelDSC_test = dice_multiclass_metric(
            make_one_hot(pred_mask_test, device, C=3),
            make_one_hot(fixed_test_msk, device, C=3)
        )

        # Append the DSC value for the epi layer (epicardium) to the list
        DSC_EPI.append(labelDSC_test[1].item())

    # After processing all batches, remove duplicate DSC and MSE values
    # This ensures that the values used for metrics are unique and avoids bias from repeated values
    DSC_back = np.unique(DSC_back)
    DSC_MYO = np.unique(DSC_MYO)
    DSC_LV = np.unique(DSC_LV)
    imageMatrix = np.unique(imageMatrix)
    DSC_EPI = np.unique(DSC_EPI)


    # Print the average DSC and MSE for each class (background, MYO, LV, EPI)
    # Displaying both the mean and standard deviation to provide a more complete evaluation
    print(f"Average test DSC Back: {np.mean(DSC_back):.5f} +/- {np.std(DSC_back):.5f}!")
    print(f"Average test DSC MYO: {np.mean(DSC_MYO):.5f} +/- {np.std(DSC_MYO):.5f}!")
    print(f"Average test DSC ENDO: {np.mean(DSC_LV):.5f} +/- {np.std(DSC_LV):.5f}!")
    print(f"Average test DSC EPI: {np.mean(DSC_EPI):.5f} +/- {np.std(DSC_EPI):.5f}!")
    print(f"Average test MSE: {np.mean(imageMatrix):.5f} +/- {np.std(imageMatrix):.5f}!")
