In [None]:
# Import the `set_determinism` utility to ensure reproducibility of results
from monai.utils import set_determinism, first

# Import necessary components for handling data in MONAI
from monai.data import DataLoader, Dataset, CacheDataset

# Import PyTorch's functional API for implementing layers and functions
import torch.nn.functional as F

# Import Structural Similarity Index (SSIM) from the piqa library for image quality assessment
from piqa import SSIM

# Import pandas for data manipulation and analysis
import pandas as pd

# Import NumPy for numerical operations and PyTorch for deep learning tasks
import numpy as np
import torch, torchinfo, torchvision

# Import `glob` for file pattern matching and `cv2` for image processing
from glob import glob
import cv2

# Import `torch.nn` for constructing neural networks
import torch.nn as nn
from torch.autograd import Variable

# Import various loss functions and metrics from MONAI for medical image analysis
from monai.losses import BendingEnergyLoss, MultiScaleLoss, DiceLoss
from monai.metrics import *

# Import `pyplot` from matplotlib for data visualization
from matplotlib import pyplot as plt

# Set determinism for reproducibility
# Uncomment the following line if reproducibility is required:
# set_determinism(42)


In [None]:
# Print the number of available GPUs
print('How many GPUs = ' + str(torch.cuda.device_count()))

# Check if a GPU is available and set the device to GPU if available, otherwise fall back to CPU
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
print(device)  # Output the device being used (either GPU or CPU)

# Raise an exception if no GPU is available, since CPU training can be very slow
if not torch.cuda.is_available():
    raise Exception("GPU not available. CPU training will be too slow.")

# Print the name of the GPU being used
print("device name", torch.cuda.get_device_name(0))


In [None]:
# Set the image size for the model input
img_size = 256

# Set the number of training epochs
epoch = 50

# Set the number of worker threads to load data in parallel
num_workers = 8

# Flag to indicate whether to use a pre-trained model (0 means no pre-trained model is used)
preTrained = 0

# File path pattern for training images
TrImage = "data_VAE/aug_train/*.png"

# File path pattern for validation images
ValImage = "data_VAE/aug_val/*.png"


In [None]:
# Define a custom dataset class for handling images with optional label-to-zero conversion
class EchoDatasetMask(Dataset):
    def __init__(self, images_path, labtozero):
        # Store the paths to the images and the label-to-zero value
        self.images_path = images_path
        self.labtozero = labtozero
        self.n_samples = len(images_path)  # Number of samples in the dataset

    def __getitem__(self, index):
        """Method to retrieve a single image based on the index"""
        # Read the image from the file path as a grayscale image
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        
        # If labtozero is specified, set all pixels with this label to zero
        if self.labtozero is not None:
            image[image == self.labtozero] = 0
            
        # Resize the image to the specified size
        image = cv2.resize(image, (img_size, img_size), interpolation=cv2.INTER_NEAREST)
        
        # Normalize the image by dividing by the maximum pixel value
        image = image / (image.max())
        
        # Expand the dimensions of the image to add a channel dimension
        image = np.expand_dims(image, axis=0)
        
        # Convert the image to a float32 type
        image = image.astype(np.float32)
        
        # Return the processed image
        return image

    def __len__(self):
        # Return the total number of samples in the dataset
        return self.n_samples

# Function to create data loaders for batches of images
def get_batches_mask(_dir,
                     _labtozero,
                     batch_size,
                     num_workers,
                     pin_memory):
    
    # Instantiate the custom dataset with the provided directory and label-to-zero value
    _data = EchoDatasetMask(images_path=_dir, labtozero=_labtozero)

    # Create a DataLoader for batching, shuffling, and loading the data in parallel
    _loader = DataLoader(_data,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         pin_memory=pin_memory,
                         shuffle=True)

    # Return the DataLoader
    return _loader

# Print the number of training images found in the specified directory
print(len(sorted(glob(TrImage))))

# Print the number of validation images found in the specified directory
print(len(sorted(glob(ValImage))))


In [None]:
# Create data loader for training set masks
train_mask_LV = get_batches_mask(_dir=sorted(glob(TrImage)),
                                 _labtozero=100,
                                 batch_size=128,
                                 num_workers=num_workers,
                                 pin_memory=True)

# Create data loader for validation set masks
val_mask_LV = get_batches_mask(_dir=sorted(glob(ValImage)),
                               _labtozero=100,
                               batch_size=8,
                               num_workers=num_workers,
                               pin_memory=True)

# Store the train and validation data loaders in a dictionary
dataloaders = {'_train': train_mask_LV, '_val': val_mask_LV}

# Retrieve the first training mask from the data loader
_train_mask_ = first(dataloaders["_train"])[0][0]

# Retrieve the first validation mask from the data loader
_val_mask_ = first(dataloaders["_val"])[0][0]

# Print the shape of the first training mask
print(f"_train_mask_ shape: {_train_mask_.shape}")

# Print the shape of the first validation mask
print(f"_val_mask_ shape: {_val_mask_.shape}")

# Print the range and unique values of the first training mask
print(f"_train_mask_ range: {_train_mask_.max()} {_train_mask_.min()} {np.unique(_train_mask_)}")

# Print the range and unique values of the first validation mask
print(f"_val_mask_ range: {_val_mask_.max()} {_val_mask_.min()} {np.unique(_val_mask_)}")

# Visualize the first training and validation masks using matplotlib
plt.figure("check", (10, 5))

# Plot the first training mask
plt.subplot(1, 2, 1)
plt.title("Train mask example")
plt.imshow(_train_mask_, cmap="gray")
plt.axis('off')  # Hide axis for a cleaner visualization

# Plot the first validation mask
plt.subplot(1, 2, 2)
plt.title("Val mask example")
plt.imshow(_val_mask_, cmap="gray")
plt.axis('off')  # Hide axis for a cleaner visualization

# Display the plots
plt.show()


In [None]:
# Create data loader for training set masks with label 200 set to zero
train_mask_Myo = get_batches_mask(_dir=sorted(glob(TrImage)),
                                  _labtozero=200,
                                  batch_size=128,
                                  num_workers=num_workers,
                                  pin_memory=True)

# Create data loader for validation set masks with label 200 set to zero
val_mask_Myo = get_batches_mask(_dir=sorted(glob(ValImage)),
                                _labtozero=200,
                                batch_size=8,
                                num_workers=num_workers,
                                pin_memory=True)

# Store the train and validation data loaders in a dictionary
dataloaders = {'_train': train_mask_Myo, '_val': val_mask_Myo}

# Retrieve the first training mask from the data loader
_train_mask_ = first(dataloaders["_train"])[0][0]

# Retrieve the first validation mask from the data loader
_val_mask_ = first(dataloaders["_val"])[0][0]

# Print the shape of the first training mask
print(f"_train_mask_ shape: {_train_mask_.shape}")

# Print the shape of the first validation mask
print(f"_val_mask_ shape: {_val_mask_.shape}")

# Print the range and unique values of the first training mask
print(f"_train_mask_ range: {_train_mask_.max()} {_train_mask_.min()} {np.unique(_train_mask_)}")

# Print the range and unique values of the first validation mask
print(f"_val_mask_ range: {_val_mask_.max()} {_val_mask_.min()} {np.unique(_val_mask_)}")

# Visualize the first training and validation masks using matplotlib
plt.figure("check", (10, 5))

# Plot the first training mask
plt.subplot(1, 2, 1)
plt.title("Train mask example")
plt.imshow(_train_mask_, cmap="gray")
plt.axis('off')  # Hide axis for a cleaner visualization

# Plot the first validation mask
plt.subplot(1, 2, 2)
plt.title("Val mask example")
plt.imshow(_val_mask_, cmap="gray")
plt.axis('off')  # Hide axis for a cleaner visualization

# Display the plots
plt.show()


In [None]:
# Create a dictionary to store training data loaders for different mask types
trainData = {
    'Myo': train_mask_Myo,  # Myocardium (Myo) training data loader
    'LV': train_mask_LV      # Left Ventricle (LV) training data loader
}

# Create a dictionary to store validation data loaders for different mask types
valData = {
    'Myo': val_mask_Myo,  # Myocardium (Myo) validation data loader
    'LV': val_mask_LV      # Left Ventricle (LV) validation data loader
}


In [None]:
# Define the Variational Encoder class inheriting from PyTorch's nn.Module
class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):  
        super(VariationalEncoder, self).__init__()
        
        # Define convolutional layers with increasing feature maps and stride for downsampling
        self.conv1 = nn.Conv2d(1, 8, 3, stride=2, padding=1)  # Input: 1 channel, Output: 8 channels
        self.conv2 = nn.Conv2d(8, 16, 3, stride=2, padding=1) # Input: 8 channels, Output: 16 channels
        self.batch2 = nn.BatchNorm2d(16)  # Batch normalization for the second conv layer
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1) # Input: 16 channels, Output: 32 channels
        
        # Linear layers to map the output of the conv layers to the latent space
        self.linear1 = nn.Linear(img_size // 8 * img_size // 8 * 32, 128) # Flatten and reduce to 128
        self.linear2 = nn.Linear(128, latent_dims)  # Map to the latent mean (mu)
        self.linear3 = nn.Linear(128, latent_dims)  # Map to the latent log-variance (sigma)

        # Define a standard normal distribution for sampling in the latent space
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.to(device)  # Ensure sampling happens on the GPU
        self.N.scale = self.N.scale.to(device)
        self.kl = 0  # Initialize the KL divergence term

    def forward(self, x):
        x = x.to(device)  # Move input to the GPU if available
        
        # Apply convolutional layers with ReLU activations
        x = F.relu(self.conv1(x))
        x = F.relu(self.batch2(self.conv2(x)))
        x = F.relu(self.conv3(x))
        
        # Flatten the output from the conv layers to prepare for linear layers
        x = torch.flatten(x, start_dim=1)
        
        # Apply the first linear layer with ReLU activation
        x = F.relu(self.linear1(x))
        
        # Compute the mean (mu) and log-variance (sigma) for the latent space
        mu = self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        
        # Sample from the latent space using the reparameterization trick
        z = mu + sigma * self.N.sample(mu.shape)
        
        # Compute the KL divergence between the approximate posterior and the prior
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 0.5).sum()
        
        # Return the sampled latent vector
        return z


In [None]:
# Define the Decoder class inheriting from PyTorch's nn.Module
class Decoder(nn.Module):
    
    def __init__(self, latent_dims):
        super().__init__()
        
        # Linear layers to map the latent space back to the original input size
        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, 128),  # Map latent dimensions to 128 features
            nn.ReLU(True),  # Apply ReLU activation
            nn.Linear(128, img_size//8 * img_size//8 * 32),  # Map to the flattened size from the encoder
            nn.ReLU(True)  # Apply ReLU activation
        )

        # Unflatten the tensor to match the shape expected by the convolutional layers
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, img_size//8, img_size//8))
        
        # Optional softmax layer (currently commented out)
        self.m = nn.Softmax(dim=1)

        # Convolutional transpose layers to upsample the feature maps back to the original image size
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # First upsampling layer
            nn.BatchNorm2d(16),  # Batch normalization
            nn.ReLU(True),  # Apply ReLU activation
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),  # Second upsampling layer
            nn.BatchNorm2d(8),  # Batch normalization
            nn.ReLU(True),  # Apply ReLU activation
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1)  # Final upsampling layer to output 1 channel
        )
        
    def forward(self, x):
        # Pass the input through the linear decoder layers
        x = self.decoder_lin(x)
        
        # Reshape the output to match the expected input shape for convolutional layers
        x = self.unflatten(x)
        
        # Pass the reshaped tensor through the convolutional decoder layers
        x = self.decoder_conv(x)
        
        # Apply sigmoid activation to the output to obtain pixel values between 0 and 1
        x = torch.sigmoid(x)
        
        # Optional softmax activation (currently commented out)
        # x = self.m(x)
        
        # Return the reconstructed image
        return x


In [None]:
# Define the Variational Autoencoder (VAE) class inheriting from PyTorch's nn.Module
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        
        # Initialize the encoder with the given latent dimensions
        self.encoder = VariationalEncoder(latent_dims)
        
        # Initialize the decoder with the same latent dimensions
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        x = x.to(device)  # Move the input to the GPU if available
        
        # Encode the input image to the latent space
        z = self.encoder(x)
        
        # Decode the latent vector back to the image space
        return self.decoder(z)


In [None]:
# Set the random seed for reproducible results
torch.manual_seed(0)

# Define the number of latent dimensions for the Variational Autoencoder
d = 8

# Initialize the Variational Autoencoder with the specified latent dimensions
LV_VAE = VariationalAutoencoder(latent_dims=d)

# Move the VAE model to the appropriate device (GPU if available, otherwise CPU)
LV_VAE.to(device)

# If a pre-trained model is specified, load its state dictionary
if preTrained: 
    # Load the model weights from a file, mapping to the correct device
    LV_VAE.load_state_dict(torch.load('Myo_VAE_' + str(img_size) + '_.pth', map_location=device))

# Optional: Print a summary of the VAE model's architecture and parameter details
# torchinfo.summary(LV_VAE, input_size=(2, 1, 512, 512), depth=100)


In [None]:
# Set the learning rate for the optimizer
lr = 1e-3 

# Initialize the Adam optimizer for the Variational Autoencoder
# The optimizer will update the model's parameters based on the gradients computed during backpropagation
optim_LV_VAE = torch.optim.Adam(
    LV_VAE.parameters(),  # The parameters of the model to optimize
    lr=lr,                # Learning rate for the optimizer
    weight_decay=1e-5     # Weight decay (L2 regularization) to prevent overfitting
)


In [None]:
# Define a custom SSIM loss class inheriting from the SSIM class
class SSIMLoss(SSIM):
    def forward(self, x, y):
        # Invert the SSIM loss to match the standard loss formulation (higher is worse)
        return 1. - super().forward(x, y)

# Initialize the SSIM loss with single channel support and move it to the appropriate device
ssim_loss_function = SSIMLoss(n_channels=1).to(device)  # .cuda() if GPU support is needed

# Initialize the Dice loss for evaluating segmentation performance
dice_loss_function = DiceLoss()

# Initialize DiceMetric for tracking performance metrics
dice_metric_calculator = DiceMetric(include_background=True, reduction="mean")

def train_epoch(model, device, dataloader, optimizer):
    """
    Perform one epoch of training for the VAE model.

    Parameters:
    - model: The Variational Autoencoder model
    - device: The device to which the model and data are transferred (GPU or CPU)
    - dataloader: DataLoader for fetching training batches
    - optimizer: Optimizer for updating model weights

    Returns:
    - A list of average losses and metrics for the epoch
    """
    # Set the model to training mode
    model.train()
    
    # Initialize accumulators for different loss components
    total_loss = 0.0
    total_dice_loss = 0.0
    total_ssim_loss = 0.0
    total_kl_divergence = 0.0
    total_mse_loss = 0.0
    mean_dice_score = 0.0
    step = 0
    
    # Iterate over the batches of the dataloader
    for batch_images in dataloader:
        # Move batch to the appropriate device
        batch_images = batch_images.to(device)
        
        # Forward pass: Generate reconstructed images
        reconstructed_images = model(batch_images)
        
        # Calculate various loss components
        mse_loss = ((batch_images - reconstructed_images)**2).sum()  # Mean squared error loss
        dice_loss = dice_loss_function(batch_images, reconstructed_images)  # Dice loss for segmentation accuracy
        kl_divergence = model.encoder.kl  # KL divergence from the encoder
        ssim_loss = ssim_loss_function(batch_images, reconstructed_images)  # SSIM loss (inverted)

        # Total loss is the sum of individual losses
        total_loss_value = kl_divergence + mse_loss + dice_loss + ssim_loss

        # Backward pass and optimization step
        optimizer.zero_grad()
        total_loss_value.backward()
        optimizer.step()
        
        # Accumulate losses for reporting
        total_loss += total_loss_value.item()
        total_dice_loss += dice_loss.item()
        total_ssim_loss += ssim_loss.item()
        total_kl_divergence += kl_divergence.item()
        total_mse_loss += mse_loss.item()
        mean_dice_score += 1-dice_loss.item()
        step+=1


    # Calculate average losses over the epoch
    num_batches = len(dataloader.dataset)

    # Compute and print mean Dice score for the epoch
    print(f"Train mean dice: {(mean_dice_score/step):.4f}")
    
    
    return [
        total_loss / num_batches,
        total_dice_loss / num_batches,
        total_ssim_loss / num_batches,
        total_kl_divergence / num_batches,
        total_mse_loss / num_batches,
        mean_dice_score / step
    ]


In [None]:
def test_epoch(model, device, dataloader):
    """
    Evaluate the model on the validation dataset.

    Parameters:
    - model: The Variational Autoencoder model
    - device: The device to which the model and data are transferred (GPU or CPU)
    - dataloader: DataLoader for fetching validation batches

    Returns:
    - A list of average losses and metrics for the epoch
    """
    # Set the model to evaluation mode (disables dropout and batch normalization)
    model.eval()
    
    # Initialize accumulators for different loss components
    total_loss = 0.0
    total_dice_loss = 0.0
    total_ssim_loss = 0.0
    total_kl_divergence = 0.0
    total_mse_loss = 0.0
    mean_dice_score = 0.0
    step =0 
    
    # Initialize Dice metric calculator
    dice_metric_calculator = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    
    with torch.no_grad():  # Disable gradient computation for evaluation
        for batch_images in dataloader:
            # Move batch to the appropriate device
            batch_images = batch_images.to(device)
            
            # Forward pass through the model to get reconstructed images
            reconstructed_images = model(batch_images)
            
            # Calculate various loss components
            mse_loss = ((batch_images - reconstructed_images)**2).sum()  # Mean squared error loss
            dice_loss = dice_loss_function(batch_images, reconstructed_images)  # Dice loss for segmentation accuracy
            kl_divergence = model.encoder.kl  # KL divergence from the encoder
            ssim_loss = ssim_loss_function(batch_images, reconstructed_images)  # SSIM loss (inverted)
            
            # Total loss is the sum of individual losses
            total_loss_value = kl_divergence + mse_loss + dice_loss + ssim_loss
            
            # Accumulate losses for reporting
            total_loss += total_loss_value.item()
            total_dice_loss += dice_loss.item()
            total_ssim_loss += ssim_loss.item()
            total_kl_divergence += kl_divergence.item()
            total_mse_loss += mse_loss.item()
            mean_dice_score += 1-dice_loss.item()
            step+=1
            

        # Compute and print mean Dice score for the validation set
        print(f"Validation mean dice: {(mean_dice_score/step):.4f}")
        
        # Calculate average losses over the epoch
        num_samples = len(dataloader.dataset)
        return [
            total_loss / num_samples,
            total_dice_loss / num_samples,
            total_ssim_loss / num_samples,
            total_kl_divergence / num_samples,
            total_mse_loss / num_samples,
            mean_dice_score / step
        ]

In [None]:
def plot_outputs(encoder, decoder, dataloader):
    """
    Visualize original and reconstructed images from the Variational Autoencoder.

    Parameters:
    - encoder: The encoder part of the Variational Autoencoder
    - decoder: The decoder part of the Variational Autoencoder
    - dataloader: DataLoader providing batches of images for visualization
    """
    # Set figure size for the plots
    plt.figure(figsize=(16, 4.5))
    
    # Iterate through one batch of images from the dataloader
    for batch_images in dataloader:
        # Set encoder and decoder to evaluation mode
        encoder.eval()
        decoder.eval()
        
        with torch.no_grad():  # No need to compute gradients
            # Move images to the appropriate device and pass through the encoder and decoder
            batch_images = batch_images.to(device)
            reconstructed_images = decoder(encoder(batch_images))
            
        # Plot original and reconstructed images
        num_images = batch_images.size(0)
        for i in range(num_images):
            # Plot original images
            ax = plt.subplot(2, num_images, i + 1)
            plt.imshow(batch_images[i].cpu().squeeze().numpy(), cmap='gist_gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == num_images // 2:
                ax.set_title('Original Images')
                
            # Plot reconstructed images
            ax = plt.subplot(2, num_images, i + 1 + num_images)
            plt.imshow(reconstructed_images[i].cpu().squeeze().numpy(), cmap='gist_gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == num_images // 2:
                ax.set_title('Reconstructed Images')
                
        plt.show()
        break  # Only plot the first batch of images


In [None]:
# Number of epochs for training
num_epochs = epoch

# Metrics to track during training and validation
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

# Lists to store loss values and metrics for each epoch
train_losses = {
    'VAE': [],
    'DSC': [],
    'SSIM': [],
    'KL': [],
    'DiffLoss': []
}

val_losses = {
    'VAE': [],
    'DSC': [],
    'SSIM': [],
    'KL': [],
    'DiffLoss': []
}

train_metrics = []
val_metrics = []

# Initialize best loss tracker
best_dice_score = 0
best_epoch = 0

# Path to save the best model
checkpoint_path = f'LV_VAE_{img_size}_.pth'

# Training loop
for epoch in range(num_epochs):
    print(f'\n<<<----------------------- EPOCH {epoch + 1}/{num_epochs} ------------------------->>>')
    print('<<<-------------------------- LV ---------------------------->>>')

    # Train for one epoch
    train_loss = train_epoch(LV_VAE, device, trainData['LV'], optim_LV_VAE)
    val_loss = test_epoch(LV_VAE, device, valData['LV'])
   
    # Check if the validation Dice score has improved
    if val_loss[5] > best_dice_score:
        best_epoch = epoch + 1
        print(f"Validation Dice Score improved from {best_dice_score:.4f} to {val_loss[5]:.4f}! Saving the best model as {checkpoint_path}")
        best_dice_score = val_loss[5]
        torch.save(LV_VAE.state_dict(), checkpoint_path)
    
    # Print losses for this epoch
    print(f'\nLV Train Loss: {train_loss[0]:.3f} \t LV Val Loss: {val_loss[0]:.3f}')
    print(f"Best Validation Dice Score: {best_dice_score:.4f} at Epoch {best_epoch}")

    # Record the losses and metrics
    train_losses['VAE'].append(train_loss[0])
    train_losses['DSC'].append(train_loss[1])
    train_losses['SSIM'].append(train_loss[2])
    train_losses['KL'].append(train_loss[3])
    train_losses['DiffLoss'].append(train_loss[4])
    train_metrics.append(train_loss[5])
    
    val_losses['VAE'].append(val_loss[0])
    val_losses['DSC'].append(val_loss[1])
    val_losses['SSIM'].append(val_loss[2])
    val_losses['KL'].append(val_loss[3])
    val_losses['DiffLoss'].append(val_loss[4])
    val_metrics.append(val_loss[5])

    # Optionally, visualize some outputs
    plot_outputs(LV_VAE.encoder, LV_VAE.decoder, valData['LV'])


# Save training and validation metrics to a CSV file
metrics_df = pd.DataFrame({
    'Train_VAE_Loss': np.array(train_losses['VAE']),
    'Train_DSC_Loss': np.array(train_losses['DSC']),
    'Train_SSIM_Loss': np.array(train_losses['SSIM']),
    'Train_KL_Loss': np.array(train_losses['KL']),
    'Train_DiffLoss': np.array(train_losses['DiffLoss']),
    'Train_Metric': np.array(train_metrics),
    'Val_VAE_Loss': np.array(val_losses['VAE']),
    'Val_DSC_Loss': np.array(val_losses['DSC']),
    'Val_SSIM_Loss': np.array(val_losses['SSIM']),
    'Val_KL_Loss': np.array(val_losses['KL']),
    'Val_DiffLoss': np.array(val_losses['DiffLoss']),
    'Val_Metric': np.array(val_metrics)
})

# Save the metrics DataFrame to a CSV file
metrics_df.to_csv(f'LV_VAE_{img_size}.csv', index=False)


In [None]:
# Plot and save the generator and discriminator loss for the LV VAE model
plt.figure(figsize=(20, 5))

# Plot Training and Validation VAE Loss
plt.subplot(151)
plt.plot(train_losses['VAE'], label='Train VAE Loss', color='blue')
plt.plot(val_losses['VAE'], label='Validation VAE Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('VAE Loss')
plt.legend()

# Plot Training and Validation DSC Loss
plt.subplot(152)
plt.plot(train_losses['DSC'], label='Train DSC Loss', color='blue')
plt.plot(val_losses['DSC'], label='Validation DSC Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('DSC Loss')
plt.legend()

# Plot Training and Validation SSIM Loss
plt.subplot(153)
plt.plot(train_losses['SSIM'], label='Train SSIM Loss', color='blue')
plt.plot(val_losses['SSIM'], label='Validation SSIM Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('SSIM Loss')
plt.legend()

# Plot Training and Validation KL Divergence Loss
plt.subplot(154)
plt.plot(train_losses['KL'], label='Train KL Divergence Loss', color='blue')
plt.plot(val_losses['KL'], label='Validation KL Divergence Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('KL Divergence Loss')
plt.legend()

# Plot Training and Validation Difference Loss
plt.subplot(155)
plt.plot(train_losses['DiffLoss'], label='Train Difference Loss', color='blue')
plt.plot(val_losses['DiffLoss'], label='Validation Difference Loss', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Difference Loss')
plt.legend()

# Adjust layout and save the plot
plt.tight_layout()
plt.savefig(f'LV_VAE_losses_{img_size}.png')
plt.show()
