In [None]:
# MONAI (Medical Open Network for AI) libraries for medical imaging processing and transformations.
from monai.utils import set_determinism  # Utility to ensure reproducible results.
from monai.transforms import (
    EnsureChannelFirstD,  # Ensures the input image has channels as the first dimension.
    Compose,              # Allows combining multiple transformations.
    LoadImageD,           # Loads images from disk and wraps them in a dictionary.
    RandRotateD,          # Randomly rotates the image within a specified range.
    RandZoomD,            # Randomly zooms into the image within a specified range.
    ScaleIntensityRanged  # Scales image intensity to a specified range.
)
import monai
from monai.utils import set_determinism, first
from monai.networks.layers import Conv, Norm, Pool, same_padding
import torchinfo  # Library for summarizing the PyTorch model architecture.
from torchviz import make_dot  # Visualizes the computation graph of a PyTorch model.
from monai.data import DataLoader, Dataset, CacheDataset  # Utilities for handling datasets and data loading.
from monai.config import print_config  # Prints the MONAI configuration and environment info.
from monai.networks.blocks import Warp  # Warp block for applying a displacement field to images.
from monai.apps import MedNISTDataset  # Utility for working with the MedNIST dataset.
import torch.nn.functional as F  # Functional interface in PyTorch, includes many useful operations like activations.
from tqdm import tqdm  # Progress bar library for iterating over large loops.


# Core PyTorch imports
import torch  # PyTorch core library for building and training neural networks.
from torch import nn  # PyTorch module containing neural network components.
from collections.abc import Sequence  # Collection utilities for handling sequences.
from monai.networks.blocks import (
    Warp,                     # Warp block for applying a displacement field to images.
    Convolution               # Generic convolution block used in many MONAI network architectures.
)
from monai.networks.blocks.regunet_block import (
    RegistrationDownSampleBlock,  # Block for downsampling in a registration network.
    get_conv_block,               # Utility function to get a convolution block.
    get_deconv_block              # Utility function to get a deconvolution block.
)
from monai.networks.utils import meshgrid_ij  # Utility to generate a meshgrid for image coordinates.

# General-purpose imports for working with files, images, and metrics
import os  # Operating system interface for file handling and paths.
import cv2  # OpenCV library for image processing.
import torchmetrics  # Metrics library for evaluating PyTorch models.
from torch.autograd import Variable  # Enables automatic differentiation for tensor operations.
from scipy.spatial.distance import directed_hausdorff  # Computes the directed Hausdorff distance between point clouds.
import pandas as pd  # Data manipulation library, useful for handling tabular data.
import numpy as np  # Numerical operations on large, multi-dimensional arrays and matrices.
import matplotlib.pyplot as plt  # Plotting library for visualizing data.
import tempfile  # For creating temporary files and directories.
from glob import glob  # Unix-style pathname pattern expansion.
from monai.losses import *  # Import all loss functions provided by MONAI.
from monai.metrics import *  # Import all metrics provided by MONAI.
from piqa import SSIM  # Structural Similarity Index (SSIM) metric from PIQA.

# Print MONAI configuration to check the setup.
print_config()

# Set a fixed seed for reproducibility in data transformations, model training, etc.
set_determinism(42)


In [None]:
# Define the dataset directory name.
dataset_name = 'CAMUS_EStoED_A2C'

# Construct the root directory path for the dataset.
dataset_root_dir = f'data/{dataset_name}/'
print(f'Root directory: {dataset_root_dir}')

# Set batch sizes for training and testing.
training_batch_size = 6
testing_batch_size = 6

# Define the size of images to be processed.
image_size = 512

# Initialize previous model weights and pre-trained model flag.
previous_model_weight_size = 256
use_pretrained_model = 0

# Define the number of training epochs.
num_epochs = 500

# Set the number of worker threads for data loading.
data_loader_workers = 0

# Define the experiment name for saving results.
experiment_name = "DdC-AC-DLIR_A2C"

# Construct the file name for saving results or model checkpoints.
checkpoint_file_name = f'{experiment_name}_{dataset_name}_{image_size}_'


In [None]:
# Print the number of available GPUs.
num_gpus = torch.cuda.device_count()
print(f'Number of GPUs available: {num_gpus}')

# Check and select the device (GPU if available, otherwise CPU).
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(f'Selected device: {device}')

# Raise an exception if no GPU is available, indicating that CPU training will be too slow.
if not torch.cuda.is_available():
    raise Exception("GPU not available. Training on CPU may be too slow.")

# Print the name of the GPU device.
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f'Device name: {device_name}')


In [None]:
# Define a dataset class for handling grayscale images.
class EchoDataset(Dataset):
    def __init__(self, image_paths, img_size=image_size):
        """
        Initialize the dataset.

        Parameters:
        - image_paths (list of str): List of file paths to the images.
        - img_size (int): The size to which each image will be resized.
        """
        self.image_paths = image_paths
        self.img_size = img_size
        self.n_samples = len(image_paths)

    def __getitem__(self, index):
        """
        Retrieve an image from the dataset.

        Parameters:
        - index (int): The index of the image to retrieve.

        Returns:
        - image (numpy.ndarray): The processed image as a numpy array.
        """
        # Read the image in grayscale mode.
        image = cv2.imread(self.image_paths[index], cv2.IMREAD_GRAYSCALE)
        
        # Resize the image to the specified size.
        image = cv2.resize(image, (self.img_size, self.img_size))
        
        # Normalize the image pixel values to the range [0, 1].
        image = image / image.max()
        
        # Expand dimensions to add a channel dimension.
        image = np.expand_dims(image, axis=0)
        
        # Convert the image to float32 data type.
        image = image.astype(np.float32)
        
        return image

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        
        Returns:
        - int: Number of samples in the dataset.
        """
        return self.n_samples


# Define a dataset class for handling masks associated with grayscale images.
class EchoDatasetMask(Dataset):
    def __init__(self, mask_paths, img_size=image_size):
        """
        Initialize the dataset.

        Parameters:
        - mask_paths (list of str): List of file paths to the mask images.
        - img_size (int): The size to which each mask image will be resized.
        """
        self.mask_paths = mask_paths
        self.img_size = img_size
        self.n_samples = len(mask_paths)

    def __getitem__(self, index):
        """
        Retrieve a mask from the dataset.

        Parameters:
        - index (int): The index of the mask to retrieve.

        Returns:
        - mask (numpy.ndarray): The processed mask as a numpy array.
        """
        # Read the mask image in grayscale mode.
        mask = cv2.imread(self.mask_paths[index], cv2.IMREAD_GRAYSCALE)
        
        # Resize the mask to the specified size using nearest neighbor interpolation.
        mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
        
        # Expand dimensions to add a channel dimension.
        mask = np.expand_dims(mask, axis=0)
        
        # Convert the mask to float32 data type.
        mask = mask.astype(np.float32)
        
        return mask

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        
        Returns:
        - int: Number of samples in the dataset.
        """
        return self.n_samples


In [None]:
# Function to create data loaders for image datasets.
def get_batches(image_paths, batch_size, num_workers, pin_memory):
    """
    Create a DataLoader for the image dataset.

    Parameters:
    - image_paths (list of str): List of file paths to the images.
    - batch_size (int): Number of samples per batch.
    - num_workers (int): Number of subprocesses to use for data loading.
    - pin_memory (bool): Whether to pin memory for faster data transfer to GPU.

    Returns:
    - DataLoader: DataLoader object for the image dataset.
    """
    dataset = EchoDataset(image_paths=image_paths, img_size=image_size)
    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=pin_memory,
                            shuffle=False)
    return data_loader

# Function to create data loaders for mask datasets.
def get_batches_mask(mask_paths, batch_size, num_workers, pin_memory):
    """
    Create a DataLoader for the mask dataset.

    Parameters:
    - mask_paths (list of str): List of file paths to the masks.
    - batch_size (int): Number of samples per batch.
    - num_workers (int): Number of subprocesses to use for data loading.
    - pin_memory (bool): Whether to pin memory for faster data transfer to GPU.

    Returns:
    - DataLoader: DataLoader object for the mask dataset.
    """
    dataset = EchoDatasetMask(mask_paths=mask_paths, img_size=image_size)
    data_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=pin_memory,
                            shuffle=False)
    return data_loader

# Print the number of training and validation samples for images and masks.
print(f'Train Sample numbers (fixed_img) = {len(sorted(glob(os.path.join(dataset_root_dir, "train/fixed_img/*.png"))))}')
print(f'Train Sample numbers (fixed_msk) = {len(sorted(glob(os.path.join(dataset_root_dir, "train/fixed_msk/*.png"))))}')
print(f'Train Sample numbers (moving_img) = {len(sorted(glob(os.path.join(dataset_root_dir, "train/moving_img/*.png"))))}')
print(f'Train Sample numbers (moving_msk) = {len(sorted(glob(os.path.join(dataset_root_dir, "train/moving_msk/*.png"))))}')
print()
print(f'Val Sample numbers (fixed_img) = {len(sorted(glob(os.path.join(dataset_root_dir, "val/fixed_img/*.png"))))}')
print(f'Val Sample numbers (fixed_msk) = {len(sorted(glob(os.path.join(dataset_root_dir, "val/fixed_msk/*.png"))))}')
print(f'Val Sample numbers (moving_img) = {len(sorted(glob(os.path.join(dataset_root_dir, "val/moving_img/*.png"))))}')
print(f'Val Sample numbers (moving_msk) = {len(sorted(glob(os.path.join(dataset_root_dir, "val/moving_msk/*.png"))))}')
print()

# Create data loaders for the training dataset.
fixed_train_img_loader = get_batches(
    image_paths=sorted(glob(os.path.join(dataset_root_dir, "train/fixed_img/*"))),
    batch_size=training_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

fixed_train_mask_loader = get_batches_mask(
    mask_paths=sorted(glob(os.path.join(dataset_root_dir, "train/fixed_msk/*"))),
    batch_size=training_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

moving_train_img_loader = get_batches(
    image_paths=sorted(glob(os.path.join(dataset_root_dir, "train/moving_img/*"))),
    batch_size=training_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

moving_train_mask_loader = get_batches_mask(
    mask_paths=sorted(glob(os.path.join(dataset_root_dir, "train/moving_msk/*"))),
    batch_size=training_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

# Print data loader objects to verify creation.
print("Train IMG FIXED Loader:", fixed_train_img_loader)
print("Train MSK FIXED Loader:", fixed_train_mask_loader)
print("Train IMG Moving Loader:", moving_train_img_loader)
print("Train MSK Moving Loader:", moving_train_mask_loader)

# Create data loaders for the validation dataset.
fixed_val_img_loader = get_batches(
    image_paths=sorted(glob(os.path.join(dataset_root_dir, "val/fixed_img/*"))),
    batch_size=testing_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

fixed_val_mask_loader = get_batches_mask(
    mask_paths=sorted(glob(os.path.join(dataset_root_dir, "val/fixed_msk/*"))),
    batch_size=testing_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

moving_val_img_loader = get_batches(
    image_paths=sorted(glob(os.path.join(dataset_root_dir, "val/moving_img/*"))),
    batch_size=testing_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

moving_val_mask_loader = get_batches_mask(
    mask_paths=sorted(glob(os.path.join(dataset_root_dir, "val/moving_msk/*"))),
    batch_size=testing_batch_size,
    num_workers=data_loader_workers,
    pin_memory=True
)

# Print data loader objects to verify creation.
print("Val IMG FIXED Loader:", fixed_val_img_loader)
print("Val MSK FIXED Loader:", fixed_val_mask_loader)
print("Val IMG Moving Loader:", moving_val_img_loader)
print("Val MSK Moving Loader:", moving_val_mask_loader)

In [None]:
# Create a dictionary to store DataLoader objects for different datasets.
dataloaders = {
    'fixed_train_img': fixed_train_img_loader,
    'fixed_train_msk': fixed_train_mask_loader,
    'moving_train_img': moving_train_img_loader,
    'moving_train_msk': moving_train_mask_loader,
    'fixed_val_img': fixed_val_img_loader,
    'fixed_val_msk': fixed_val_mask_loader,
    'moving_val_img': moving_val_img_loader,
    'moving_val_msk': moving_val_mask_loader
}

# Example usage: Print the DataLoader objects to verify their creation.
for key, loader in dataloaders.items():
    print(f"{key} DataLoader: {loader}")


In [None]:
# Extract samples from each DataLoader
fixed_train_img_sample = first(dataloaders["fixed_train_img"])[0][0]
fixed_train_msk_sample = first(dataloaders["fixed_train_msk"])[0][0]
moving_train_img_sample = first(dataloaders["moving_train_img"])[0][0]
moving_train_msk_sample = first(dataloaders["moving_train_msk"])[0][0]

fixed_val_img_sample = first(dataloaders["fixed_val_img"])[0][0]
fixed_val_msk_sample = first(dataloaders["fixed_val_msk"])[0][0]
moving_val_img_sample = first(dataloaders["moving_val_img"])[0][0]
moving_val_msk_sample = first(dataloaders["moving_val_msk"])[0][0]

# Print shapes of the samples
print(f"fixed_train_img_sample shape: {fixed_train_img_sample.shape}")
print(f"fixed_train_msk_sample shape: {fixed_train_msk_sample.shape}")
print(f"moving_train_img_sample shape: {moving_train_img_sample.shape}")
print(f"moving_train_msk_sample shape: {moving_train_msk_sample.shape}")
print(f"fixed_val_img_sample shape: {fixed_val_img_sample.shape}")
print(f"fixed_val_msk_sample shape: {fixed_val_msk_sample.shape}")
print(f"moving_val_img_sample shape: {moving_val_img_sample.shape}")
print(f"moving_val_msk_sample shape: {moving_val_msk_sample.shape}")

# Print range of pixel values
print(f"fixed_train_img_sample Range: {fixed_train_img_sample.max()} {fixed_train_img_sample.min()}")
print(f"fixed_train_msk_sample Range: {fixed_train_msk_sample.max()} {fixed_train_msk_sample.min()} {np.unique(fixed_train_msk_sample)}")
print(f"moving_train_img_sample Range: {moving_train_img_sample.max()} {moving_train_img_sample.min()}")
print(f"moving_train_msk_sample Range: {moving_train_msk_sample.max()} {moving_train_msk_sample.min()} {np.unique(moving_train_msk_sample)}")
print(f"fixed_val_img_sample Range: {fixed_val_img_sample.max()} {fixed_val_img_sample.min()}")
print(f"fixed_val_msk_sample Range: {fixed_val_msk_sample.max()} {fixed_val_msk_sample.min()} {np.unique(fixed_val_msk_sample)}")
print(f"moving_val_img_sample Range: {moving_val_img_sample.max()} {moving_val_img_sample.min()}")
print(f"moving_val_msk_sample Range: {moving_val_msk_sample.max()} {moving_val_msk_sample.min()} {np.unique(moving_val_msk_sample)}")

# Plot samples in a grid
plt.figure(figsize=(15, 7))

# Fixed training images and masks
plt.subplot(2, 4, 1)
plt.title("fixed_train_img_sample")
plt.imshow(fixed_train_img_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 2)
plt.title("fixed_train_msk_sample")
plt.imshow(fixed_train_msk_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 3)
plt.title("moving_train_img_sample")
plt.imshow(moving_train_img_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 4)
plt.title("moving_train_msk_sample")
plt.imshow(moving_train_msk_sample.squeeze(), cmap="gray")
plt.axis('off')

# Fixed validation images and masks
plt.subplot(2, 4, 5)
plt.title("fixed_val_img_sample")
plt.imshow(fixed_val_img_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 6)
plt.title("fixed_val_msk_sample")
plt.imshow(fixed_val_msk_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 7)
plt.title("moving_val_img_sample")
plt.imshow(moving_val_img_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.subplot(2, 4, 8)
plt.title("moving_val_msk_sample")
plt.imshow(moving_val_msk_sample.squeeze(), cmap="gray")
plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
from __future__ import annotations

class RegistrationExtractionBlock(nn.Module):
    """
    Extraction Block for RegUNet.
    Extracts features from specified levels and averages them.
    """
    def __init__(
        self,
        spatial_dims,
        extract_levels,
        num_channels,
        out_channels,
        kernel_initializer,
        activation,
        mode,
        align_corners,
    ):
        """
        Args:
            spatial_dims: Number of spatial dimensions.
            extract_levels: Levels to extract features from, e.g., [0, 1, 2].
            num_channels: Number of channels at each level.
            out_channels: Number of output channels.
            kernel_initializer: Initializer for kernels.
            activation: Activation function.
            mode: Interpolation mode for feature map resizing.
            align_corners: Whether to align corners during interpolation.
        """
        super().__init__()
        self.extract_levels = extract_levels
        self.max_level = max(extract_levels)
        self.layers = nn.ModuleList(
            [
                get_conv_block(
                    spatial_dims=spatial_dims,
                    in_channels=num_channels[d],
                    out_channels=out_channels,
                    norm="BATCH",
                    act=activation,
                    initializer=kernel_initializer,
                )
                for d in extract_levels
            ]
        )
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x: list[torch.Tensor], image_size: list[int]) -> torch.Tensor:
        """
        Args:
            x: Decoded features at different levels.
            image_size: Desired output image size.

        Returns:
            Tensor of shape (batch, out_channels, size1, size2, size3), where (size1, size2, size3) = image_size.
        """
        feature_list = [
            F.interpolate(
                layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners
            )
            for layer, level in zip(self.layers, self.extract_levels)
        ]
        out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0)
        return out

def get_conv_layer(
    spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Sequence[int] | int = 3
) -> nn.Module:
    """
    Creates a convolutional layer with the given parameters.
    
    Args:
        spatial_dims: Number of spatial dimensions.
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        kernel_size: Size of the kernel.
    
    Returns:
        Convolutional layer module.
    """
    padding = same_padding(kernel_size)
    return Convolution(
        spatial_dims, in_channels, out_channels, kernel_size=kernel_size, bias=False, conv_only=True, padding=padding
    )

class RegistrationResidualConvBlock(nn.Module):
    """
    Residual convolutional block with layer normalization and activation.
    """
    def __init__(
        self, spatial_dims: int, in_channels: int, out_channels: int, num_layers: int = 2, kernel_size: int = 3
    ):
        """
        Args:
            spatial_dims: Number of spatial dimensions.
            in_channels: Number of input channels.
            out_channels: Number of output channels.
            num_layers: Number of layers in the block.
            kernel_size: Size of the kernel.
        """
        super().__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            [
                get_conv_layer(
                    spatial_dims=spatial_dims,
                    in_channels=in_channels if i == 0 else out_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                )
                for i in range(num_layers)
            ]
        )
        self.norms = nn.ModuleList([Norm[Norm.BATCH, spatial_dims](out_channels) for _ in range(num_layers)])
        self.acts = nn.ModuleList([nn.ReLU() for _ in range(num_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor with shape (batch, in_channels, insize_1, insize_2, [insize_3]).

        Returns:
            Output tensor with shape (batch, out_channels, insize_1, insize_2, [insize_3]).
        """
        skip = x
        for i, (conv, norm, act) in enumerate(zip(self.layers, self.norms, self.acts)):
            x = conv(x)
            x = norm(x)
            if i == self.num_layers - 1:
                # Add skip connection on the last layer
                x = x + skip
            x = act(x)
        return x

class RegUNet(nn.Module):
    """
    Modified UNet architecture used in RegUNet for registration tasks.
    """
    def __init__(
        self,
        spatial_dims,
        in_channels,
        num_channel_initial,
        depth,
        out_kernel_initializer,
        out_activation,
        out_channels,
        extract_levels,
        encode_kernel_sizes,
        pooling=True,
        concat_skip=False,
    ):
        """
        Args:
            spatial_dims: Number of spatial dimensions.
            in_channels: Number of input channels.
            num_channel_initial: Number of initial channels.
            depth: Depth of the network.
            out_kernel_initializer: Initializer for the output layer.
            out_activation: Activation function for the output layer.
            out_channels: Number of output channels.
            extract_levels: Levels to extract features from.
            encode_kernel_sizes: Kernel sizes for encoding.
            pooling: Whether to use pooling for down-sampling.
            concat_skip: Whether to concatenate skipped connections.
        """
        super().__init__()
        if not extract_levels:
            extract_levels = (depth,)
        if max(extract_levels) != depth:
            raise AssertionError("Maximum extraction level must equal depth.")

        self.spatial_dims = spatial_dims
        self.in_channels = in_channels
        self.num_channel_initial = num_channel_initial
        self.depth = depth
        self.out_kernel_initializer = out_kernel_initializer
        self.out_activation = out_activation
        self.out_channels = out_channels
        self.extract_levels = extract_levels
        self.pooling = pooling
        self.concat_skip = concat_skip

        if isinstance(encode_kernel_sizes, int):
            encode_kernel_sizes = [encode_kernel_sizes] * (self.depth + 1)
        if len(encode_kernel_sizes) != self.depth + 1:
            raise AssertionError("Kernel sizes length must match depth + 1.")
        self.encode_kernel_sizes = encode_kernel_sizes

        self.num_channels = [self.num_channel_initial * (2**d) for d in range(self.depth + 1)]
        self.min_extract_level = min(self.extract_levels)

        # Initialize layers
        self.build_layers()

    def build_layers(self):
        self.build_encode_layers()
        self.build_decode_layers()

    def build_encode_layers(self):
        # Encoding layers
        self.encode_convs = nn.ModuleList(
            [
                self.build_conv_block(
                    in_channels=self.in_channels if d == 0 else self.num_channels[d - 1],
                    out_channels=self.num_channels[d],
                    kernel_size=self.encode_kernel_sizes[d],
                )
                for d in range(self.depth)
            ]
        )
        self.encode_pools = nn.ModuleList(
            [self.build_down_sampling_block(channels=self.num_channels[d]) for d in range(self.depth)]
        )
        self.bottom_block = self.build_bottom_block(
            in_channels=self.num_channels[-2], out_channels=self.num_channels[-1]
        )

    def build_conv_block(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            get_conv_block(
                spatial_dims=self.spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                act='RELU'
            ),
            RegistrationResidualConvBlock(
                spatial_dims=self.spatial_dims,
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
            ),
        )

    def build_down_sampling_block(self, channels: int):
        return RegistrationDownSampleBlock(spatial_dims=self.spatial_dims, channels=channels, pooling=self.pooling)

    def build_bottom_block(self, in_channels: int, out_channels: int):
        kernel_size = self.encode_kernel_sizes[self.depth]
        return nn.Sequential(
            get_conv_block(
                spatial_dims=self.spatial_dims,
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                act='RELU'
            ),
            RegistrationResidualConvBlock(
                spatial_dims=self.spatial_dims,
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
            ),
        )

    def build_decode_layers(self):
        self.decode_deconvs = nn.ModuleList(
            [
                self.build_up_sampling_block(in_channels=self.num_channels[d + 1], out_channels=self.num_channels[d])
                for d in range(self.depth - 1, self.min_extract_level - 1, -1)
            ]
        )
        self.decode_convs = nn.ModuleList(
            [
                self.build_conv_block(
                    in_channels=(2 * self.num_channels[d] if self.concat_skip else self.num_channels[d]),
                    out_channels=self.num_channels[d],
                    kernel_size=3,
                )
                for d in range(self.depth - 1, self.min_extract_level - 1, -1)
            ]
        )
        self.output_block = self.build_output_block()

    def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
        return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)

    def build_output_block(self) -> nn.Module:
        return RegistrationExtractionBlock(
            spatial_dims=self.spatial_dims,
            extract_levels=self.extract_levels,
            num_channels=self.num_channels,
            out_channels=self.out_channels,
            kernel_initializer=self.out_kernel_initializer,
            activation=self.out_activation,
            mode='bilinear',
            align_corners=True
        )

    def forward(self, x):
        """
        Args:
            x: Input tensor with shape (batch, in_channels, insize_1, insize_2, [insize_3]).

        Returns:
            Output tensor with shape (batch, out_channels, insize_1, insize_2, [insize_3]).
        """
        image_size = x.shape[2:]
        skips = []  # Skipped connections
        encoded = x
        for encode_conv, encode_pool in zip(self.encode_convs, self.encode_pools):
            skip = encode_conv(encoded)
            encoded = encode_pool(skip)
            skips.append(skip)
        decoded = self.bottom_block(encoded)

        outs = [decoded]

        for i, (decode_deconv, decode_conv) in enumerate(zip(self.decode_deconvs, self.decode_convs)):
            decoded = decode_deconv(decoded)
            if self.concat_skip:
                decoded = torch.cat([decoded, skips[-i - 1]], dim=1)
            else:
                decoded = decoded + skips[-i - 1]
            decoded = decode_conv(decoded)
            outs.append(decoded)

        out = self.output_block(outs, image_size=image_size)
        return out


In [None]:
# Define RegUNet class (assuming it's already correctly implemented)
# ...

# Initialize the RegUNet model with specific hyperparameters
model = RegUNet(
    spatial_dims=2,            # The model will operate on 2D spatial data (e.g., images)
    in_channels=2,             # The number of input channels (e.g., 2 for a dual-channel input like multi-modal images)
    num_channel_initial=128,   # The number of channels/filters in the first convolutional layer
    depth=5,                   # The number of down-sampling/encoding layers in the U-Net architecture
    extract_levels=[5],        # Specifies which levels to extract features from (5 corresponds to the deepest layer here)
    out_activation=None,       # No activation function is applied to the output layer (e.g., raw logits are returned)
    out_channels=2,            # The number of output channels (e.g., 2 for a segmentation task with 2 classes)
    out_kernel_initializer="zeros", # Initialize the output layer weights with zeros
    concat_skip=False,         # Whether to concatenate (True) or add (False) skip connections
    encode_kernel_sizes=3      # The kernel size for convolutions in the encoding path
)

# Print a summary of the model architecture using torchinfo
torchinfo.summary(model, input_size=(2, 2, image_size, image_size), depth=100)
# Summary displays detailed information about each layer, 
# including input/output sizes, number of parameters, and computational complexity.
# 'input_size' defines the expected size of the input tensor (batch_size, channels, height, width).
# 'depth=100' controls the level of detail shown in the summary (with 100, all layers are included).


In [None]:
# Move the model to the specified device (e.g., GPU or CPU)
model = model.to(device)

# If using a pretrained model, load its state dictionary from a file
if use_pretrained_model:
    # Construct the path to the pretrained weights file using root_dir, ExpName, dataDir, and previousWeight
    model.load_state_dict(torch.load(os.path.join(root_dir, ExpName + '_' + dataDir + '_' + str(previousWeight) + '_'+'.pth')))

# Initialize a Warp layer for image transformation with bilinear interpolation and zero-padding
warp_layer = Warp(mode='bilinear', padding_mode='zeros').to(device)

# Define the image loss function using Global Mutual Information
# This loss function measures the mutual information between images
image_loss = GlobalMutualInformationLoss()

# Define a custom SSIM loss class inheriting from SSIM
class SSIMLoss(SSIM):
    # Override the forward method to compute the SSIM loss as 1 minus the SSIM score
    def forward(self, x, y):
        return 1. - super().forward(x, y)
    
# Instantiate the SSIMLoss with 1 channel (e.g., grayscale images) and move it to the device
label_SSIM = SSIMLoss(n_channels=1).to(device)  # Use .cuda() if GPU support is needed

# Define the label loss function using Dice Loss
label_loss = DiceLoss()

# Optionally, you can use a MultiScaleLoss that combines Dice Loss across different scales
# label_loss = MultiScaleLoss(label_loss, scales=[0, 1, 2, 4, 8, 16])

# Define the regularization term using Bending Energy Loss
regularization = BendingEnergyLoss()

# Initialize the Adam optimizer with a learning rate of 0.0002 for the model's parameters
optimizer = torch.optim.Adam(model.parameters(), 0.0002)

# Optionally, define a learning rate scheduler that reduces the learning rate by a factor of 0.5 every 8 epochs
# exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.5, verbose=True)

# Define a metric for evaluating Dice score, including the background class and computing the mean
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

# Optionally, you can compute the mean Dice score directly if you have predictions and ground truth
# dice_metric = compute_meandice(y_pred, y, include_background=True)

In [None]:
def make_one_hot(labels, device, C=2):
    '''
    Converts an integer label tensor to a one-hot encoded tensor.
    
    Parameters
    ----------
    labels : torch.Tensor
        A tensor of shape (N, 1, H, W) where N is batch size, 
        and each value is an integer representing the correct classification.
    device : torch.device
        The device to which the tensor should be moved (e.g., 'cpu' or 'cuda').
    C : int
        The number of classes in the labels.
    
    Returns
    -------
    target : torch.Tensor
        A tensor of shape (N, C, H, W) where C is the number of classes.
        Each element is one-hot encoded, indicating the class label.
    '''
    # Convert labels to long type tensor (integer type)
    labels = labels.long()
    
    # Create a tensor of shape (N, C, H, W) filled with zeros
    one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_().to(device)
    
    # Use scatter_ to set the appropriate index for each label to 1
    # '1' is placed in the channel corresponding to the label value
    target = one_hot.scatter_(1, labels.data, 1)
    
    # Wrap the tensor in a Variable (for backward compatibility; can be omitted if not needed)
    target = Variable(target)
        
    return target

In [None]:
class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):  
        super(VariationalEncoder, self).__init__()
        
        # Define the convolutional layers for encoding
        self.conv1 = nn.Conv2d(1, 8, 3, stride=2, padding=1)  # First conv layer: input channels=1, output channels=8
        self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1) # Second conv layer: input channels=8, output channels=16
        self.batch2 = nn.BatchNorm2d(16)  # Batch normalization after second conv layer
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # Third conv layer: input channels=16, output channels=32
        
        # Define linear layers for mapping to latent space
        self.linear1 = nn.Linear(image_size//8 * image_size//8 * 32, 128)  # Linear layer to reduce dimensions to 128
        self.linear2 = nn.Linear(128, latent_dims)  # Linear layer to output mean of latent space
        self.linear3 = nn.Linear(128, latent_dims)  # Linear layer to output log-variance of latent space

        # Define a Normal distribution for sampling in latent space
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.to(device)  # Move distribution parameters to the device (GPU/CPU)
        self.N.scale = self.N.scale.to(device)
        self.kl = 0  # Initialize KL divergence

    def forward(self, x):
        x = x.to(device)  # Move input data to the device
        x = F.relu(self.conv1(x))  # Apply first convolutional layer followed by ReLU activation
        x = F.relu(self.batch2(self.conv2(x)))  # Apply second convolutional layer, batch normalization, and ReLU activation
        x = F.relu(self.conv3(x))  # Apply third convolutional layer followed by ReLU activation
        x = torch.flatten(x, start_dim=1)  # Flatten the tensor to feed into linear layers
        x = F.relu(self.linear1(x))  # Apply linear layer followed by ReLU activation
        mu = self.linear2(x)  # Compute the mean of the latent space distribution
        sigma = torch.exp(self.linear3(x))  # Compute the standard deviation of the latent space distribution
        z = mu + sigma * self.N.sample(mu.shape)  # Sample from the latent space using reparameterization trick
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()  # Compute the KL divergence
        return z  # Return the latent space sample

class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super().__init__()

        # Define linear layers for decoding
        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 128),  # Linear layer to expand latent dimensions
            nn.ReLU(True),  # ReLU activation
            nn.Linear(128, image_size//8 * image_size//8 * 32),  # Linear layer to map to the target size
            nn.ReLU(True)  # ReLU activation
        )

        # Define layer to reshape tensor for convolutional layers
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, image_size//8, image_size//8))
        
        # Define the decoder convolutional layers
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # Transposed convolution to upsample
            nn.BatchNorm2d(16),  # Batch normalization
            nn.ReLU(True),  # ReLU activation
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),  # Transposed convolution to further upsample
            nn.BatchNorm2d(8),  # Batch normalization
            nn.ReLU(True),  # ReLU activation
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)  # Transposed convolution to get back to original channel size (1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)  # Pass through linear layers to expand latent representation
        x = self.unflatten(x)  # Reshape the tensor to the format expected by convolutional layers
        x = self.decoder_conv(x)  # Pass through transposed convolutional layers to reconstruct image
        x = torch.sigmoid(x)  # Apply sigmoid activation to output pixel values in the range [0, 1]
        return x  # Return the reconstructed image

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = VariationalEncoder(latent_dims)  # Initialize the encoder
        self.decoder = Decoder(latent_dims)  # Initialize the decoder

    def forward(self, x):
        x = x.to(device)  # Move input data to the device
        z = self.encoder(x)  # Encode input to get latent space representation
        return self.decoder(z)  # Decode latent space representation to reconstruct the input image

# Set the random seed for reproducibility
torch.manual_seed(0)

# Define the dimensionality of the latent space
d = 8

# Initialize the first VAE model and load pretrained weights
Myo_VAE = VariationalAutoencoder(latent_dims=d)
Myo_VAE.to(device)  # Move model to the device
Myo_VAE.load_state_dict(torch.load('Myo_VAE_' + str(image_size) + '_.pth', map_location=device))  # Load the model weights
Myo_VAE.eval()  # Set the model to evaluation mode

# Initialize the second VAE model and load pretrained weights
LV_VAE = VariationalAutoencoder(latent_dims=d)
LV_VAE.to(device)  # Move model to the device
LV_VAE.load_state_dict(torch.load('LV_VAE_' + str(image_size) + '_.pth', map_location=device))  # Load the model weights
LV_VAE.eval()  # Set the model to evaluation mode

# Optional: Print model summaries (requires torchinfo package)
# import torchinfo
torchinfo.summary(Myo_VAE, input_size=(2, 1, 512, 512), depth=100)
torchinfo.summary(LV_VAE, input_size=(2, 1, 512, 512), depth=100)


In [None]:
# Define the L2 loss function (Mean Squared Error Loss)
L2_loss = nn.MSELoss(reduction='mean')

def globalLoss(trueMask, predMask):
    '''
    Computes a combined loss for segmentation masks using Variational Autoencoders.
    
    Parameters
    ----------
    trueMask : torch.Tensor
        The ground truth masks with class labels.
    predMask : torch.Tensor
        The predicted masks from the model.
        
    Returns
    -------
    loss_ : torch.Tensor
        The combined loss calculated from both myocardial and LV masks.
    '''
    
    # Clone the true mask to create separate masks for different regions
    myo_trueMask = trueMask.clone()
    LV_trueMask = trueMask.clone()

    # Clone the predicted mask for thresholding and classification
    thresholded_predMask = predMask.clone()

    # Define class labels
    background_class = torch.zeros_like(predMask)
    myo_class = torch.ones_like(predMask)
    lv_class = 2 * torch.ones_like(predMask)

    # Apply thresholding to classify pixels
    # Pixels < 0.98 are classified as background
    thresholded_predMask_ = torch.where(thresholded_predMask < 0.98, background_class, thresholded_predMask)
    
    # Pixels > 1.0 are classified as LV class
    thresholded_predMask_ = torch.where(thresholded_predMask_ > 1.0, lv_class, thresholded_predMask_)
    
    # Pixels in the range [0.98, 1.0] are classified as myocardial class
    thresholded_predMask_ = torch.where((thresholded_predMask_ <= 1.0) & (thresholded_predMask_ >= 0.98), myo_class, thresholded_predMask_)

    # Separate the predicted mask into myocardial and LV components
    myo_predMask = thresholded_predMask_.clone()
    LV_predMask = thresholded_predMask_.clone()

    # Set LV pixels in the myocardial mask to 0 and vice versa
    myo_predMask[myo_predMask == 2] = 0
    myo_trueMask[myo_trueMask == 2] = 0

    # Set myocardial pixels in the LV mask to 0 and normalize LV mask to be in the range [0, 1]
    LV_predMask[LV_predMask == 1] = 0
    LV_predMask = LV_predMask / 2

    # Set myocardial pixels in the LV true mask to 0 and normalize LV true mask to be in the range [0, 1]
    LV_trueMask[LV_trueMask == 1] = 0
    LV_trueMask = LV_trueMask / 2

    # Pass the myocardial masks through the Myo_VAE model
    attribute_myo_true = Myo_VAE(myo_trueMask)
    attribute_myo_pred = Myo_VAE(myo_predMask)

    # Calculate the L2 loss for the myocardial masks
    myo_L2 = L2_loss(attribute_myo_true, attribute_myo_pred)

    # Pass the LV masks through the LV_VAE model
    attribute_lv_true = LV_VAE(LV_trueMask)
    attribute_lv_pred = LV_VAE(LV_predMask)

    # Calculate the L2 loss for the LV masks
    lv_L2 = L2_loss(attribute_lv_true, attribute_lv_pred)

    # Combine the losses for myocardial and LV masks
    loss_ = myo_L2 + lv_L2

    return loss_


In [None]:
# Function to create real labels (1s) for the discriminator
def label_real(size):
    '''
    Create a tensor of real labels (all ones) with the given size.
    
    Parameters
    ----------
    size : tuple
        The size of the tensor to be created.
        
    Returns
    -------
    data : torch.Tensor
        Tensor filled with ones, representing real labels.
    '''
    data = torch.ones(size, 1)  # Create a tensor of ones
    return data.to(device)  # Move tensor to the specified device (GPU/CPU)

# Function to create fake labels (0s) for the discriminator
def label_fake(size):
    '''
    Create a tensor of fake labels (all zeros) with the given size.
    
    Parameters
    ----------
    size : tuple
        The size of the tensor to be created.
        
    Returns
    -------
    data : torch.Tensor
        Tensor filled with zeros, representing fake labels.
    '''
    data = torch.zeros(size, 1)  # Create a tensor of zeros
    return data.to(device)  # Move tensor to the specified device (GPU/CPU)

# Initialize the discriminator network
discriminator = monai.networks.nets.Discriminator(
    in_shape=(1, image_size, image_size),  # Input shape: grayscale images with size (img_size x img_size)
    channels=(8, 16, 32, 64, 1),       # Number of channels for each layer
    strides=(2, 2, 2, 2, 1),           # Strides for each convolutional layer
    num_res_units=2,                    # Number of residual units in the network
    kernel_size=3,                      # Kernel size for the convolutions
    dropout=0.10,                       # Dropout rate to prevent overfitting
    act='LEAKYRELU'                     # Activation function used: Leaky ReLU
).to(device)  # Move the discriminator to the specified device (GPU/CPU)

# Loss function for binary classification
criterion = nn.BCELoss()
'''
BCELoss stands for Binary Cross Entropy Loss, used for binary classification tasks.
In the context of GANs, it helps in distinguishing between real and fake data.
'''

# Optimizer for the discriminator
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
'''
Adam optimizer is used to update the discriminator's weights.
Learning rate is set to 0.0002, which controls how much to change the weights in each update.
'''

# Optional: Print the summary of the discriminator network (commented out)
torchinfo.summary(discriminator, input_size=(2, 1, image_size, image_size), depth=100)

In [None]:
# Initialize variables for tracking performance metrics and losses during training
total_epochs = num_epochs  # Total number of epochs for training
train_epoch_losses, val_epoch_losses = [], []  # Lists to store average loss for each epoch (training and validation)
validation_interval = 1  # Interval for running validation (every epoch in this case)
best_dice_metric = -1  # Best Dice Similarity Coefficient (DSC) achieved so far
best_dice_epoch = -1  # Epoch at which the best DSC was achieved
dice_metric_history = []  # List to store DSC metrics for each validation epoch
lowest_loss = 1e10  # Initialize to a large value to track the lowest loss
best_dice_score = 0  # Highest DSC value achieved so far
epoch_counter = 1  # Counter for epochs, purpose specified later

train_L2_losses = []  # To store L2 loss for training data across epochs
val_L2_losses = []  # To store L2 loss for validation data across epochs

train_MI_losses = []  # To store Mutual Information (MI) loss for training data across epochs
val_MI_losses = []  # To store Mutual Information loss for validation data across epochs

# Training loop over all epochs
for epoch in range(total_epochs):
    print("-" * 100)
    print(f"Epoch {epoch + 1}/{total_epochs}")
    
    model.train()  # Set model to training mode
    epoch_loss_sum, step_count = 0, 0  # Initialize sum of losses and step counter for the epoch

    # Initialize metrics for the current epoch
    epoch_dice_sum, epoch_L2_sum, epoch_MI_sum, epoch_reg_sum = 0, 0, 0, 0
    
    # Iterate over batches of training data
    for fixed_train_img, fixed_train_mask, moving_train_img, moving_train_mask in tqdm(zip(fixed_train_img_loader,
                                                                                          fixed_train_mask_loader, 
                                                                                          moving_train_img_loader,
                                                                                          moving_train_mask_loader)):   
        step_count += 1  # Increment step counter
        
        optimizer.zero_grad()  # Reset the gradients for the optimizer
        discriminator_optimizer.zero_grad()  # Reset the gradients for the discriminator

        # Move data to the specified device (e.g., GPU)
        fixed_train_img = fixed_train_img.to(device)
        fixed_train_mask = fixed_train_mask.to(device)
        moving_train_img = moving_train_img.to(device)
        moving_train_mask = moving_train_mask.to(device)
        
        # Compute the deformation field using the model
        deformation_field_train = model(torch.cat((moving_train_img, fixed_train_img), dim=1))
        
        # Warp images and masks using the deformation field
        warped_image_train = warp_layer(moving_train_img, deformation_field_train)
        warped_mask_train = warp_layer(moving_train_mask, deformation_field_train)

        # Create labels for real and fake images for the discriminator
        real_label = label_real(fixed_train_img.size(0))
        fake_label = label_fake(fixed_train_img.size(0))

        # Compute loss for real images in the discriminator
        discriminator_output_real = discriminator(fixed_train_img)
        loss_real = criterion(discriminator_output_real, real_label)
        loss_real.backward()  # Backpropagate the loss

        # Compute loss for fake images in the discriminator
        discriminator_output_fake = discriminator(warped_image_train.detach())
        loss_fake = criterion(discriminator_output_fake, fake_label)
        loss_fake.backward()  # Backpropagate the loss

        discriminator_optimizer.step()  # Update the discriminator parameters

        # Average discriminator loss
        discriminator_loss = (loss_real + loss_fake) / 2

        # Compute loss for the generator (model) based on discriminator's output
        generator_output = discriminator(warped_image_train)
        loss_generator = criterion(generator_output, real_label)

        # Compute various losses for the generator
        image_similarity_loss = image_loss(warped_image_train, fixed_train_img)  # Image similarity loss
        dice_loss = label_loss(fixed_train_mask, warped_mask_train)  # Dice similarity loss
        global_loss = globalLoss(fixed_train_mask, warped_mask_train) - 1  # Global loss (modified)
        regularization_loss = regularization(deformation_field_train)  # Regularization loss on deformation field

        # Total loss for the generator combining all components
        total_loss = (image_similarity_loss + 
                     regularization_loss + 
                     2 * global_loss + 
                     2 * dice_loss + 
                     0.0001 * loss_generator)
        total_loss.backward()  # Backpropagate the total loss
        optimizer.step()  # Update the generator parameters

        # Accumulate losses and metrics for this step
        epoch_loss_sum += total_loss.item()
        epoch_dice_sum += dice_loss.item()
        epoch_L2_sum += global_loss.item()
        epoch_MI_sum += image_similarity_loss.item()
        epoch_reg_sum += regularization_loss.item()

        # Update dice metric using a helper function
        dice_metric(y_pred=make_one_hot(warped_mask_train, device, C=3), 
                    y=make_one_hot(fixed_train_mask, device, C=3))

    # Print the learning rate for the optimizer
    for param_group in optimizer.param_groups:
        print("Learning rate: ", param_group['lr'])

    # Aggregate and average metrics and losses for the epoch
    avg_dice_metric = dice_metric.aggregate().item()
    dice_metric.reset()
    avg_epoch_loss = epoch_loss_sum / step_count
    avg_epoch_dice = epoch_dice_sum / step_count
    avg_epoch_L2 = epoch_L2_sum / step_count
    avg_epoch_MI = epoch_MI_sum / step_count
    avg_epoch_reg = epoch_reg_sum / step_count

    # Store L2 and MI losses for training in lists
    train_L2_losses.append(avg_epoch_L2)
    train_MI_losses.append(avg_epoch_MI)

    # Print statistics for the current epoch
    print(f"Epoch {epoch + 1}: Average training total Loss: {avg_epoch_loss:.5f}")
    print(f"Epoch {epoch + 1}: Average training DSC: {avg_dice_metric:.5f}")
    print(f"Epoch {epoch + 1}: Average training DSC Loss: {avg_epoch_dice:.5f}")
    print(f"Epoch {epoch + 1}: Average training L2 Loss: {avg_epoch_L2:.5f}")
    print(f"Epoch {epoch + 1}: Average training MI Loss: {avg_epoch_MI:.5f}")
    print(f"Epoch {epoch + 1}: Average training DDF Loss: {avg_epoch_reg:.5f}")
    print("-" * 60)

    # Validation step every `validation_interval` epochs
    if (epoch + 1) % validation_interval == 0 or epoch == 0:
        model.eval()  # Set model to evaluation mode
        val_loss_sum, val_dice_sum = 0, 0
        val_L2_sum, val_MI_sum = 0, 0
        val_step_count = 0
        with torch.no_grad():  # No gradient computation during validation
            for fixed_val_img, fixed_val_mask, moving_val_img, moving_val_mask in zip(fixed_val_img_loader, 
                                                                                      fixed_val_mask_loader,
                                                                                      moving_val_img_loader,
                                                                                      moving_val_mask_loader):  

                # Move validation data to the specified device
                fixed_val_img = fixed_val_img.to(device)
                fixed_val_mask = fixed_val_mask.to(device)
                moving_val_img = moving_val_img.to(device)
                moving_val_mask = moving_val_mask.to(device)

                # Compute deformation field for validation data
                deformation_field_val = model(torch.cat((moving_val_img, fixed_val_img), dim=1))
                
                # Warp images and masks for validation
                warped_image_val = warp_layer(moving_val_img, deformation_field_val)
                warped_mask_val = warp_layer(moving_val_mask, deformation_field_val)

                # Compute generator loss on validation data
                real_label = label_real(warped_image_val.size(0))
                generator_output_val = discriminator(warped_image_val)
                loss_generator_val = criterion(generator_output_val, real_label)

                # Compute various losses for validation
                image_similarity_loss_val = image_loss(warped_image_val, fixed_val_img)
                dice_loss_val = label_loss(fixed_val_mask, warped_mask_val)
                global_loss_val = globalLoss(fixed_val_mask, warped_mask_val) - 1
                regularization_loss_val = regularization(deformation_field_val)

                # Total loss for validation data
                total_val_loss = (image_similarity_loss_val + 
                                  regularization_loss_val + 
                                  2 * global_loss_val + 
                                  2 * dice_loss_val + 
                                  0.0001 * loss_generator_val)
                
                # Accumulate validation losses and metrics
                val_loss_sum += total_val_loss.item()
                val_L2_sum += global_loss_val.item()
                val_MI_sum += image_similarity_loss_val.item()
                val_step_count += 1

                # Update dice metric for validation
                dice_metric(y_pred=make_one_hot(warped_mask_val, device, C=3), 
                            y=make_one_hot(fixed_val_mask, device, C=3))
                
            # Compute average validation losses and metrics
            avg_val_loss = val_loss_sum / val_step_count
            avg_val_L2 = val_L2_sum / val_step_count
            avg_val_MI = val_MI_sum / val_step_count
            val_L2_losses.append(avg_val_L2)
            val_MI_losses.append(avg_val_MI)
            val_epoch_losses.append(avg_val_loss)
            avg_dice_metric = dice_metric.aggregate().item()
            dice_metric_history.append(avg_dice_metric)
            dice_metric.reset()
            
            # Save the model if it achieves the best DSC so far
            if avg_dice_metric > best_dice_score:
                best_dice_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(dataset_root_dir, checkpoint_file_name + '.pth'))
                print(f"Validation DSC improved from {best_dice_score:2.4f} to {avg_dice_metric:2.4f}! "
                      f"Saving best model as {checkpoint_file_name + '.pth'}")
                best_dice_score = avg_dice_metric
                
            # Print validation statistics
            print(f"\nCurrent mean DSC: {avg_dice_metric:.4f} \t Current validation loss: {avg_val_loss:.4f}\n\n"
                  f"Best DSC: {best_dice_score:.4f} at epoch {best_dice_epoch}")
            
# Save training and validation metrics to a CSV file
metrics_dataframe = pd.DataFrame({
     'Validation_L2_Loss': np.array(val_L2_losses),
     'Training_L2_Loss': np.array(train_L2_losses),
     'Validation_MI_Loss': np.array(val_MI_losses),
     'Training_MI_Loss': np.array(train_MI_losses),
})

metrics_dataframe.to_csv(experiment_name + '.csv')
