In [None]:
from monai.utils import set_determinism, first  # Utility functions from MONAI, including setting random seed for reproducibility and retrieving the first item from an iterable
from monai.transforms import (  # Importing MONAI data transformations used for preprocessing and augmentation
    EnsureChannelFirstD,  # Ensures image has a channel-first format, necessary for PyTorch models
    Compose,  # Composes multiple transformations into a single callable transform
    LoadImageD,  # Loads medical images from file, supporting multiple formats like NIfTI, PNG, and DICOM
    RandRotateD,  # Applies a random rotation transformation to images, useful for data augmentation
    RandZoomD,  # Applies a random zoom transformation, helping with scale-invariant learning
    ScaleIntensityRanged  # Normalizes image intensities within a given range, improving model performance
)

import monai  # Import MONAI framework for medical imaging AI, including deep learning utilities
from monai.data import DataLoader, Dataset, CacheDataset  # Data handling classes from MONAI, with caching for performance improvement
from monai.config import print_config, USE_COMPILED  # Print MONAI config details and check if compiled layers are used for performance
from monai.networks.nets import *  # Import MONAI's neural network architectures, including U-Net, SwinUNETR, etc.
from monai.networks.blocks import Warp  # Warp block for spatial transformations, useful for registration and motion estimation
from monai.apps import MedNISTDataset  # Pre-loaded medical imaging dataset for experimentation and benchmarking
import torch.nn.functional as F  # Import PyTorch functional utilities for operations like activation functions and loss calculations

from torchinfo import summary  # Utility for model summary visualization, displaying layer-wise structure and parameter counts

from fvcore.nn import FlopCountAnalysis  # Tool for analyzing floating point operations (FLOPs), helpful for model complexity estimation

from glob import glob  # Module for file path pattern matching, useful for loading datasets from directories
import cv2  # OpenCV for image processing, commonly used for visualization and preprocessing
import torchmetrics  # PyTorch Metrics for model evaluation, providing a variety of metrics for classification and segmentation

from torch.utils.tensorboard import SummaryWriter  # TensorBoard logging utility for monitoring training progress and performance metrics

from torch.autograd import Variable  # Autograd wrapper for automatic differentiation, useful for tracking gradients

from scipy.spatial.distance import directed_hausdorff  # Function to compute Hausdorff distance, measuring shape similarity
import pandas as pd  # Pandas for data handling and analysis, commonly used for logging results
import torch.nn as nn  # PyTorch's neural network module, providing layers like Conv2D, Linear, BatchNorm, etc.

import numpy as np  # NumPy for numerical computations, widely used for array manipulations
import torch  # PyTorch main package for deep learning and tensor computations
from torch.nn import MSELoss  # Mean Squared Error loss function, commonly used for regression tasks
import matplotlib.pyplot as plt  # Plotting library for visualizing data, including loss curves and predictions
import os  # OS module for file handling, useful for managing datasets and saving models
import tempfile  # Utility for handling temporary files, often used in testing and caching
from monai.losses import *  # MONAI loss functions, including DiceLoss, FocalLoss, etc., for medical imaging segmentation
from monai.metrics import *  # MONAI metrics for model evaluation, such as Dice Score and Hausdorff Distance
from piqa import SSIM  # Structural Similarity Index (SSIM) for image quality assessment, measuring perceived similarity
import visdom  # Visualization tool for real-time data visualization, useful for monitoring model training
from tqdm import tqdm  # Progress bar utility, helpful for tracking training and data loading progress

import torch  # Re-importing PyTorch (redundant, but common in large projects)
import torch.nn as nn  # Re-importing PyTorch NN module (also redundant)
import torch.optim as optim  # Optimizers for training neural networks, including Adam, SGD, and RMSprop
from torch.utils.data import DataLoader  # DataLoader for batch processing and efficient data handling
from torchvision import transforms  # Torchvision transformations for image processing, useful for augmentations
# import torchio as tio  # (Commented out) TorchIO for medical image augmentation and preprocessing (not currently used)
import nibabel as nib  # NiBabel for handling medical imaging formats (NIfTI, DICOM, Analyze, etc.)
from helper import * # Load custom helper
import config  # Importing configuration file (assumed to be custom), likely containing hyperparameters and paths
print_config()  # Print MONAI configuration details, helping with debugging and environment setup
set_determinism(42)  # Set random seed for reproducibility, ensuring consistent experimental results
import torch, torchinfo  # Re-importing PyTorch and torchinfo (redundant)
from torchviz import make_dot, make_dot_from_trace  # Visualization tools for neural networks, creating computational graphs
from helper import make_one_hot  # Custom helper function to convert labels to one-hot encoding, useful for classification and segmentation

import albumentations as A  # Albumentations library for image augmentation, providing advanced transformations
from albumentations.pytorch import ToTensorV2  # Convert Albumentations output to PyTorch tensor, ensuring compatibility


In [None]:
# Checking the number of GPUs available in the system
# torch.cuda.device_count() returns the total number of GPUs in the system.
print('How many GPUs = ' + str(torch.cuda.device_count()))

# Checking the availability of CUDA-enabled device (GPU)
# If a GPU is available, the device will be set to 'cuda:0', otherwise it will default to 'cpu'.
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Printing the device being used (either CUDA or CPU)
print(device)

# If CUDA is not available (i.e., no GPUs), an exception is raised with a warning message.
if not torch.cuda.is_available():
    raise Exception("GPU not available. CPU training will be too slow.")

# Printing the name of the CUDA device (GPU) at index 0, if available
# torch.cuda.get_device_name(0) returns the name of the GPU device at index 0.
print("device name", torch.cuda.get_device_name(0))


In [None]:
# Set the number of intermediate frames and the number of attention heads for multi-head attention (MH).
num_mid_frames = 3
num_heads_ = 8

# Check if the number of intermediate frames is None
if num_mid_frames is None:
    # If num_mid_frames is None, set saveFile name and directory path for the data (assuming default 0 for mid frames).
    saveFile = 'FCN8s_TAM_MH_' + str(num_heads_) + '_Mid_' + str(0)
    data_dir = 'data/CAMUS_data/'  # Default data directory path when num_mid_frames is None.
    print(data_dir) 

# If num_mid_frames is not None (i.e., there are intermediate frames to consider)
if num_mid_frames is not None:
    # If num_mid_frames is provided, create the saveFile name and the directory path for data based on the number of frames.
    saveFile = 'FCN8s_TAM_MH_' + str(num_heads_) + '_Mid_' + str(num_mid_frames)
    data_dir = 'data/CAMUS_data_Mid' + str(num_mid_frames) + '/'  # Directory changes to include the number of mid frames.
    print(data_dir) 

# Set the checkpoint path by appending the file extension '.pth' to the saveFile name.
checkpoint_path = saveFile + '.pth'


In [None]:
# Import the CardiacDataset class from the data_loader module.
# This class is used for loading and handling cardiac dataset images and corresponding masks.
from data_loader import CardiacDataset

# Print the total number of training images by counting files with 'ED.png' in the specified directory.
# The glob function returns all file paths matching the given pattern.
print('Total train image Samples=' + str(len(glob(data_dir+"train/image/*ED.png"))))

# Print the total number of training masks (i.e., labels) by counting files with 'ED.png' in the mask directory.
print('Total train image Samples=' + str(len(glob(data_dir+"train/mask/*ED.png"))))

# Print the total number of validation images (test images) in the specified directory.
print('Total val image Samples=' + str(len(glob(data_dir+"test/image/*ED.png"))))

# Print the total number of validation masks (test labels) in the specified directory.
print('Total val mask Samples=' + str(len(glob(data_dir+"test/mask/*ED.png"))))

# Define a series of augmentation transformations using the Albumentations library.
# This transformation will be applied to the images and masks during training.
transform = A.Compose([
    A.HorizontalFlip(p=0.5),         # Randomly flip the image horizontally with 50% probability.
    A.VerticalFlip(p=0.5),           # Randomly flip the image vertically with 50% probability.
    A.RandomRotate90(p=0.5),         # Randomly rotate the image by 90 degrees with 50% probability.
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, border_mode=0, rotate_limit=45, p=0.5),  # Randomly apply shift, scale, and rotation to the image with 50% probability.
    ToTensorV2()                     # Convert the image and mask to PyTorch tensors for model input.
],
    additional_targets={              # Apply transformations to additional targets (e.g., image2, mask2).
        'image2': 'image',
        'mask2': 'mask'
    })

# Create a DataLoader for the training dataset.
# This DataLoader will load the training images and masks from the given directory.
# The dataset will be shuffled, and multiple workers will be used to load the data in parallel.
trainData = DataLoader(
    CardiacDataset(
        sorted(glob(data_dir+"test/image/*ED.png")),    # Sorted list of training images.
        sorted(glob(data_dir+"test/mask/*ED.png")),     # Sorted list of training masks.
        num_mid_frames=num_mid_frames,                  # Number of intermediate frames used for training.
        transform=None                                  # No transformations applied (as augmentations will be done within the dataset).
    ),
    batch_size=config.trainBatch,          # Batch size for training.
    shuffle=config.shuffle_,               # Whether to shuffle the dataset during training.
    num_workers=config.num_workers         # Number of workers to load data in parallel.
)

# Create a DataLoader for the validation/test dataset.
# This DataLoader will load the validation images and masks from the given directory.
valData = DataLoader(
    CardiacDataset(
        sorted(glob(data_dir+"test/image/*ED.png")),    # Sorted list of validation images.
        sorted(glob(data_dir+"test/mask/*ED.png")),     # Sorted list of validation masks.
        num_mid_frames=num_mid_frames,                  # Number of intermediate frames used for validation.
        transform=None                                  # No transformations applied for validation.
    ),
    batch_size=config.valBatch,           # Batch size for validation.
    shuffle=config.shuffle_val,           # Whether to shuffle the dataset during validation.
    num_workers=config.num_workers        # Number of workers to load validation data in parallel.
)

# Retrieve a sample from the training DataLoader.
# This will give us the first batch of data from the training set.
train_sample = first(trainData)

# Print the keys of the train_sample dictionary to inspect the loaded data.
print(train_sample.keys()) 

# Print the shape of the 'image' and 'mask' for the first sample in the training set.
# This will show the dimensions of the input image and corresponding mask.
print('train ED img  ' + str(train_sample[list(train_sample.keys())[0]]['image'].shape))
print('train ED mask   ' + str(train_sample[list(train_sample.keys())[0]]['mask'].shape))


In [None]:
# Usage for visualizing a sample from the training dataset
# The function visualize_dataset_sample will visualize a sample from the 'trainData' DataLoader.

visualize_dataset_sample(
    trainData,                   # The DataLoader for the training dataset, which provides batches of training data.
    list(train_sample.keys()),    # The keys of the train_sample dictionary, which typically contain 'image' and 'mask'.
    device='cpu',                # The device (CPU or GPU) to perform the visualization on. 'cpu' is specified here.
    num_classes=4,               # The number of classes (e.g., segmentation classes). In this case, it's set to 4.
    img_size=config.img_size,    # The image size that the input images should be resized to. Defined in the config.
    dataset_type='Train'         # The type of dataset being visualized. This is set to 'Train' to specify it's a training dataset.
)


In [None]:
# Usage for visualizing a sample from the validation dataset
# The function visualize_dataset_sample will visualize a sample from the 'valData' DataLoader.

visualize_dataset_sample(
    valData,                      # The DataLoader for the validation dataset, which provides batches of validation data.
    list(train_sample.keys()),     # The keys of the train_sample dictionary, which typically contain 'image' and 'mask'.
    device='cpu',                 # The device (CPU or GPU) to perform the visualization on. 'cpu' is specified here.
    num_classes=4,                # The number of classes (e.g., segmentation classes). In this case, it's set to 4.
    img_size=config.img_size,     # The image size that the input images should be resized to. Defined in the config.
    dataset_type='Validation'     # The type of dataset being visualized. This is set to 'Validation' to specify it's a validation dataset.
)


In [None]:
"""
Convolutional block:
    This block consists of two 3x3 convolutional layers, each followed by batch normalization 
    and ReLU activation. It is commonly used in CNN architectures to extract hierarchical features 
    and stabilize training through normalization.
"""

class ConvBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        """
        Initializes the convolutional block.
        
        Parameters:
        - input_channels: int, the number of input channels (e.g., 3 for RGB images).
        - output_channels: int, the number of output channels (the number of filters for this layer).
        """
        super().__init__()

        # First convolutional layer (3x3 filter, padding to preserve spatial dimensions)
        self.conv_layer1 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(output_channels)  # Batch normalization for the output of conv_layer1

        # Second convolutional layer (3x3 filter, padding to preserve spatial dimensions)
        self.conv_layer2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(output_channels)  # Batch normalization for the output of conv_layer2

        # ReLU activation function
        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Forward pass through the convolutional block.
        
        Parameters:
        - x: Tensor, input tensor to be passed through the convolutional block.
        
        Returns:
        - x: Tensor, the processed output tensor after passing through convolution, batch normalization, and activation.
        """
        # Pass the input through the first convolutional layer, batch normalization, and ReLU activation
        x = self.conv_layer1(x)  # Apply the first convolution
        x = self.batch_norm1(x)   # Apply batch normalization
        x = self.relu(x)          # Apply ReLU activation

        # Pass the result through the second convolutional layer, batch normalization, and ReLU activation
        x = self.conv_layer2(x)   # Apply the second convolution
        x = self.batch_norm2(x)   # Apply batch normalization
        x = self.relu(x)          # Apply ReLU activation

        return x  # Return the final processed tensor

    
"""
Encoder block:
    This block consists of a convolutional block followed by a max pooling layer.
    The number of filters doubles and the spatial dimensions (height and width) are halved after every block.
"""

class EncoderBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        """
        Initializes the encoder block.
        
        Parameters:
        - input_channels: int, the number of input channels (e.g., 3 for RGB images).
        - output_channels: int, the number of output channels (the number of filters for this layer).
        """
        super().__init__()

        # Initialize the convolutional block
        self.conv_block = ConvBlock(input_channels, output_channels)
        
        # Initialize max pooling layer (2x2 kernel to halve the spatial dimensions)
        self.max_pool = nn.MaxPool2d((2, 2))

    def forward(self, x):
        """
        Forward pass through the encoder block.
        
        Parameters:
        - x: Tensor, input tensor to be passed through the encoder block.
        
        Returns:
        - conv_output: Tensor, the output from the convolutional block.
        - pool_output: Tensor, the output from the max pooling operation.
        """
        # Pass input through the convolutional block
        conv_output = self.conv_block(x)
        
        # Pass the output through the max pooling layer
        pool_output = self.max_pool(conv_output)

        return conv_output, pool_output  # Return both convolutional output and pooled output



    
"""
Decoder block:
    The decoder block begins with a transpose convolution, followed by a concatenation with the skip
    connection from the encoder block. After concatenation, a convolutional block is applied.
    In this block, the number of filters decreases by half and the spatial dimensions (height and width) double.
"""

class DecoderBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        """
        Initializes the decoder block.
        
        Parameters:
        - input_channels: int, the number of input channels (output of the previous decoder or encoder block).
        - output_channels: int, the number of output channels (number of filters for the convolutional block).
        """
        super().__init__()

        # Transpose convolution for upsampling (doubles the spatial dimensions)
        self.up_conv = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2, padding=0)
        
        # Convolutional block after upsampling and concatenation
        self.conv_block = ConvBlock(output_channels * 2, output_channels)

    def forward(self, x, skip_connection):
        """
        Forward pass through the decoder block.
        
        Parameters:
        - x: Tensor, input tensor to be passed through the decoder block (usually the output of the previous decoder block).
        - skip_connection: Tensor, skip connection from the encoder block to be concatenated after upsampling.
        
        Returns:
        - x: Tensor, the processed output tensor after upsampling, concatenation, and convolution.
        """
        # Upsample the input tensor using transpose convolution (also known as a deconvolution)
        x = self.up_conv(x)

        # Concatenate the skip connection with the upsampled output along the channel axis (axis=1)
        x = torch.cat([x, skip_connection], dim=1)  # Concatenate along the channel dimension

        # Apply convolution block after concatenation
        x = self.conv_block(x)

        return x  # Return the final processed tensor

    

class Encoder(nn.Module):
    """
    Encoder block for feature extraction from input frames.
    This encoder processes multiple frames independently, and applies attention mechanism at the bottleneck.
    
    The encoder is structured as follows:
    1. Multiple encoder blocks to progressively extract features.
    2. A bottleneck block for further feature extraction.
    3. An attention mechanism applied to the pooled features at the bottleneck.

    Args:
        input_depth (int): Number of input channels (default: 1).
        features (list): List of feature sizes for each encoder block (default: [64, 128, 256, 512, 1024]).
    """
    def __init__(self, input_depth=1, features=[64, 128, 256, 512, 1024]):
        super().__init__()

        # Encoder Blocks
        self.encoder_block1 = EncoderBlock(input_depth, features[0])
        self.encoder_block2 = EncoderBlock(features[0], features[1])
        self.encoder_block3 = EncoderBlock(features[1], features[2])
        self.encoder_block4 = EncoderBlock(features[2], features[3])

        # Bottleneck Block
        self.bottleneck = ConvBlock(features[3], features[4])

        # Attention Mechanism (applied on pooled features)
        self.attention_bottleneck = AttentionKQV(all_channels=features[4], embedding_dim=features[4])

    def forward(self, *frames):
        """
        Forward pass through the encoder. Processes multiple frames and applies attention at the bottleneck.
        
        Args:
            frames (tuple): Multiple frames to be processed independently.

        Returns:
            tuple: Tuple containing outputs from encoder stages and attention-enhanced bottleneck.
        """
        # Initialize lists to store outputs for each frame
        stage1_outputs, stage2_outputs, stage3_outputs, stage4_outputs, bottleneck_outputs = [], [], [], [], []

        # Process each frame independently
        for frame in frames:
            # Encoder Stage 1
            stage1, pool1 = self.encoder_block1(frame)
            
            # Encoder Stage 2
            stage2, pool2 = self.encoder_block2(pool1)
            
            # Encoder Stage 3
            stage3, pool3 = self.encoder_block3(pool2)
            
            # Encoder Stage 4
            stage4, pool4 = self.encoder_block4(pool3)
            
            # Bottleneck Stage
            bottleneck = self.bottleneck(pool4)
            
            # Collect outputs for each frame
            stage1_outputs.append(stage1)
            stage2_outputs.append(stage2)
            stage3_outputs.append(stage3)
            stage4_outputs.append(stage4)
            bottleneck_outputs.append(bottleneck)

        # Stack bottleneck outputs into a tensor for attention mechanism
        bottleneck_outputs = torch.stack(bottleneck_outputs, dim=0)

        # Apply attention mechanism on the stacked bottleneck outputs
        attention_bottleneck_outputs = self.attention_bottleneck(bottleneck_outputs)

        # Flatten the outputs into a tuple (for easy processing downstream)
        outputs = tuple(
            item for sublist in zip(stage1_outputs, stage2_outputs, stage3_outputs, stage4_outputs, attention_bottleneck_outputs)
            for item in sublist
        )

        return outputs



class AttentionKQV(nn.Module):
    """
    Attention mechanism applied on a set of frames using Query, Key, and Value (QKV) attention.
    The mechanism computes attention for each frame using other frames as context, applies gating,
    and combines the features for further processing.

    Args:
        all_channels (int): Number of input channels for the frames (default: 1024).
        embedding_dim (int): The dimension of the embeddings used for attention (default: 1024).
        num_heads (int): Number of attention heads for multi-head attention (default: 8).
    """
    
    def __init__(self, all_channels=1024, embedding_dim=1024, num_heads=8):
        super().__init__()

        # Linear layers for Query, Key, and Value projections
        self.query_linear = nn.Linear(all_channels, embedding_dim, bias=False)
        self.key_linear = nn.Linear(all_channels, embedding_dim, bias=False)
        self.value_linear = nn.Linear(all_channels, embedding_dim, bias=False)

        # Multi-Head Self-Attention
        self.self_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)

        # Gating mechanism (using convolution)
        self.gate_conv = nn.Conv2d(all_channels, 1, kernel_size=1, bias=False)
        self.gate_activation = nn.Sigmoid()

        # Convolutional layer for combining attention and original features
        self.combine_conv = nn.Conv2d(all_channels * 2, all_channels, kernel_size=3, padding=1, bias=False)

        # Batch normalization and activation functions
        self.batch_norm = nn.BatchNorm2d(all_channels)
        self.activation = nn.ReLU(inplace=True)

        # Final classifier to process the output
        self.classifier = nn.Conv2d(all_channels, all_channels, kernel_size=1, bias=True)

    def forward(self, frames):
        """
        Forward pass through the attention mechanism for a set of frames.

        Args:
            frames (list of torch.Tensor): List of frames (tensors) to be processed.

        Returns:
            list of torch.Tensor: List of outputs after applying attention and classification.
        """
        num_frames = len(frames)
        feature_size = frames[0].size()[2:]  # Get the spatial dimensions (H, W)
        all_dim = feature_size[0] * feature_size[1]  # Total spatial size (H*W)

        outputs = []

        # Process each frame
        for i in range(num_frames):
            combined_output = 0  # Initialize the accumulated output for the current frame
            
            # Cross-attention with other frames
            for j in range(num_frames):
                if i == j:
                    continue  # Skip self-attention (no attention to itself)

                # Flatten spatial dimensions of both frames for attention computation
                frame_i_flat = frames[i].view(-1, frames[i].size(1), all_dim).transpose(1, 2)  # N, H*W, C
                frame_j_flat = frames[j].view(-1, frames[j].size(1), all_dim).transpose(1, 2)  # N, H*W, C

                # Compute Query, Key, and Value for attention
                query = self.query_linear(frame_j_flat)  # N, H*W, Embedding_Dim
                key = self.key_linear(frame_i_flat)      # N, H*W, Embedding_Dim
                value = self.value_linear(frame_i_flat)  # N, H*W, Embedding_Dim

                # Apply Multi-Head Self-Attention
                attn_output, _ = self.self_attention(query, key, value)  # N, H*W, Embedding_Dim

                # Reshape the attention output to match the spatial dimensions of the frame
                attn_output = attn_output.transpose(1, 2).view(-1, frames[i].size(1), feature_size[0], feature_size[1])  # N, C, H, W

                # Apply gating mechanism (sigmoid to create a mask)
                mask = self.gate_conv(attn_output)
                mask = self.gate_activation(mask)
                attn_output = attn_output * mask  # Apply the mask to the attention output

                # Concatenate the original frame and the gated attention output
                combined = torch.cat([attn_output, frames[i]], 1)  # N, 2C, H, W

                # Apply convolution, batch normalization, and activation to the combined features
                combined = self.combine_conv(combined)
                combined = self.batch_norm(combined)
                combined = self.activation(combined)

                # Accumulate the results
                combined_output += combined

            # Average the contributions from all other frames
            combined_output /= (num_frames - 1)
            
            # Apply the final classifier
            output = self.classifier(combined_output)
            outputs.append(output)

        return outputs



class TAM_FCN8s(nn.Module):
    def __init__(self, num_classes, feature_depths=[64, 128, 256, 512, 1024]):
        super().__init__()

        # Define the encoder (Assuming `encoder` is pre-defined)
        self.encoder = Encoder(features=feature_depths)
        
        # Define the classification layer for the bottleneck
        self.score_fr = nn.Conv2d(1024, num_classes, kernel_size=1, padding=0)

        # Upsampling layers to progressively upsample the features
        self.upconv1 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)
        # self.upconv2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, padding=1)
        self.upconv3 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, padding=4)
        
        # Skip connection layers (downsampled encoder outputs to be concatenated later)
        self.skip_conv1 = nn.Conv2d(feature_depths[1], num_classes, kernel_size=1)
        self.skip_conv2 = nn.Conv2d(feature_depths[2], num_classes, kernel_size=1)
        self.skip_conv3 = nn.Conv2d(feature_depths[3], num_classes, kernel_size=1)

        # Classifier layer to produce the final segmentation mask
        self.classifier = nn.Conv2d(num_classes, num_classes, kernel_size=1, padding=0)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, *inputs):
        num_frames = len(inputs)
        masks = []
        
        # Get the output from the encoder for all frames
        encoder_outputs = self.encoder(*inputs)

        for i in range(num_frames):
            # Unpack encoder outputs for each frame
            skip1, skip2, skip3, skip4, bottleneck = encoder_outputs[i * 5:(i + 1) * 5]
                 
            # Apply classification layer to bottleneck feature map
            x = self.score_fr(bottleneck)
            
            # Upsample the bottleneck features
            x = self.upconv1(x)

            # Add skip connection from the last encoder stage (skip4)
            x = x + self.skip_conv3(skip4)

            # Further upsampling to match the original input size
            x = self.upconv3(x)

            # Apply final classification layer
            mask = self.classifier(x)
            mask = self.softmax(mask)  # Apply softmax for multi-class segmentation

            # Collect the output mask for each frame
            masks.append(mask)

        return tuple(masks)


In [None]:
# Model initialization
model = TAM_FCN8s(num_classes=4)

# Example input dimensions
batch_size = 1
channels = 1
height = 256
width = 256

# Example input data (three frames in this case)
inputs = [
    torch.randn(batch_size, channels, height, width), 
    torch.randn(batch_size, channels, height, width),
    torch.randn(batch_size, channels, height, width)
]

# Forward pass to get outputs
outputs = model(*inputs)

# Print output shape for the first frame
print(outputs[0].shape)

# Print model summary
summary(
    model,
    input_size=[(batch_size, channels, height, width), 
                (batch_size, channels, height, width),
                (batch_size, channels, height, width)],  # Assuming 3 inputs
    col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"],
    depth=50,  # Adjust depth to show nested layers
    col_width=20,
    row_settings=["var_names"]
)


In [None]:
if num_mid_frames is not None:
    # Create a list of input tensors based on the number of frames (num_mid_frames + 2)
    # Typically, num_mid_frames might represent some intermediate frames between the input and output frames
    input_tensors = [torch.randn(1, 1, 256, 256).to(device) for _ in range(num_mid_frames + 2)]

    # Perform FLOP count analysis using the input tensors and the model
    flops = FlopCountAnalysis(model.to(device), tuple(input_tensors))
    
    # Print the total FLOPs (both forward and backward passes)
    print(f"FLOPs: {flops.total()}")
    
    # Convert the total FLOPs to GFLOPs and print the result
    total_flops = flops.total() * 2 / 10**9  # Multiply by 2 for forward + backward and divide by 10^9 for GFLOPs
    print(f"Total FLOPs (forward + backward): {total_flops}")

if num_mid_frames is None:
    # Create only 2 input tensors when num_mid_frames is None, typically representing a simpler case (e.g., one frame)
    input_tensors = [torch.randn(1, 1, 256, 256).to(device) for _ in range(2)]
    
    # Perform FLOP count analysis using the input tensors and the model
    flops = FlopCountAnalysis(model.to(device), tuple(input_tensors))
    
    # Print the total FLOPs (both forward and backward passes)
    print(f"FLOPs: {flops.total()}")
    
    # Convert the total FLOPs to GFLOPs and print the result
    total_flops = flops.total() * 2 / 10**9  # Multiply by 2 for forward + backward and divide by 10^9 for GFLOPs
    print(f"Total FLOPs (forward + backward): {total_flops}")


In [None]:
# Load the model weights from a checkpoint
# Ensure the model is loaded onto the correct device (CPU or GPU)
# model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Move the model to the specified device (GPU or CPU)
model = model.to(device)

# If double GPU configuration is specified in the config, wrap the model with DataParallel
if config.double_GPU:
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3], output_device=[0, 1, 2, 3])

# Initialize the Adam optimizer with model parameters and specified learning rate from the config
optimizer = torch.optim.Adam(model.parameters(), config.LR)

# Lists to track metrics for training and validation phases
train_myo_dsc = []
train_endo_dsc = []
train_epi_dsc = []
train_LA_dsc = []
train_loss_history = []

val_myo_dsc = []
val_endo_dsc = []
val_epi_dsc = []
val_LA_dsc = []
val_loss_history = []

# Variable to store the best validation DSC for model checkpointing
best_val_dsc = 0

# Loop over the specified number of epochs for training
for epoch in range(config.Epochs):
    print("-" * 100)
    print(f"Epoch {epoch + 1}/{config.Epochs}")
    
    # Set the model to training mode
    model.train()

    # Initialize variables for tracking metrics during the epoch
    batch_count = 0
    epoch_train_loss = 0
    epoch_train_myo_dsc = 0
    epoch_train_LA_dsc = 0
    epoch_train_endo_dsc = 0
    epoch_train_epi_dsc = 0

    # Loop over the training data batches
    for batch_data in tqdm(trainData):
        batch_count += 1
        optimizer.zero_grad()

        # Extract the frame keys from batch data
        frame_keys = list(batch_data.keys())

        # Prepare the input tensors and the corresponding ground truth masks
        inputs = [batch_data[key]['image'].to(device) for key in frame_keys]
        mask_true = [batch_data[key]['mask'].to(device) for key in frame_keys]

        # Ensure the number of inputs matches the expected count for the model
        assert len(inputs) == len(frame_keys), f"Expected {len(frame_keys)} inputs, but got {len(inputs)}"

        # Pass the inputs through the model to get predicted masks
        masks = model(*inputs)

        # Lists to store evaluation metrics for the current batch
        _dscs = []
        _mses = []
        _myo_endos = []
        _epis = []
        
        # Loop over the masks for each frame
        for frame_ in range(len(masks)):
            target_dsc, target_mse, target_myo_endo, target_epi = train_helper(masks[frame_], mask_true[frame_], device)
            _dscs.append(target_dsc)
            _mses.append(target_mse)
            _myo_endos.append(target_myo_endo)
            _epis.append(target_epi)

        # Compute the total loss for the batch and backpropagate
        batch_train_loss = sum(_dscs) + sum(_mses)
        batch_train_loss.backward()
        optimizer.step()

        # Accumulate loss and metrics for the epoch
        epoch_train_loss += batch_train_loss.item()
        epoch_train_myo_dsc += _myo_endos[0][1].item()
        epoch_train_endo_dsc += _myo_endos[0][2].item()
        epoch_train_myo_dsc += _myo_endos[1][1].item()
        epoch_train_endo_dsc += _myo_endos[1][2].item()
        epoch_train_LA_dsc += _myo_endos[0][3].item()
        epoch_train_LA_dsc += _myo_endos[1][3].item()
        epoch_train_epi_dsc += _epis[0][1].item()
        epoch_train_epi_dsc += _epis[1][1].item()

    # Average the metrics over the number of batches for this epoch
    epoch_train_loss /= batch_count
    epoch_train_myo_dsc /= 2 * batch_count
    epoch_train_endo_dsc /= 2 * batch_count
    epoch_train_epi_dsc /= 2 * batch_count
    epoch_train_LA_dsc /= 2 * batch_count

    # Store the metrics for this epoch
    train_myo_dsc.append(epoch_train_myo_dsc)
    train_endo_dsc.append(epoch_train_endo_dsc)
    train_epi_dsc.append(epoch_train_epi_dsc)
    train_LA_dsc.append(epoch_train_LA_dsc)
    train_loss_history.append(epoch_train_loss)

    # Validation phase after every 'val_interval' epochs or in the first epoch
    if (epoch + 1) % config.val_interval == 0 or epoch == 0:
        # Initialize variables for validation phase
        val_batch_count = 0
        epoch_val_loss = 0
        epoch_val_myo_dsc = 0
        epoch_val_endo_dsc = 0
        epoch_val_LA_dsc = 0
        epoch_val_epi_dsc = 0

        # Set model to evaluation mode (disables dropout, batch norm, etc.)
        model.eval()

        # Disable gradient computation for validation to save memory
        with torch.no_grad():
            for batch_data in tqdm(valData):
                val_batch_count += 1
                
                # Extract frame keys for the validation batch
                frame_keys = list(batch_data.keys())

                # Prepare the validation input tensors and ground truth masks
                inputs = [batch_data[key]['image'].to(device) for key in frame_keys]
                mask_true = [batch_data[key]['mask'].to(device) for key in frame_keys]

                # Ensure the number of inputs matches the expected count for the model
                assert len(inputs) == len(frame_keys), f"Expected {len(frame_keys)} inputs, but got {len(inputs)}"

                # Pass the validation inputs through the model
                masks = model(*inputs)

                # Lists to store evaluation metrics for validation
                _dscs = []
                _mses = []
                _myo_endos = []
                _epis = []

                # Loop over each frame in the validation batch
                for frame_ in range(len(masks)):
                    target_dsc, target_mse, target_myo_endo, target_epi = train_helper(masks[frame_], mask_true[frame_], device)
                    _dscs.append(target_dsc)
                    _mses.append(target_mse)
                    _myo_endos.append(target_myo_endo)
                    _epis.append(target_epi)

                # Calculate the total validation loss for the batch
                batch_val_loss = sum(_dscs) + sum(_mses)

                # Accumulate validation metrics
                epoch_val_loss += batch_val_loss.item()
                epoch_val_myo_dsc += _myo_endos[0][1].item() + _myo_endos[1][1].item()
                epoch_val_endo_dsc += _myo_endos[0][2].item() + _myo_endos[1][2].item()
                epoch_val_LA_dsc += _myo_endos[0][3].item() + _myo_endos[1][3].item()
                epoch_val_epi_dsc += _epis[0][1].item() + _epis[1][1].item()

            # Average validation metrics over the batch count
            epoch_val_loss /= val_batch_count
            epoch_val_myo_dsc /= 2 * val_batch_count
            epoch_val_endo_dsc /= 2 * val_batch_count
            epoch_val_LA_dsc /= 2 * val_batch_count
            epoch_val_epi_dsc /= 2 * val_batch_count

            # Store the validation metrics for this epoch
            val_loss_history.append(epoch_val_loss)
            val_myo_dsc.append(epoch_val_myo_dsc)
            val_endo_dsc.append(epoch_val_endo_dsc)
            val_epi_dsc.append(epoch_val_epi_dsc)
            val_LA_dsc.append(epoch_val_LA_dsc)

            # Calculate the overall average DSC for checkpointing
            epoch_val_avg_dsc = (epoch_val_epi_dsc + epoch_val_endo_dsc) / 2

            # Save the model if validation DSC improves
            if epoch_val_avg_dsc > best_val_dsc:
                best_val_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(checkpoint_path))
                print(f"Valid DSC improved from {best_val_dsc:2.5f} to {epoch_val_avg_dsc:2.5f}! Best model is saving as---> {checkpoint_path}")
                best_val_dsc = epoch_val_avg_dsc

                # Print training metrics for the current epoch
                print(f"For epoch: {epoch + 1}, Average train total Loss: {epoch_train_loss:.5f}!")
                print(f"For epoch: {epoch + 1}, Average train DSC MYO: {epoch_train_myo_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average train DSC ENDO: {epoch_train_endo_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average train DSC EPI: {epoch_train_epi_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average train DSC LA: {epoch_train_LA_dsc:.5f}!")
                print("-" * 60)

                # Print validation metrics for the current epoch
                print(f"For epoch: {epoch + 1}, Average validation total Loss: {epoch_val_loss:.5f}!")
                print(f"For epoch: {epoch + 1}, Average validation DSC MYO: {epoch_val_myo_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average validation DSC ENDO: {epoch_val_endo_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average validation DSC EPI: {epoch_val_epi_dsc:.5f}!")
                print(f"For epoch: {epoch + 1}, Average validation DSC LA: {epoch_val_LA_dsc:.5f}!")

            # Print the best model at the end of the training
            print()
            print(f'Best model at the epoch of {best_val_epoch:2.0f}, having DSC of {best_val_dsc:2.4f}!!')

# Save the training and validation metrics as a DataFrame
df = {
    'train_myo_dsc': np.array(train_myo_dsc),
    'val_myo_dsc': np.array(val_myo_dsc),
    'train_endo_dsc': np.array(train_endo_dsc),
    'val_endo_dsc': np.array(val_endo_dsc),
    'train_epi_dsc': np.array(train_epi_dsc),
    'val_epi_dsc': np.array(val_epi_dsc),
    'train_LA_dsc': np.array(train_LA_dsc),
    'val_LA_dsc': np.array(val_LA_dsc),
    'train_loss_history': np.array(train_loss_history),
    'val_loss_history': np.array(val_loss_history)
}

# Save the metrics to a CSV file
df = pd.DataFrame(df)
df.to_csv("model_training_metrics.csv", index=False)

In [None]:
# Move the model to the specified device (GPU or CPU)
# Assuming 'device' is defined earlier in the script
model = TAM_FCN8s(num_classes=4)
model = model.to(device)

# Set the model to evaluation mode (disables dropout, batch norm, etc.)
model.eval()

# Load the model weights from the checkpoint
# The model's state_dict will be loaded onto the appropriate device
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Create the directory for saving output if it doesn't already exist
if not os.path.exists(saveFile):
    os.makedirs(saveFile)

# Define the validation/test dataset and DataLoader
valData = DataLoader(
    CardiacDataset(
        # Sorting and loading images and masks for the validation dataset
        sorted(glob(data_dir + "test/image/*ED.png")),
        sorted(glob(data_dir + "test/mask/*ED.png")),
        num_mid_frames=num_mid_frames,
        transform=None  # No transformation applied
    ),
    batch_size=1,  # Process one image at a time
    shuffle=config.shuffle_val,  # Shuffle the dataset if specified
    num_workers=config.num_workers  # Number of parallel workers for data loading
)

# Iterate over batches of data in the validation set
for batch_data in tqdm(valData):
    
    # Extract the frame keys (corresponding to different image/mask files)
    frame_keys = list(batch_data.keys())

    # Prepare inputs (images) and true masks by moving them to the appropriate device
    inputs = [batch_data[key]['image'].to(device) for key in frame_keys]
    mask_true = [batch_data[key]['mask'].to(device) for key in frame_keys]

    # Ensure that the number of inputs corresponds to the number of expected frames
    assert len(inputs) == len(frame_keys), f"Expected {len(frame_keys)} inputs, but got {len(inputs)}"

    # Pass the prepared inputs to the model to obtain the predicted masks
    masks = model(*inputs)

    # Iterate through the frames in the batch
    for frame_ in range(len(masks)):
        # Apply the argmax to the model output to get the predicted mask classes
        masks_ = torch.argmax(masks[frame_], dim=1).unsqueeze(1).to(torch.float32)
        mask_true_ = mask_true[frame_]
        true_img_ = inputs[frame_]

        # Modify mask values to specific integers for each class
        masks_[masks_ == 1] = 100
        masks_[masks_ == 2] = 150
        masks_[masks_ == 3] = 200

        mask_true_[mask_true_ == 1] = 100
        mask_true_[mask_true_ == 2] = 150
        mask_true_[mask_true_ == 3] = 200

        # Save the output images and masks for each batch and frame
        for batch in range(masks_.shape[0]):
            # Convert the masks and images to numpy arrays for saving as images
            true_mask = mask_true_[batch, :, :].reshape(config.img_size, config.img_size).detach().cpu().numpy()
            pred_mask = masks_[batch, :, :].reshape(config.img_size, config.img_size).detach().cpu().numpy()
            true_img = true_img_[batch, :, :].reshape(config.img_size, config.img_size).detach().cpu().numpy()

            # Save the predicted mask, true mask, and true image as PNG files
            cv2.imwrite(saveFile + '/' + os.path.splitext(batch_data['ED']['name'][batch])[0] + frame_keys[frame_] + '_pred_mask.png', pred_mask)
            cv2.imwrite(saveFile + '/' + os.path.splitext(batch_data['ED']['name'][batch])[0] + frame_keys[frame_] + '_true_mask.png', true_mask)
            cv2.imwrite(saveFile + '/' + os.path.splitext(batch_data['ED']['name'][batch])[0] + frame_keys[frame_] + '_true_image.png', 255 * true_img)

