In [None]:
# Import necessary modules
import glob
import hashlib
import io
import litdata
import matplotlib.pyplot as plt
import numpy as np
import optuna
import os
import pickle
import random
import re
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import tqdm

from datetime import datetime
from IPython.display import display, clear_output
from itertools import product
from PIL import Image
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import ConcatDataset, DataLoader, random_split, WeightedRandomSampler
from torchvision.models import vgg16
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid, save_image

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Specify data folder
datapath = '/projects/ec232/data/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define mean and std from ImageNet data statistics
in_mean = [0.485, 0.456, 0.406]
in_std = [0.229, 0.224, 0.225]

class ToRGBTensor:
    
    def __call__(self, img):
        return TF.to_tensor(img).expand(3, -1, -1) # Expand to 3 channels
        
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

def scores_to_tensor(scores_path):
    scores = np.load(scores_path)
    return torch.tensor(scores, dtype=torch.float32)    
    
## Updated scores_to_tensor function
def scores_to_tensor(scores):
    return torch.tensor(scores, dtype=torch.float32)

# Redefine postprocessing / transform of data modalities
postprocess = (                                # Create tuple for image and scores...
    T.Compose([                                # Handles processing of the .jpg image
        ToRGBTensor(),                         # Convert from PIL image to RGB torch.Tensor.
        T.Resize((224, 224), antialias=True),  # Resize the image to 224, 224 for ResNet18
        #T.Normalize(in_mean, in_std),         # Normalize image to correct mean/std.
    ]),
    scores_to_tensor                           # Handles processing of .scores.npy file.
)

# Reload training data with the updated transformation
data = litdata.LITDataset(
    'CarRecs',
    datapath,
    override_extensions =[ # Sets the order of the modalities:
        'jpg', # ... load image first
        'scores.npy' # ... load scores second
    ],
).map_tuple(*postprocess)

# Test accessing a sample
sample = data[0]
print("Image shape:", sample[0].shape)
print("Scores tensor:", sample[1])

data

## Dataset split

**Product of Scores:**

Multiplying the scores of the two reviewers for each image.
Pros:
Prioritizes images that both reviewers gave a high score to.
Cons:
Low scores from either reviewer can drastically reduce the weight of an image.

**Consensus-Based:**

Prioritizing images where there's a smaller difference between the scores of the two reviewers (indicating consensus).
Pros:
Focuses on images where both reviewers have a mutual understanding or agreement.
Cons:
Might ignore images with extreme scores from one reviewer.

In [None]:
# Extract scores for Moira and Ferdinando
moira_scores = [item[1][0][1].item() for item in data]  # 1 represents Moira's position
ferdinando_scores = [item[1][0][3].item() for item in data]  # 3 represents Ferdinando's position


# Set image batch size
batch_size = 32

# 1. Product of Scores
product_weights = [moira_score * ferdinando_score for moira_score, ferdinando_score in zip(moira_scores, ferdinando_scores)]

# 2. Consensus-Based
consensus_weights = [1 / (abs(moira_score - ferdinando_score) + 1) for moira_score, ferdinando_score in zip(moira_scores, ferdinando_scores)]

# Create full dataset
full_dataset = [(item[0], (torch.tensor(moira_scores[i]), torch.tensor(ferdinando_scores[i]))) for i, item in enumerate(data)]

# Split the dataset
train_size = int(0.9 * len(data))
val_size = len(data) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Extract training weights for both strategies
train_product_weights = [product_weights[i] for i in train_dataset.indices]
train_consensus_weights = [consensus_weights[i] for i in train_dataset.indices]

# Create the WeightedRandomSampler for both strategies
sampler_product = WeightedRandomSampler(train_product_weights, num_samples=len(train_dataset), replacement=True)
sampler_consensus = WeightedRandomSampler(train_consensus_weights, num_samples=len(train_dataset), replacement=True)

# Create data loaders for both strategies
train_loader_product = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_product)
train_loader_consensus = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler_consensus)

# Validation loaders remain the same for both as we don't apply weighting to a validation set
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_loader_product, train_loader_consensus

train_loader = train_loader_product
val_loader = val_loader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Pretrained ResNet18 VAE model

In [None]:
class CVAE_ResNet18_DualEmbedding(nn.Module):
    def __init__(self, conditional_dim=2, latent_dim=512, debug=False):
        super(CVAE_ResNet18_DualEmbedding, self).__init__()
        self.debug = debug
        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(32)
        
        self.conditional_dim = conditional_dim
        self.latent_dim = latent_dim
        
        # Load the ResNet18 model with pretrained weights
        self.resnet18 = models.resnet18(weights='DEFAULT')
        
        # Extract the feature layers
        self.features = nn.Sequential(*list(self.resnet18.children())[:-1])
        
        # Encoder layers
        self.fc_mu = nn.Linear(512 + 2, self.latent_dim)
        self.fc_logvar = nn.Linear(512 + 2, self.latent_dim)
        
        # Decoder layers
        self.decoder_input = nn.Linear(self.latent_dim + 2, 512 * 7 * 7)
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

    def encode(self, x, c):
        # Extract features using ResNet18
        x = self.features(x)
        if self.debug: print(f"After ResNet18 features extraction: {x.shape}")
        x = x.view(x.size(0), 512)  # only flatten the spatial dimensions
        if self.debug: print(f"After reshaping: {x.shape}")
        
        # Concatenate condition vectors (for Moira and Ferdinando)
        x = torch.cat([x, c], dim=1)
        if self.debug: print(f"After concatenating with condition vector: {x.shape}")
        
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, c):
        # Concatenate z with condition vectors (for Moira and Ferdinando)
        z = torch.cat([z, c], dim=1)
        if self.debug: print(f"After concatenating z with condition vector: {z.shape}")

        x = F.relu(self.decoder_input(z))
        x = x.view(x.size(0), 512, 7, 7)  # reshape to (batch, channels, height, width)
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))
        x = F.relu(self.bn4(self.deconv4(x)))
        x = torch.sigmoid(self.deconv5(x))

        return x
    
    def forward(self, x, c):
        if self.debug: print(f"Input image shape: {x.shape}, Condition shape: {c.shape}")
        mu, logvar = self.encode(x, c)
        if self.debug: print(f"Encoding outputs - mu: {mu.shape}, logvar: {logvar.shape}")
        z = self.reparameterize(mu, logvar)
        if self.debug: print(f"Latent representation z shape: {z.shape}")
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar
    
cvae = CVAE_ResNet18_DualEmbedding(conditional_dim=2, latent_dim=512, debug=True).to(device)
print(cvae)

## Initializing metrics

**Importance of Losses for VAE:**

**MSE (Mean Squared Error):**

Importance: Fundamental for VAEs. Ensures pixel-level fidelity.
When to Use: Always, especially in the early stages of training.
Issues: May result in blurry reconstructions if used alone.

**Perceptual Loss:**

Importance: Captures high-level semantic differences between images, ensuring that the generated images are perceptually similar to the target images.
When to Use: Once the VAE starts producing reasonable reconstructions, to refine the quality.
Issues: Can introduce artifacts if weighted too highly.

**Histogram Loss:**

Importance: Ensures that the distribution of pixel intensities in the generated image matches the target, which can be important for capturing image characteristics like brightness and contrast.
When to Use: Useful when the color distribution or intensity distribution is crucial.
Issues: Might not be necessary if the other two losses already produce satisfactory results.
In general, for VAEs, the primary loss is the combination of the reconstruction loss (like MSE) and the KL divergence, which ensures that the learned latent space has desirable properties. The additional perceptual and histogram losses are supplementary and help refine the image quality based on the application's requirements.

To decide the importance, we need to consider the goals of our application. If pixel-level fidelity is crucial, prioritize MSE. If perceptual quality matters more (e.g., for art generation or photo enhancement), the perceptual loss becomes more important. If maintaining image characteristics like brightness or contrast is essential, then histogram loss should be emphasized.

In [None]:
# Loss function: combination of reconstruction MSE loss and KL divergence
def mse_loss(recon_x, x, mu, logvar, beta):
    # Mean Squared Error
    # MSE = F.mse_loss(recon_x, x, reduction='sum') / x.numel() # Normalize by the number of pixels per batch
    MSE = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)   # Normalize by batch size
    
    # Kullback-Leibler divergence loss
    KLD = (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())) / x.size(0)   # Normalize by batch size
    
    return MSE + beta * KLD, MSE, KLD, 0, 0, 0


# Computes the VAE loss = reconstruction MS_SSIM loss + KL divergence
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
ms_ssim_module = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)[:3]).to(device) # adjusted number of scales for 128*128 images
def ms_ssim_loss(recon_x, x, mu, logvar, beta=1.0): 
    
    # Reconstruction loss using MS-SSIM
    ms_ssim_val = ms_ssim_module(recon_x, x)
    ms_ssim_loss = 1 - ms_ssim_val  # 1 - MS-SSIM gives the loss
    
    # KL divergence
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)  # Normalize by batch size
    
    return ms_ssim_loss + beta * kld_loss, ms_ssim_loss, kld_loss, 0, 0, 0


# Load the VGG16 model and extract features from an intermediate layer
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        model = vgg16(weights='DEFAULT')
        self.features = model.features[:16]
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        x_features = self.features(x)
        y_features = self.features(y)
        return F.mse_loss(x_features, y_features)

perceptual_criterion = VGGPerceptualLoss().to(device)


# Enhanced CVAE loss which combines MSE, Perceptual, and Histogram-based loss
def histogram_loss(img1, img2, bins=64):
    # Ensure images are normalized
    # output = (output - output.min()) / (output.max() - output.min())
    # target = (target - target.min()) / (target.max() - target.min())
    
    # print(f"Output Image - Min: {img1.min().item()}, Max: {img1.max().item()}, Mean: {img1.mean().item()}, Std: {img1.std().item()}")
    # print(f"Target Image - Min: {img2.min().item()}, Max: {img2.max().item()}, Mean: {img2.mean().item()}, Std: {img2.std().item()}")

    hist1 = torch.histc(img1, bins=bins, min=0, max=1)
    hist2 = torch.histc(img2, bins=bins, min=0, max=1)

    # Normalize the histograms
    hist1 = hist1 / hist1.sum()
    hist2 = hist2 / hist2.sum()
    
    def is_normalized(img):
        min_val = img.min().item()
        max_val = img.max().item()
        if 0 <= min_val and max_val <= 1:
            return True
        return False

    #if not is_normalized(img1) or not is_normalized(img2):
        #print("Warning: Input images to histogram_loss are not normalized but standardized to (0, 1)")
    
    return F.mse_loss(hist1, hist2)

def combined_cvae_loss(recon_x, x, mu, logvar, beta, alpha, theta, lamda):
    #mse_loss = F.mse_loss(recon_x, x)
    mse_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)   # Normalize by batch size
    perceptual_loss = perceptual_criterion(recon_x, x)
    hist_loss = histogram_loss(recon_x, x)
    
    # Combine the losses
    recon_loss = alpha * mse_loss + theta * perceptual_loss + lamda * hist_loss
    
    # KL divergence
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kld_loss = kld_loss / x.size(0)  # normalize by batch size
    
    total_loss = recon_loss + beta * kld_loss
    loss_dict = {
        "recon_loss": recon_loss,
        "kld_loss": kld_loss,
        "mse_loss": mse_loss,
        "perceptual_loss": perceptual_loss,
        "hist_loss": hist_loss
    }
    
    # Debugging print statements
    # print(f"Debug - recon_loss: {recon_loss.item()}, kld_loss: {kld.item()}, mse_loss: {mse_loss.item()}, perceptual_loss: {perceptual_loss.item()}, hist_loss: {hist_loss.item()}")
    
    return total_loss, loss_dict


# Optimizer
optimizer = torch.optim.Adam(cvae.parameters(), lr=1e-3)

# Check if everything is set up correctly
mse_loss, optimizer

## Training the model

In [None]:
# Capture current time for folder naming.
current_time = datetime.now().strftime("%Y%m%d_%H%M")


# Main function for training the model and visualizing losses.
def train_and_visualize_losses(
    model,
    train_loader,                                                # DataLoader for training data
    val_loader,                                                  # DataLoader for validation data
    optimizer,                                                   # Optimizer for model parameters
    scheduler,                                                   # Learning rate scheduler
    hyperparameters,
):

    start_times = []
    
    # Initialize all the required parameters from hyperparameter dictionary
    num_epochs=hyperparameters["num_epochs"]                    # Total number of epochs to train the model
    patience=hyperparameters["patience"]                        # Number of epochs with no improvement after which training will be stopped
    plateau_threshold=hyperparameters["plateau_threshold"]      # Number of epochs to wait before considering it a plateau (for learning rate adjustment)
    save_interval=hyperparameters["save_interval"]              # Interval at which the model checkpoints are saved
    learning_rate=hyperparameters["learning_rate"]              # Learning rate for the optimizer
    weight_decay=hyperparameters["weight_decay"]                # Weight decay regularization to prevent model from overfitting
    latent_dim=hyperparameters["latent_dim"]                    # Dimensionality of the latent space
    condition_dim=hyperparameters["condition_dim"]              # Dimensionality of the condition vector (e.g., 2 for two scores in our case)
    loss_type=hyperparameters["loss_type"]                      # Type of the loss function ('mse', 'ms-ssim', or 'combined')
    add_loss_on_epoch=hyperparameters["add_loss_on_epoch"]      # Epoch number after which additional loss terms start getting added
    alpha=hyperparameters["alpha"]                              # Weight for the MSE term in the loss when using 'combined' loss_type
    beta=hyperparameters["beta"]                                # Weight for the KL-divergence term in the loss
    theta=hyperparameters["theta"]                              # Weight for the perceptual loss term
    lamda=hyperparameters["lamda"]                              # Weight for the histogram loss term
    interval_perceptual=hyperparameters["interval_perceptual"]  # Delay prior to introducing additional loss.
    interval_kld=hyperparameters["interval_kld"]                # Delay prior to introducing additional loss.
    interval_hist=hyperparameters["interval_hist"]              # Delay prior to introducing additional loss.
    images_to_show=hyperparameters["images_to_show"]            # Number of images to display during training for visual inspection
    nrow=hyperparameters["nrow"]                                # Number of rows when showing the images using `make_grid`
    image_show=hyperparameters["image_show"]                    # Whether to show images during training
    debug=hyperparameters["debug"]                              # If true, print out additional debug information
    

    
    # Function to compute MD5 hash of a tensor. Used for debugging.
    def compute_hash(tensor):
        return hashlib.md5(tensor.cpu().numpy().tobytes()).hexdigest()

    # Function to denormalize a tensor (convert from [-1, 1] to [0, 255]).
    def denormalize(tensor, mean=in_mean, std=in_std):
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return tensor

    # Function to select the appropriate loss function based on the type.
    def select_loss_function(loss_type):
        if loss_type == 'mse':
            return mse_loss
        elif loss_type == 'ms-ssim':
            return ms_ssim_loss
        elif loss_type == 'combined':
            return combined_cvae_loss
        else:
            raise ValueError("Invalid loss_type. Expected 'mse', 'ms-ssim', or 'combined'.")


    def analyze_latent_space(corr_matrix, latent_vectors):
        # Thresholds for analysis
        high_corr_threshold = 0.9
        low_variance_threshold = 0.01

        # Identify pairs of highly correlated dimensions
        high_corr_pairs = []
        for i in range(corr_matrix.shape[0]):
            for j in range(i+1, corr_matrix.shape[1]):
                if abs(corr_matrix[i][j]) > high_corr_threshold:
                    high_corr_pairs.append((i, j))

        # Flag latent dimensions with low variance
        latent_variances = np.var(latent_vectors, axis=0)
        low_variance_dims = np.where(latent_variances < low_variance_threshold)[0]

        # Analysis and Interpretation
        num_high_corr_pairs = len(high_corr_pairs)
        num_low_variance_dims = len(low_variance_dims)

        print(f"Number of highly correlated pairs of latent dimensions: {num_high_corr_pairs}")
        print(f"Number of latent dimensions with low variance: {num_low_variance_dims}\n")

        if num_high_corr_pairs > (0.5 * corr_matrix.shape[0]):
            print("A significant portion of the latent dimensions are highly correlated with each other. This might suggest redundancy in the latent space.")
        else:
            print("The latent dimensions seem to be reasonably independent, suggesting efficient utilization of the latent space.")

        if num_low_variance_dims > (0.5 * latent_vectors.shape[1]):
            print("A significant number of latent dimensions exhibit low variance, indicating potential underutilization of the latent space.\n\n\n")
        else:
            print("Most latent dimensions have reasonable variance, suggesting they might be carrying significant information.\n\n\n")

        return high_corr_pairs, low_variance_dims

    # Create a folder to save interim images during the training process.
    def save_hyperparameters_to_txt(folder, hyperparameters):
        with open(os.path.join(folder, "hyperparameters.txt"), "w") as f:
            for key, value in hyperparameters.items():
                f.write(f"{key}: {value}\n")
  
    # Create a folder name based on hyperparameters.
    folder_suffix = f"beta_{hyperparameters['beta']}_ldim_{hyperparameters['latent_dim']}_{hyperparameters['loss_type']}_loss"
    images_folder = f"results/training_images_{current_time}_{folder_suffix}"
    os.makedirs(images_folder, exist_ok=True)

    # Save hyperparameters to a text file.
    save_hyperparameters_to_txt(images_folder, hyperparameters)
    
    # Initialize loss variables.
    epoch_mse_loss = 0.0
    epoch_perceptual_loss = 0.0
    epoch_hist_loss = 0.0
    
    mse_loss_accum = 0.0
    perceptual_loss_accum = 0.0
    hist_loss_accum = 0.0
    
    # Select the appropriate loss function.
    print(f"Received loss_type: '{loss_type}'")
    criterion = select_loss_function(loss_type)
    
    # Initialize losses to zero if not using the 'combined' loss.
    if loss_type != 'combined':
        kld_loss, perceptual_loss, hist_loss = 0, 0, 0
        
    # Initialize list to store latent vectors
    latent_vectors = []
    
    # Store initial model weights and initialize best loss to infinity.
    best_model_wts = model.state_dict()
    best_loss = float('inf')
    no_improve = 0
    
    # Dictionaries to store training and validation losses.
    train_losses = {
        'total': [], 
        'reconstruction_loss': [], 
        'kld_loss': [], 
        'mse_loss': [], 
        'perceptual_loss': [], 
        'hist_loss': []
    }
    val_losses = {
        'total': [], 
        'reconstruction_loss': [], 
        'kld_loss': [], 
        'mse_loss': [], 
        'perceptual_loss': [], 
        'hist_loss': []
    }
    
    # Get a fixed batch from the training set for visualization during the training phase.
    fixed_train_iter = iter(train_loader)
    fixed_train_batch, fixed_train_labels = next(fixed_train_iter)
    fixed_train_batch = fixed_train_batch.to(device)
    moira_fixed_train_scores = fixed_train_labels[0].unsqueeze(1).float().to(device)
    ferdinando_fixed_train_scores = fixed_train_labels[1].unsqueeze(1).float().to(device)
    fixed_train_labels = torch.cat([moira_fixed_train_scores, ferdinando_fixed_train_scores], dim=1)

    fixed_data_iter = iter(val_loader)
    fixed_val_batch, fixed_val_labels = next(fixed_data_iter)
    fixed_val_batch = fixed_val_batch.to(device)
    moira_fixed_val_scores = fixed_val_labels[0].unsqueeze(1).float().to(device)
    ferdinando_fixed_val_scores = fixed_val_labels[1].unsqueeze(1).float().to(device)
    fixed_val_labels = torch.cat([moira_fixed_val_scores, ferdinando_fixed_val_scores], dim=1)
    
    # Define beta values for KL divergence weight adjustment.
    beta_values = [0.01, 0.1, 0.5, 1.0]
    beta_update_interval = 10 * round((num_epochs // (len(beta_values) + 1)) / 10)
    beta_index = 0

    # Original values
    original_alpha = alpha
    original_beta = beta
    original_theta = theta
    original_lamda = lamda
    
    # Variables for monitoring validation loss and introducing additional losses.
    prev_val_loss = float('inf')
    plateau_count = 0
    plateau_threshold = 10  # Number of epochs to wait before considering it a plateau.
    
    # Flags for the introduction of additional losses.
    kld_introduced = False
    perceptual_introduced = False
    hist_introduced = False
    
    # Training loop.
    for epoch in range(num_epochs):
               
        # Record the start time of the epoch.
        start_time = time.time()
        
        # Initialize the list to store latent vectors for this epoch.
        latent_vectors = []  
            
        print(f"Epoch {epoch}/{num_epochs - 1}")
        print('-' * 10)
        
        # Re-assign the original values at the start of each epoch       
        if kld_introduced:
            beta = original_beta
        if perceptual_introduced:
            theta = original_theta
        if hist_introduced:
            lamda = original_lamda
        
        # Set beta, theta, and lamda to 0 if current epoch is less than additional_loss_on_epoch.
        if epoch < add_loss_on_epoch:
            alpha, beta, theta, lamda = 1, 0, 0, 0 #hyperparameters['beta']
            can_monitor_plateau = False
        else:
            can_monitor_plateau = True 
        
        # Initialize running losses.
        running_loss = 0.0
        running_reconstruction_loss = 0.0
        running_kld = 0.0
        running_mse_loss = 0.0
        running_perceptual_loss = 0.0
        running_hist_loss = 0.0
        
        # Iterate over training and validation phases.
        for phase in ['train', 'val']:
            
            if phase == 'train':
                fixed_batch = fixed_train_batch
                fixed_labels = fixed_train_labels
            else: # 'val' phase
                fixed_batch = fixed_val_batch
                fixed_labels = fixed_val_labels
                
            # Compute and store the hash for the fixed batch for consistency check.
            fixed_batch_hash = compute_hash(fixed_batch)
            
            # Set the model to training mode during the training phase and evaluation mode during the validation phase.
            model.train() if phase == 'train' else model.eval()
            
            # Select the appropriate data loader.
            dataloader = train_loader if phase == 'train' else val_loader

            # Iterate over batches.
            for inputs, (moira_scores, ferdinando_scores) in dataloader:
                inputs = inputs.to(device)
                moira_scores = moira_scores.unsqueeze(1).float().to(device)
                ferdinando_scores = ferdinando_scores.unsqueeze(1).float().to(device)
                combined_scores = torch.cat([moira_scores, ferdinando_scores], dim=1)

                optimizer.zero_grad()

                # Forward pass.
                with torch.set_grad_enabled(phase == 'train'):
                    recon_batch, mu, logvar = model(inputs, combined_scores)
                    
                    # Store latent vectors
                    latent_vectors.append(mu.detach().cpu().numpy())
                    
                    # Compute the loss.
                    loss, loss_dict = combined_cvae_loss(recon_batch, inputs, mu, logvar, beta, alpha, theta, lamda)
                    
                    # Perform backward pass and optimization during the training phase.
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Update running losses.
                running_loss += loss.item() * inputs.size(0)
                running_reconstruction_loss += loss_dict['recon_loss'].item() * inputs.size(0)
                running_kld += loss_dict['kld_loss'].item() * inputs.size(0)

                if hyperparameters['loss_type'] == 'combined':
                    running_mse_loss += loss_dict['mse_loss'].item() * inputs.size(0)
                    running_perceptual_loss += loss_dict['perceptual_loss'].item() * inputs.size(0)
                    running_hist_loss += loss_dict['hist_loss'].item() * inputs.size(0)

            # Compute average losses.
            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_reconstruction_loss = running_reconstruction_loss / len(dataloader.dataset)
            epoch_kld = running_kld / len(dataloader.dataset)
            
            if hyperparameters['loss_type'] == 'combined':
                epoch_mse_loss = running_mse_loss / len(dataloader.dataset)
                epoch_perceptual_loss = running_perceptual_loss / len(dataloader.dataset)
                epoch_hist_loss = running_hist_loss / len(dataloader.dataset)

            # Store epoch losses.
            if phase == 'train':
                train_losses['mse_loss'].append(epoch_mse_loss)
                train_losses['perceptual_loss'].append(epoch_perceptual_loss)
                train_losses['hist_loss'].append(epoch_hist_loss)
            else:
                val_losses['mse_loss'].append(epoch_mse_loss)
                val_losses['perceptual_loss'].append(epoch_perceptual_loss)
                val_losses['hist_loss'].append(epoch_hist_loss)

            # Update the scheduler based on the validation loss.
            if phase == 'val' and scheduler:
                scheduler.step(epoch_loss)

            print(f"{phase} losses:       "
                  f"{hyperparameters['loss_type']}: {epoch_reconstruction_loss:.2f}, "
                  f"mse_loss: {alpha * epoch_mse_loss:.2f}, "
                  f"kld_loss: {beta * epoch_kld:.2f}, "
                  f"perceptual_loss: {theta * epoch_perceptual_loss:.2f}, "
                  f"histogram_loss: {lamda * epoch_hist_loss:.2f}")


            # Check for validation loss plateau and introduce additional losses.
            if phase == 'val' and can_monitor_plateau:

                if (abs(prev_val_loss - epoch_loss) / prev_val_loss) < 0.01:  # small threshold for considering it a plateau.
                    plateau_count += 1
                else:
                    plateau_count = 0

                # Introduce kld_loss either after a fixed interval or if a plateau is detected.
                if epoch == add_loss_on_epoch + interval_kld or (plateau_count >= plateau_threshold and not kld_introduced):
                    beta = original_beta
                    kld_introduced = True
                    print("\n\nIntroducing the KLD loss\n\n")
                    plateau_count = 0  # Reset plateau count after introducing a loss

                # Introduce perceptual_loss either after a fixed interval or if a plateau is detected.
                elif epoch == add_loss_on_epoch + interval_perceptual or (plateau_count >= plateau_threshold and not perceptual_introduced):
                    theta = original_theta
                    perceptual_introduced = True
                    print("\n\nIntroducing the perceptual loss\n\n")
                    plateau_count = 0  # Reset plateau count after introducing a loss

                # Introduce hist_loss either after a fixed interval or if a plateau is detected.
                elif epoch == add_loss_on_epoch + interval_hist or (plateau_count >= plateau_threshold and not hist_introduced):
                    lamda = original_lamda
                    hist_introduced = True
                    print("\n\nIntroducing the histogram loss\n\n")
                    plateau_count = 0  # Reset plateau count after introducing a loss

                prev_val_loss = epoch_loss
                

            
            # Visualize generated images during the training.
            if phase == 'train':
                with torch.no_grad():
                    model.eval()
                    
                    # Save original and reconstructed images from the training set.
                    save_image(fixed_train_batch[:10], f"{images_folder}/original_train.png", nrow=10)
                    recon_train, _, _ = model(fixed_train_batch, fixed_train_labels)
                    denormalized_recon_train = denormalize(recon_train.clone())
                    save_image(recon_train[:10], f"{images_folder}/recon_train_epoch_{epoch:03}.png", nrow=10)
                    
                    # Display images.
                    if image_show:
                        
                        grid_img_train_tensor = make_grid(recon_train[:images_to_show], nrow=nrow)
                        grid_img_train_resized = F.interpolate(grid_img_train_tensor.unsqueeze(0), size=(96, 128 * images_to_show), mode='bilinear', align_corners=True).squeeze(0)
                        plt.imshow(grid_img_train_resized.permute(1, 2, 0).cpu().numpy())
                        plt.title(f"Training Reconstruction - Epoch {epoch}")
                        plt.axis('off')
                        plt.show()

            # For the validation phase
            elif phase == 'val':
                with torch.no_grad():
                    model.eval()
                    
                    # Save original and reconstructed images from the validation set.
                    save_image(fixed_val_batch[:10], f"{images_folder}/original_val.png", nrow=10)
                    recon_val, _, _ = model(fixed_val_batch, fixed_val_labels)
                    denormalized_recon_val = denormalize(recon_val.clone())
                    save_image(recon_val[:10], f"{images_folder}/recon_val_epoch_{epoch:03}.png", nrow=10)
                    
                    # Display images.
                    if image_show:
                        
                        grid_img_val_tensor = make_grid(recon_val[:images_to_show], nrow=nrow)
                        grid_img_val_resized = F.interpolate(grid_img_val_tensor.unsqueeze(0), size=(96, 128 * images_to_show), mode='bilinear', align_corners=True).squeeze(0)
                        plt.imshow(grid_img_val_resized.permute(1, 2, 0).cpu().numpy())
                        plt.title(f"Validation Reconstruction - Epoch {epoch}")
                        plt.axis('off')
                        plt.show()

                model.train()

            # Update best model if current validation loss is lower.
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = model.state_dict()
                no_improve = 0
            elif phase == 'val':
                no_improve += 1
                
            
        # Check if it's time to analyze the latent space (every 10 epochs).
        if (epoch + 1) % 10 == 0:
            # Re-calculate the correlation matrix and perform analysis
            flattened_latents = np.vstack(latent_vectors)
            corr_matrix = np.corrcoef(flattened_latents, rowvar=False)

            # Plot the correlation matrix
            plt.figure(figsize=(10, 10))
            plt.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
            plt.colorbar()
            plt.title(f"Latent Variable Correlation Matrix - Epoch {epoch}")
            plt.show()

            # Perform latent space analysis
            analyze_latent_space(corr_matrix, flattened_latents)

            # Reset latent_vectors for the next set of epochs
            latent_vectors = []

                
        # Calculate the elapsed time and estimate the remaining time at the end of the epoch.
        end_time = time.time()
        elapsed_time = end_time - start_time
        start_times.append(elapsed_time)
        avg_time_per_epoch = sum(start_times) / len(start_times)
        estimated_time_left = avg_time_per_epoch * (num_epochs - epoch - 1)
        print(f"Estimated time left: {estimated_time_left // 3600:.0f}h {(estimated_time_left % 3600) // 60:.0f}m {estimated_time_left % 60:.0f}s")
        
        print()

        # Implement early stopping.
        if no_improve >= patience:
            print("Early stopping due to no improvement.")
            break

    print("Best val loss: {:4f}\n".format(best_loss))
    
    # Save best model weights and losses.
    torch.save(best_model_wts, os.path.join(images_folder, "best_conv_model_weights.pth"))
    with open(os.path.join(images_folder, "conv_losses.pkl"), "wb") as f:
        pickle.dump({'train': train_losses, 'val': val_losses}, f)
        
    # plot_losses_and_save(images_folder, hyperparameters['beta'], hyperparameters['latent_dim'], hyperparameters['loss_type'])
    # plot_losses_and_save(images_folder, hyperparameters['beta'], hyperparameters['latent_dim'], hyperparameters['loss_type'], scaled=True)

    # Load best model weights.
    model.load_state_dict(best_model_wts)

    return model, {
        'mse_loss': epoch_mse_loss, 
        'perceptual_loss': epoch_perceptual_loss, 
        'hist_loss': epoch_hist_loss
    }

In [None]:
# Train the model with combined MSE, KLD, perceptual and histogram loss.

# Define hyperparameters.
hyperparameters = {
    "num_epochs": 501,           # Total number of epochs for training
    "patience": 150,             # Number of epochs with no improvement after which training will be stopped (early stopping)
    "save_interval": 1,          # Interval at which the model checkpoints are saved
    "learning_rate": 0.0001,     # Learning rate for the optimizer
    "weight_decay": 1e-2,        # L2 regularization coefficient, helps in preventing overfitting
    "latent_dim": 512,           # Dimensionality of the latent space in the CVAE
    "condition_dim": 2,          # Dimensionality of the condition vector (e.g., 2 for two scores in our case)
    "loss_type": 'combined',     # Type of the loss function to use ('mse', 'ms-ssim', or 'combined')
    "plateau_threshold": 10,     # Number of epochs to wait before considering it a plateau (for learning rate adjustment or adding new losses)
    "alpha": 0.01,               # Weight for the MSE term in the loss when using 'combined' loss_type
    "beta": 0.1,                 # Weight for the KL-divergence term in the loss
    "theta": 1000,               # Weight for the perceptual loss term
    "lamda": 100,                # Weight for the histogram loss term
    "add_loss_on_epoch": 30,     # Epoch number after which additional loss terms start getting added
    "interval_perceptual": 0,    # Interval after 'add_loss_on_epoch' to introduce perceptual loss
    "interval_kld": 20,          # Interval after 'add_loss_on_epoch' to introduce KLD loss
    "interval_hist": 100,        # Interval after 'add_loss_on_epoch' to introduce histogram loss
    "images_to_show": 5,         # Number of images to display during training for visual inspection
    "nrow": 5,                   # Number of rows when showing the images using `make_grid`
    "image_show": True,          # Whether to show images during training
    "debug": False               # If true, print out additional debug information
}

# Create a model instance and initialize optimizer and scheduler.
cvae = CVAE_ResNet18_DualEmbedding(conditional_dim=2, latent_dim=hyperparameters['latent_dim'], debug=False).to(device)
optimizer = torch.optim.Adam(cvae.parameters(), lr=hyperparameters['learning_rate'], weight_decay=hyperparameters['weight_decay'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# Train the model instance and vizualize losses.
trained_model = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, hyperparameters) # Hyperparameters can be changed above before the function call

In [None]:
# Train the model with combined MSE, KLD, perceptual and histogram loss.

# Define hyperparameters.
hyperparameters = {
    "num_epochs": 301,           # Total number of epochs for training
    "patience": 150,             # Number of epochs with no improvement after which training will be stopped (early stopping)
    "save_interval": 1,          # Interval at which the model checkpoints are saved
    "learning_rate": 0.0001,     # Learning rate for the optimizer
    "weight_decay": 1e-2,        # L2 regularization coefficient, helps in preventing overfitting
    "latent_dim": 512,           # Dimensionality of the latent space in the CVAE
    "condition_dim": 2,          # Dimensionality of the condition vector (e.g., 2 for two scores in our case)
    "loss_type": 'combined',     # Type of the loss function to use ('mse', 'ms-ssim', or 'combined')
    "plateau_threshold": 10,     # Number of epochs to wait before considering it a plateau (for learning rate adjustment or adding new losses)
    "alpha": 0.01,               # Weight for the MSE term in the loss when using 'combined' loss_type
    "beta": 0.1,                 # Weight for the KL-divergence term in the loss
    "theta": 1000,               # Weight for the perceptual loss term
    "lamda": 100,                # Weight for the histogram loss term
    "add_loss_on_epoch": 50,     # Epoch number after which additional loss terms start getting added
    "interval_perceptual": 0,    # Interval after 'add_loss_on_epoch' to introduce perceptual loss
    "interval_kld": 20,          # Interval after 'add_loss_on_epoch' to introduce KLD loss
    "interval_hist": 100,        # Interval after 'add_loss_on_epoch' to introduce histogram loss
    "images_to_show": 5,         # Number of images to display during training for visual inspection
    "nrow": 5,                   # Number of rows when showing the images using `make_grid`
    "image_show": True,          # Whether to show images during training
    "debug": False               # If true, print out additional debug information
}

# Create a model instance and initialize optimizer and scheduler.
cvae = CVAE_ResNet18_DualEmbedding(conditional_dim=2, latent_dim=hyperparameters['latent_dim'], debug=False).to(device)
optimizer = torch.optim.Adam(cvae.parameters(), lr=hyperparameters['learning_rate'], weight_decay=hyperparameters['weight_decay'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# Train the model instance and vizualize losses.
trained_model = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, hyperparameters) # Hyperparameters can be changed above before the function call

In [None]:
# Train the model with combined MSE, KLD, perceptual and histogram loss.

# Define hyperparameters.
hyperparameters = {
    "num_epochs": 301,           # Total number of epochs for training
    "patience": 150,             # Number of epochs with no improvement after which training will be stopped (early stopping)
    "save_interval": 1,          # Interval at which the model checkpoints are saved
    "learning_rate": 0.0001,     # Learning rate for the optimizer
    "weight_decay": 1e-2,        # L2 regularization coefficient, helps in preventing overfitting
    "latent_dim": 512,           # Dimensionality of the latent space in the CVAE
    "condition_dim": 2,          # Dimensionality of the condition vector (e.g., 2 for two scores in our case)
    "loss_type": 'ms-ssim',      # Type of the loss function to use ('mse', 'ms-ssim', or 'combined')
    "plateau_threshold": 10,     # Number of epochs to wait before considering it a plateau (for learning rate adjustment or adding new losses)
    "alpha": 0.01,               # Weight for the MSE term in the loss when using 'combined' loss_type
    "beta": 1,                   # Weight for the KL-divergence term in the loss
    "theta": 0,                  # Weight for the perceptual loss term
    "lamda": 0,                  # Weight for the histogram loss term
    "add_loss_on_epoch": 50,     # Epoch number after which additional loss terms start getting added
    "interval_perceptual": 300,  # Interval after 'add_loss_on_epoch' to introduce perceptual loss
    "interval_kld": 0,           # Interval after 'add_loss_on_epoch' to introduce KLD loss
    "interval_hist": 300,        # Interval after 'add_loss_on_epoch' to introduce histogram loss
    "images_to_show": 5,         # Number of images to display during training for visual inspection
    "nrow": 5,                   # Number of rows when showing the images using `make_grid`
    "image_show": True,          # Whether to show images during training
    "debug": False               # If true, print out additional debug information
}

# Create a model instance and initialize optimizer and scheduler.
cvae = CVAE_ResNet18_DualEmbedding(conditional_dim=2, latent_dim=hyperparameters['latent_dim'], debug=False).to(device)
optimizer = torch.optim.Adam(cvae.parameters(), lr=hyperparameters['learning_rate'], weight_decay=hyperparameters['weight_decay'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# Train the model instance and vizualize losses.
trained_model = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, hyperparameters) # Hyperparameters can be changed above before the function call

## Run for further processing

In [None]:
# Check if the folder contains the .pth and .pkl files (i.e.trained model, hyperparameters and losses)
def check_folders_for_files(base_folder):
    required_files = ["best_conv_model_weights.pth", "conv_losses.pkl", "hyperparameters.txt"]
    valid_folders = []

    for folder in os.listdir(base_folder):
        if folder.startswith("training_images_2023"):
            all_files_present = all(os.path.exists(os.path.join(base_folder, folder, file)) for file in required_files)
            if all_files_present:
                valid_folders.append(folder)
    return sorted(valid_folders)

# Rename empty folders without the .pth and .pkl files (i.e.trained model, hyperparameters and losses) 
def rename_folders(root_folder, valid_folders):
    # Identify folders to be renamed
    folders_to_rename = sorted([folder for folder in os.listdir(root_folder) if folder.startswith("training_images_2023") and folder not in valid_folders])
    
    # If there are no folders to rename, exit early
    if not folders_to_rename:
        print("No folders need renaming.")
        return
    
    # List folders to be renamed
    print("Folders to be renamed:")
    for folder in folders_to_rename:
        print(folder)

    # Ask for confirmation
    confirm = input("\nDo you want to rename these folders? (y/n): ")
    
    if confirm.lower() == 'y':
        for folder in folders_to_rename:
            old_folder_path = os.path.join(root_folder, folder)
            new_folder_path = os.path.join(root_folder, "empty_" + folder)
            os.rename(old_folder_path, new_folder_path)
            print(f"Renamed: {old_folder_path} -> {new_folder_path}")
    else:
        print("No folders were renamed.")

def get_folders_in_results_directory(base_path="results/"):
    folders = [f for f in os.listdir(base_path) if f.startswith("training_images_2023") and os.path.isdir(os.path.join(base_path, f))]
    return sorted(folders)  # Sort the folders by name which effectively sorts them by date

def extract_hyperparameters(folder):
    """Extract the required hyperparameters from the hyperparameters.txt file."""
    hyperparameters = {}
    with open(os.path.join(folder, "hyperparameters.txt"), "r") as f:
        for line in f.readlines():
            if ":" in line:  # Check if the line contains a colon
                parts = line.strip().split(": ")
                key = parts[0]
                value = parts[1]
                hyperparameters[key] = value
    return hyperparameters

def load_model_weights(model, folder_path):
    # Load the model weights from the specified folder
    weights_path = os.path.join(folder_path, "best_conv_model_weights.pth")
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"No weights file found at {weights_path}")
    model.load_state_dict(torch.load(weights_path, map_location=device))
    return model

## Load weights and generate embeddings

In [None]:
# Check if the folder contains the .pth and .pkl files (i.e.trained model, hyperparameters and losses)
def check_folders_for_files(base_folder):
    required_files = ["best_conv_model_weights.pth", "conv_losses.pkl", "hyperparameters.txt"]
    valid_folders = []

    for folder in os.listdir(base_folder):
        if folder.startswith("training_images_2023"):
            all_files_present = all(os.path.exists(os.path.join(base_folder, folder, file)) for file in required_files)
            if all_files_present:
                valid_folders.append(folder)
    return sorted(valid_folders)

# Rename empty folders without the .pth and .pkl files (i.e.trained model, hyperparameters and losses) 
def rename_folders(root_folder, valid_folders):
    # Identify folders to be renamed
    folders_to_rename = sorted([folder for folder in os.listdir(root_folder) if folder.startswith("training_images_2023") and folder not in valid_folders])
    
    # If there are no folders to rename, exit early
    if not folders_to_rename:
        print("No folders need renaming.")
        return
    
    # List folders to be renamed
    print("Folders to be renamed:")
    for folder in folders_to_rename:
        print(folder)

    # Ask for confirmation
    confirm = input("\nDo you want to rename these folders? (y/n): ")
    
    if confirm.lower() == 'y':
        for folder in folders_to_rename:
            old_folder_path = os.path.join(root_folder, folder)
            new_folder_path = os.path.join(root_folder, "empty_" + folder)
            os.rename(old_folder_path, new_folder_path)
            print(f"Renamed: {old_folder_path} -> {new_folder_path}")
    else:
        print("No folders were renamed.")
  
            
# Check folders and print results
result_folders = check_folders_for_files("results")
print("Folders containing the required files:")
for folder in result_folders:
    print(folder)
    
print()

# Rename folders 
rename_folders("results", result_folders)

In [None]:
def generate_and_save_images_and_embeddings(model, dataloader, num_samples=10, device="cuda"):
    # 1. Get some images and their associated condition vectors from the dataloader
    images, c_list = next(iter(dataloader))
    
    # Debug messages to check the type and shape of images and c
    print(f"Type of images: {type(images)}, Shape of images: {images.shape}")
    print(f"Type of c_list: {type(c_list)}")
    
    # Stack the scores for Moira and Ferdinando along the second dimension
    c = torch.stack(c_list, dim=1).to(device)
    print(f"Combined condition tensor shape: {c.shape}")
    
    images = images.to(device)[:num_samples]
    c = c[:num_samples]

    # 2. Pass the images and their condition vectors through the model to get the embeddings       
    with torch.no_grad():
        mu, logvar = model.encode(images, c)
        z = model.reparameterize(mu, logvar)
        embeddings = torch.cat([z, c], dim=1)  # concatenate conditions to embeddings before saving


    # 3. Save the images and embeddings to the specified paths
    torch.save(images.cpu(), 'ready/images.pth')
    torch.save(embeddings.cpu(), 'ready/embeddings.pth')
    
    # 4. Save the images to a visual format for verification
    save_image(images.cpu(), 'ready/samples.png', nrow=int(num_samples**1), padding=2, normalize=True)
    
    # 5. Define the source and destination paths for the model
    src_path = os.path.join(folder_path, "best_conv_model_weights.pth")
    dest_path = 'ready/best_conv_model_weights.pth'
    
    # 6. Use shutil to copy the file
    import shutil
    shutil.copy(src_path, dest_path)
    print(f"Images and embeddings have been saved. Model weights copied to: {dest_path}.")


# Assuming the model class and structure is defined, instantiate the model
model = CVAE_ResNet18_DualEmbedding()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Load the model weights
base_folder = "results"
valid_folders = check_folders_for_files(base_folder)

if valid_folders:
    # Display valid folders and ask user to choose
    for idx, folder_name in enumerate(valid_folders, 1):
        print(f"{idx}. {folder_name}")
    
    chosen_idx = int(input("\nChoose a folder: ")) - 1
    if chosen_idx < 0 or chosen_idx >= len(valid_folders):
        raise ValueError("Invalid choice. Please choose a valid number.")
    
    chosen_folder = valid_folders[chosen_idx]
    folder_path = os.path.join(base_folder, chosen_folder)
    model = load_model_weights(model, folder_path)
    
    print(f"Model weights loaded from: {chosen_folder}")
    
    # Generate and save images and embeddings with the chosen model
    generate_and_save_images_and_embeddings(model, train_loader, device=device)
else:
    raise ValueError("No valid folders found with trained models.")

In [None]:
# Load the saved images
saved_images = torch.load('ready/images.pth')
resized_images = F.interpolate(saved_images, size=(96, 128))  # Resize to 128x96

# Visualize the first few images
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i, ax in enumerate(axes):
    ax.imshow(resized_images[i].permute(1, 2, 0).cpu().numpy())
    ax.set_title(f"Original Image {i+1}")
    ax.axis("off")

plt.tight_layout()
plt.show()

In [None]:
plot_losses_and_save(images_folder, hyperparameters['beta'], hyperparameters['latent_dim'], hyperparameters['loss_type'])
plot_losses_and_save(images_folder, hyperparameters['beta'], hyperparameters['latent_dim'], hyperparameters['loss_type'], scaled=True)

## DIsplay reconstructed images

In [None]:
def get_folders_in_results_directory():
    return sorted([folder for folder in os.listdir('results/') if os.path.isdir(os.path.join('results/', folder)) and folder.startswith('training_images')])

def display_folder_list(folders):
    print("Available folders:")
    for idx, folder in enumerate(folders):
        print(f"{idx + 1}. {folder}")
    print("\nEnter the number of the folder to process:")

In [None]:
def display_images(folder_path, img_type='train'):
    if img_type == 'train':
        original_img_path = os.path.join(folder_path, "original_train.png")
    else:  # 'val'
        original_img_path = os.path.join(folder_path, "original_val.png")
        
    original_img = Image.open(original_img_path)
    original_img = original_img.resize((128*10, 96))  # Resize to 128x96
    
    image_paths = sorted(glob.glob(f"{folder_path}/recon_{img_type}_epoch_*.png"))
    
    for image_path in image_paths:
        recon_img = Image.open(image_path)
        recon_img = recon_img.resize((128*10, 96))  # Resize to 128x96
        
        clear_output(wait=True)
        
        plt.figure(figsize=(15, 3))  # figure size for vertical display
        
        # Display original image
        plt.subplot(2, 1, 1)  # 2x1 grid
        plt.imshow(original_img)
        plt.title("Original")
        plt.axis('off')
        
        # Display reconstructed image
        plt.subplot(2, 1, 2)  # 2x1 grid
        plt.imshow(recon_img)
        plt.title(os.path.basename(image_path))
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
        time.sleep(0.01)
      

# List folders
folders = get_folders_in_results_directory()
display_folder_list(folders)

# Get user input for folder
selected_folder_num = input().strip()

# Default to latest folder if no input is given or invalid input
if not selected_folder_num or not selected_folder_num.isdigit() or int(selected_folder_num) > len(folders):
    selected_folder = folders[-1]  # Latest folder
else:
    selected_folder = folders[int(selected_folder_num) - 1]

# Get user input for image type
img_type_choice = input("\nChoose image type:\n1. Training\n2. Validation\nEnter choice number (default is Training): ").strip()
img_type = 'train' if img_type_choice != '2' else 'val'

clear_output(wait=True)
print(f"Displaying images from: {selected_folder}")
display_images(os.path.join('results/', selected_folder), img_type)

## Save reconstructed images as a Gif

In [None]:
def save_images_to_gif(folder_path, gif_filename, img_type='train'):
    # Depending on the choice of image type, select the original image.
    if img_type == 'train':
        original_img_path = os.path.join(folder_path, "original_train.png")
    else:  # 'val'
        original_img_path = os.path.join(folder_path, "original_val.png")
        
    original_img = Image.open(original_img_path)
    original_img = original_img.resize((128*10, 96))
    
    # Filter the image paths based on the chosen type.
    all_image_paths = sorted(glob.glob(f"{folder_path}/recon_{img_type}_epoch_*.png"))
    
    # Select every 10th image and always include the last one
    image_paths = all_image_paths[::10] + [all_image_paths[-1]] if len(all_image_paths) > 1 else all_image_paths
    
    gif_images = []
    
    # Wrap the loop with tqdm for a progress bar.
    for image_path in tqdm(image_paths, desc="Generating GIF Image", ncols=100):
        recon_img = Image.open(image_path)
        recon_img = recon_img.resize((128*10, 96))
        
        fig, axarr = plt.subplots(2, 1, figsize=(12, 3))
        
        # Display original image.
        axarr[0].imshow(original_img)
        axarr[0].set_title("Original")
        axarr[0].axis('off')
        
        # Display reconstructed image.
        axarr[1].imshow(recon_img)
        axarr[1].set_title(os.path.basename(image_path))
        axarr[1].axis('off')
        
        plt.tight_layout()
        
        # Save the current figure to a BytesIO stream.
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=150)
        buf.seek(0)
        gif_images.append(Image.open(buf))
        plt.close()
        
    # Save gif_images to a GIF using Pillow with a fixed duration of 300ms per frame.
    gif_images[0].save(gif_filename, save_all=True, append_images=gif_images[1:], loop=0, duration=300)
    
# List folders.
folders = get_folders_in_results_directory()
display_folder_list(folders)

# Get user input for folder.
selected_folder_num = input().strip()

# Default to latest folder if no input is given or invalid input.
if not selected_folder_num or not selected_folder_num.isdigit() or int(selected_folder_num) > len(folders):
    selected_folder = folders[-1]
else:
    selected_folder = folders[int(selected_folder_num) - 1]

clear_output(wait=True)
print(f"Generating GIFs from: {selected_folder}")

for img_type in ['train', 'val']:
    # Save inside 'results/gif/' with the appropriate suffix based on the image type.
    gif_filename = os.path.join('results', 'gif', f"{os.path.basename(selected_folder)}_{img_type}_reconstruction.gif")
    os.makedirs('results/gif', exist_ok=True)
    save_images_to_gif(os.path.join('results/', selected_folder), gif_filename, img_type)
    print(f"{img_type} gif saved as: {gif_filename}")

## Load the trained model and generate images

**Ensemble Strategy:** Generate images for each score above the threshold and then average the generated images.

For threshold 7, generating images for scores 7, 8, and 9, then average these images to get the final image.

Pros: Easy to implement, no retraining required.
Cons: May produce blurry images due to averaging.
__________

- for a threshold of 0, we generate and average images for scores 0 through 9.

- for a threshold of 7, we generate and average images for scores 7 through 9.

- for a threshold of 9, we would generate an image only for score 9.

**Probability Strategy:** Instead of setting a discrete score in the condition vector, using a probability distribution.

For threshold 7, set probabilities for scores 7, 8, and 9 to be higher than the rest.

Generating the image using this modified condition vector.
Pros: More flexible representation.
Cons: Might produce images that don't strongly correlate to any specific score.

______
- For scores above the threshold, assigning higher probabilities.
- For scores below the threshold, assigning lower probabilities.
- Using this probability distribution to sample a score.
- Generating the image using the sampled score.

**Iterative Refinement:** Starting with the lowest acceptable score and iteratively refine the image.

Generating an image with score 7. Using this image as an input and condition with score 8, and so on.

Pros: Sequential enhancement of images.
Cons: Multiple forward passes, works a bit slower.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_folders_in_results_directory(base_path="results/"):
    folders = [f for f in os.listdir(base_path) if f.startswith("training_images_20231015") and os.path.isdir(os.path.join(base_path, f))]
    return sorted(folders)  # Sort the folders by name which effectively sorts them by date

def extract_hyperparameters(folder):
    """Extract the required hyperparameters from the hyperparameters.txt file."""
    hyperparameters = {}
    with open(os.path.join(folder, "hyperparameters.txt"), "r") as f:
        for line in f.readlines():
            if ":" in line:  # Check if the line contains a colon
                parts = line.strip().split(": ")
                key = parts[0]
                value = parts[1]
                hyperparameters[key] = value
    return hyperparameters

def load_model_weights(model, folder_path):
    # Load the model weights from the specified folder
    weights_path = os.path.join(folder_path, "best_conv_model_weights.pth")
    if not os.path.exists(weights_path):
        raise FileNotFoundError(f"No weights file found at {weights_path}")
    model.load_state_dict(torch.load(weights_path))
    return model


def tensor_to_image(tensor):
    tensor = tensor.detach().clamp(0, 1)
    tensor = tensor.permute(1, 2, 0)
    return tensor.cpu().numpy()


def generate_images_for_strategy(model, strategy, hyperparameters):
    if strategy == "ensemble":
        return generate_images_for_thresholds(model, int(hyperparameters["condition_dim"]), hyperparameters)
    elif strategy == "probability":
        return generate_images_with_probability(model, int(hyperparameters["condition_dim"]), hyperparameters)
    elif strategy == "refinement":
        start_score = 7
        end_score = 9
        return iterative_refinement(model, int(hyperparameters["condition_dim"]), start_score, end_score, hyperparameters)
    else:
        raise ValueError("Invalid strategy")

        
def generate_and_display_images_for_each_folder(strategy="ensemble", display=True, base_path="results/"):
    folders = get_folders_in_results_directory()

    # Iterate over each folder and generate images
    for folder in folders:
        full_folder_path = os.path.join(base_path, folder)

        # Extract hyperparameters
        hyperparameters = extract_hyperparameters(full_folder_path)

        # Create model instance based on the stored hyperparameters
        # Use default values if certain hyperparameters aren't found
        condition_dim = int(hyperparameters.get('condition_dim', 2))
        latent_dim = int(hyperparameters.get('latent_dim', 512))

        try:
            model_instance = CVAE_ResNet18_DualEmbedding_(conditional_dim=condition_dim,
                                                          latent_dim=latent_dim,
                                                          debug=False).to(device)
            trained_model = load_model_weights(model_instance, full_folder_path)
        except:
            model_instance = CVAE_ResNet18_DualEmbedding(conditional_dim=condition_dim,
                                                         latent_dim=latent_dim,
                                                         debug=False).to(device)
            trained_model = load_model_weights(model_instance, full_folder_path)

        # Generate embeddings for 10 images
        z = torch.randn(10, latent_dim).to(device)
        
        # Assuming a constant condition for simplicity
        c = torch.zeros(10, condition_dim).to(device)

        # Decode to get images
        with torch.no_grad():
            generated_images_tensor = trained_model.decode(z, c)
        generated_images = [tensor_to_image(img) for img in generated_images_tensor]

        # Save the images and embeddings
        torch.save(generated_images_tensor.cpu(), os.path.join(full_folder_path, 'generated_images.pth'))
        torch.save(z.cpu(), os.path.join(full_folder_path, 'embeddings.pth'))

        # Display and save the images as well
        save_path = os.path.join(full_folder_path, f"generated_images/{strategy}_images", f"{strategy}_{folder}.png")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        fig, axes = plt.subplots(1, len(generated_images), figsize=(20, 2))
        if not isinstance(axes, np.ndarray):  # If axes is not an array (i.e., single axes object)
            axes = [axes]

        for ax, img in zip(axes, generated_images):
            ax.imshow(img)
            ax.axis('off')
        fig.suptitle(folder, y=1.05)
        plt.savefig(save_path, bbox_inches='tight')

        if display:
            plt.show()

        plt.close()

        
# Folders we want to visualize images for
result_folders = get_folders_in_results_directory()

In [None]:
def display_training_and_validation_images(result_folders):
    # Figure for training images
    plt.figure(figsize=(20, 4 * len(result_folders)))
    
    # Iterate over the folders for training images
    for idx, folder in enumerate(result_folders):
        full_folder_path = os.path.join("results", folder)
    
        # Extract hyperparameters from the file.
        hp = extract_hyperparameters(full_folder_path)
    
        # Get list of training reconstruction images
        train_recons = [img for img in os.listdir(full_folder_path) if img.startswith("recon_train_epoch_")]
    
        # Get the image with the maximum epoch index for training reconstructions
        max_train_epoch = max([int(re.search(r"(\d+).png", img).group(1)) for img in train_recons])
        max_train_img_name = f"recon_train_epoch_{max_train_epoch}.png"
        max_train_img = os.path.join(full_folder_path, max_train_img_name)
    
        # Display the training images with hyperparameters overlayed
        plt.subplot(len(result_folders), 1, idx+1)
        img = plt.imread(max_train_img)
        plt.imshow(img)
        title_str = (f"Folder: {folder}\n" +
                     f"Epoch Index: {max_train_epoch} | File: {max_train_img_name}\n" +
                     f"latent_dim: {hp['latent_dim']}, " +
                     f"loss_type: {hp['loss_type']}, " +
                     f"learning_rate: {hp['learning_rate']}, " +
                     f"weight_decay: {hp['weight_decay']}, " +
                     f"alpha: {hp['alpha']}, " +
                     f"beta: {hp['beta']}, " +
                     f"theta: {hp['theta']}, " +
                     f"lamda: {hp['lamda']}")
        plt.title(title_str, loc='left')
        plt.axis('off')

    # Save the combined training images as a PNG
    train_results_path = os.path.join("results", "train_results.png")
    plt.tight_layout()
    plt.savefig(train_results_path)
    plt.show()

    # Figure for validation images
    plt.figure(figsize=(20, 4 * len(result_folders)))
    
    # Iterate over the folders for validation images
    for idx, folder in enumerate(result_folders):
        full_folder_path = os.path.join("results", folder)
    
        # Extract hyperparameters from the file.
        hp = extract_hyperparameters(full_folder_path)
    
        # Get list of validation reconstruction images
        val_recons = [img for img in os.listdir(full_folder_path) if img.startswith("recon_val_epoch_")]
    
        # Get the image with the maximum epoch index for validation reconstructions
        max_val_epoch = max([int(re.search(r"(\d+).png", img).group(1)) for img in val_recons])
        max_val_img_name = f"recon_val_epoch_{max_val_epoch}.png"
        max_val_img = os.path.join(full_folder_path, max_val_img_name)
    
        # Display the validation images with hyperparameters overlayed
        plt.subplot(len(result_folders), 1, idx+1)
        img = plt.imread(max_val_img)
        plt.imshow(img)
        title_str = (f"Folder: {folder}\n" +
                     f"Epoch Index: {max_val_epoch} | File: {max_val_img_name}\n" +
                     f"latent_dim: {hp['latent_dim']}, " +
                     f"loss_type: {hp['loss_type']}, " +
                     f"learning_rate: {hp['learning_rate']}, " +
                     f"weight_decay: {hp['weight_decay']}, " +
                     f"alpha: {hp['alpha']}, " +
                     f"beta: {hp['beta']}, " +
                     f"theta: {hp['theta']}, " +
                     f"lamda: {hp['lamda']}")
        plt.title(title_str, loc='left')
        plt.axis('off')

    # Save the combined validation images as a PNG
    val_results_path = os.path.join("results", "val_results.png")
    plt.tight_layout()
    plt.savefig(val_results_path)
    plt.show()
    
        
# Display training and validation images
display_training_and_validation_images(result_folders)

In [None]:
# Function to generate images for direct scores
def generate_images_for_all_scores(model, c_dim, hyperparameters):
    model.eval()
    all_samples = []

    for score in range(10):  # for scores from 0 to 9
        score_samples = []
        for _ in range(10):  # generate 10 images for each score
            z = torch.randn(1, int(hyperparameters["latent_dim"])).to(device)

            # Setting scores to the current score for both Moira and Ferdinando
            condition = torch.zeros(1, c_dim).to(device)
            condition[:, 0] = score  # assuming 0 corresponds to Moira's score
            condition[:, 1] = score  # assuming 1 corresponds to Ferdinando's score

            with torch.no_grad():
                sample = model.decode(z, condition)

            # Resize the image to 96x128
            resized_sample = F.interpolate(sample, size=(96, 128), mode='bilinear', align_corners=False)
            score_samples.append(resized_sample.squeeze(0))
        all_samples.extend(score_samples)

    return torch.stack(all_samples)



def display_generated_images_by_score(result_folders):
    # Iterate over the folders starting with training_images_2023
    for idx, folder in enumerate(result_folders):
        full_folder_path = os.path.join("results", folder)
    
        # Extract hyperparameters from the file.
        hp = extract_hyperparameters(full_folder_path)
    
        # First attempt to load with cvae_ model structure
        model = CVAE_ResNet18_DualEmbedding(conditional_dim=int(hp['condition_dim']), 
                                             latent_dim=int(hp['latent_dim']), debug=False).to(device)
    
        # Load the model weights
        try:
            model = load_model_weights(model, full_folder_path)
        except:
            # If there's an error, try with the cvae model structure
            model = CVAE_ResNet18_DualEmbedding(conditional_dim=int(hp['condition_dim']), 
                                                latent_dim=int(hp['latent_dim']), debug=False).to(device)
            try:
                model = load_model_weights(model, full_folder_path)
            except Exception as e:
                print(f"Error in folder {folder}: {e}")
                continue

        # Generate and visualize images for each threshold
        generated_images_tensor = generate_images_for_thresholds(model, int(hp["condition_dim"]), hp)
        generated_images = [tensor_to_image(img.cpu()) for tensor in generated_images_tensor for img in tensor]

        # Plotting the generated images
        fig, axes = plt.subplots(10, 10, figsize=(20, 20))
    
        # Title
        max_val_epoch = hp.get("num_epochs", "N/A")
        title_str = (f"Folder: {folder}\n" +
                     f"Epoch Index: {hp['num_epochs']}, " +
                     f"latent_dim: {hp['latent_dim']}, " +
                     f"loss_type: {hp['loss_type']}, " +
                     f"learning_rate: {hp['learning_rate']}, " +
                     f"weight_decay: {hp['weight_decay']}, " +
                     f"alpha: {hp['alpha']}, " +
                     f"beta: {hp['beta']}, " +
                     f"theta: {hp['theta']}, " +
                     f"lamda: {hp['lamda']}")
        fig.suptitle(title_str, y=1.02, fontsize=10)

        for row in range(10):
            for col in range(10):
                image_idx = row * 10 + col
                axes[row, col].imshow(generated_images[image_idx])
                axes[row, col].axis('off')
                if col == 0:  # If it's the first column, we'll add a label indicating the score
                    axes[row, col].set_ylabel(f"Score: {row}", fontsize=12)

        # Save the generated images per threshold as a PNG
        gen_results_path = os.path.join("results", f"{folder}_generated_results.png")
        plt.tight_layout()
        plt.savefig(gen_results_path)
        plt.show()

# Display generated images by score
display_generated_images_by_score(result_folders)

In [None]:
# Call the function for each strategy
generate_and_display_images_for_each_folder(strategy="ensemble")

In [None]:
# Function to generate images with probability-based scores
def generate_images_with_probability(model, c_dim, hyperparameters):
    model.eval()
    all_samples = []

    for threshold in range(10):  # for thresholds from 0 to 9
        # Setting scores' probabilities based on the threshold
        probabilities = [1/(10-threshold) if score >= threshold else 0 for score in range(10)]

        # Sample a score based on the probabilities
        sampled_score = np.random.choice(range(10), p=probabilities)

        # Create condition vector with the sampled score
        condition = torch.zeros(1, c_dim).to(device)
        condition[:, 0] = sampled_score  # assuming 0 corresponds to Moira's score
        condition[:, 1] = sampled_score  # assuming 1 corresponds to Ferdinando's score

        # Generate image for the sampled score
        z = torch.randn(1, int(hyperparameters["latent_dim"])).to(device)
        with torch.no_grad():
            sample = model.decode(z, condition)

        # Resize the image to 96x128
        resized_sample = F.interpolate(sample, size=(96, 128), mode='bilinear', align_corners=False)
        all_samples.append(resized_sample.squeeze(0))

    return torch.stack(all_samples)

In [None]:
# Call the function for each strategy
generate_and_display_images_for_each_folder(strategy="probability")

In [None]:
def iterative_refinement(model, c_dim, start_score, end_score, hyperparameters):
    model.eval()
    # Generate initial image with start_score
    z = torch.randn(1, int(hyperparameters["latent_dim"])).to(device)  # Convert latent_dim to int
    condition = torch.zeros(1, c_dim).to(device)
    condition[:, 0] = start_score
    condition[:, 1] = start_score

    with torch.no_grad():
        refined_sample = model.decode(z, condition)

    # Iteratively refine the image
    for score in range(start_score + 1, end_score + 1):
        condition[:, 0] = score
        condition[:, 1] = score

        # Using the previously generated image as an input
        refined_sample, _, _ = model(refined_sample, condition)

    # Resize the image to 96x128
    refined_sample_resized = F.interpolate(refined_sample, size=(96, 128), mode='bilinear', align_corners=False)

    return refined_sample_resized.unsqueeze(0)  # Add a batch dimension

# Call the function for each strategy
generate_and_display_images_for_each_folder(strategy="refinement")

### Hyperparameters search with Optuna

In [None]:
def hyperparameter_search(train_loader, val_loader, hyperparameters, n_trials ):
    
    trial_durations = []

    def objective(trial):
        
        start_time = time.time()
        
        # Suggest values for the hyperparameters
        beta = trial.suggest_float('beta', 1e-6, 1.0)
        alpha = trial.suggest_float('alpha', 1e-2, 100)
        theta = trial.suggest_float('theta', 1, 1000)
        lamda = trial.suggest_float('lamda', 1, 1e10)

        # Train the model for few epochs with the given hyperparameters
        model, loss_dict = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, 
                                                      beta=beta, hyperparameters, alpha, theta, lamda, image_show=False, debug=False)
        
        
        # Compute the standard deviation of the individual losses
        mse_loss = loss_dict['mse_loss']
        perceptual_loss = loss_dict['perceptual_loss']
        hist_loss = loss_dict['hist_loss']
        current_std = np.std([alpha * mse_loss, theta * perceptual_loss, lamda * hist_loss])
        
        end_time = time.time()
        duration = end_time - start_time
        trial_durations.append(duration)
        
        avg_time_per_trial = sum(trial_durations) / len(trial_durations)
        estimated_time_left = avg_time_per_trial * (n_trials - trial.number - 1)  # Assuming n_trials=10
        
        print(f"Trial {trial.number} completed with STD: {current_std}. Estimated time remaining: {estimated_time_left/60:.2f} minutes")
        
        return current_std

    # Create a study object and specify the direction is 'minimize'.
    sampler = optuna.samplers.TPESampler()
    study = optuna.create_study(sampler=sampler, direction='minimize')
    
    # Optimize the study, the objective function is passed in as the first argument.
    study.optimize(objective, n_trials=n_trials)

    best_params = study.best_params

    return best_params

# Set the number of epochs in hyperparameters to 1
hyperparameters['num_epochs'] = 1

# Execute hyperparameter search
best_params = hyperparameter_search(train_loader, val_loader, hyperparameters, n_trials = 1000 )
best_params

### Hyperparameters search with adaptive weights

In [None]:
def adaptive_weights(losses):
    """Compute adaptive weights for the losses based on their magnitudes."""
    inverse_losses = [1.0 / (loss + 1e-5) for loss in losses]  # add a small value to prevent division by zero
    sum_inverse_losses = sum(inverse_losses)
    weights = [inv_loss / sum_inverse_losses for inv_loss in inverse_losses]
    return weights

def hyperparameter_search(train_loader, val_loader, hyperparameters, n_trials):
    
    trial_durations = []

    def objective(trial):
        
        start_time = time.time()
        
        # Suggest values for the hyperparameters
        beta = trial.suggest_float('beta', 1e-6, 1.0)
        alpha = trial.suggest_float('alpha', 1e-2, 10)
        theta = trial.suggest_float('theta', 1, 100)
        lamda = trial.suggest_float('lamda', 4*1e9, 7*1e9)

        # Train the model for few epochs with the given hyperparameters
        model, loss_dict = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, 
                                                      hyperparameters, beta, alpha, theta, lamda, image_show=False, debug=False)
        
        # Multiply each loss by its respective coefficient
        mse_loss = alpha * loss_dict['mse_loss']
        perceptual_loss = theta * loss_dict['perceptual_loss']
        hist_loss = lamda * loss_dict['hist_loss']

        weights = adaptive_weights([mse_loss, perceptual_loss, hist_loss])
        weighted_std = np.std([mse_loss * weights[0], perceptual_loss * weights[1], hist_loss * weights[2]])
        
        end_time = time.time()
        duration = end_time - start_time
        trial_durations.append(duration)
        
        avg_time_per_trial = sum(trial_durations) / len(trial_durations)
        estimated_time_left = avg_time_per_trial * (n_trials - trial.number - 1)
        
        print(f"Trial {trial.number} completed with Weighted STD: {weighted_std}. Estimated time remaining: {estimated_time_left/60:.2f} minutes")
        
        return weighted_std

    # Create a study object and specify the direction is 'minimize'.
    sampler = optuna.samplers.TPESampler()
    study = optuna.create_study(sampler=sampler, direction='minimize')
    
    # Optimize the study, the objective function is passed in as the first argument.
    study.optimize(objective, n_trials=n_trials)

    best_params = study.best_params

    return best_params

# Set the number of epochs in hyperparameters to 1
hyperparameters['num_epochs'] = 1

# Execute hyperparameter search
best_params = hyperparameter_search(train_loader, val_loader, hyperparameters, n_trials = 10)
best_params

### Hyperparameters search

In [None]:
def hyperparameter_search(train_loader, val_loader, hyperparameters):
    # Define search space
    alpha_values = [1e-3, 1e-2, 1e-1]
    theta_values = [1, 5, 10, 100]
    lamda_values = [1, 1e3, 1e5, 1e7, 1e10]

    best_std = float('inf')  # We aim to minimize the standard deviation
    best_hyperparameters = {}

    total_combinations = len(alpha_values) * len(theta_values) * len(lamda_values)
    iteration_time_list = []  # Store the time taken for each iteration

    # For each combination of hyperparameters
    for index, (alpha, theta, lamda) in enumerate(product(alpha_values, theta_values, lamda_values)):
        print(f"Training with alpha: {alpha}, theta: {theta}, lamda: {lamda}")

        start_time = time.time()

        # Modify the loss function to use the current hyperparameters
        def modified_combined_cvae_loss(recon_x, x, mu, logvar, beta):
            return combined_cvae_loss(recon_x, x, mu, logvar, beta, alpha, theta, lamda)

        # Train the model for few epochs
        model, loss_dict = train_and_visualize_losses(cvae, train_loader, val_loader, optimizer, scheduler, 
                                                 hyperparameters, alpha, theta, lamda, image_show=False, debug=False)
        

        end_time = time.time()
        iteration_time = end_time - start_time
        iteration_time_list.append(iteration_time)

        # Compute the standard deviation of the individual losses
        mse_loss = loss_dict['mse_loss']
        perceptual_loss = loss_dict['perceptual_loss']
        hist_loss = loss_dict['hist_loss']

        current_std = np.std([mse_loss, perceptual_loss, hist_loss])

        # Update the best hyperparameters if current standard deviation is lower
        if current_std < best_std:
            best_std = current_std
            best_hyperparameters = {
                'alpha': alpha,
                'theta': theta,
                'lamda': lamda
            }

        avg_time_per_iteration = sum(iteration_time_list) / len(iteration_time_list)
        estimated_time_left = avg_time_per_iteration * (total_combinations - index - 1)
        print(f"Estimated time remaining: {estimated_time_left/60:.2f} minutes")

    return best_hyperparameters

# Set the number of epochs in hyperparameters to 2
hyperparameters['num_epochs'] = 2

# Execute hyperparameter search
best_params = hyperparameter_search(train_loader, val_loader, hyperparameters)
best_params

## Hyperparameters tuning

In [None]:
def save_best_model(study, trial):
    if trial.value == study.best_value:
        best_model = trial.user_attrs["model"]
        torch.save(best_model.state_dict(), f"best_model_trial_{trial.number}.pth")

def objective(trial):
    # 1. Define the hyperparameters to tune with steps
    latent_dim = trial.suggest_int("latent_dim", 128, 512, step=128)
    beta = trial.suggest_float("beta", 0.0001, 10.0, log=True)
    loss_type = trial.suggest_categorical("loss_type", ["MSE"])
    
    hyperparameters = {
        "num_epochs": 101,
        "patience": 50,
        "beta": beta,  
        "save_interval": 1,
        "learning_rate": 0.001,
        "latent_dim": latent_dim,
        "condition_dim": 2,
        "loss_type": "MSE"
    }

    # 2. Initialize model, optimizer etc. with chosen hyperparameters
    model = CVAE_ResNet18_DualEmbedding(conditional_dim=10, latent_dim=latent_dim, debug=False).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

    # 3. Train the model
    trained_model, best_val_loss = train_and_visualize_losses(model, train_loader, val_loader, 
                                           optimizer, scheduler, hyperparameters, image_show=True, debug=False)
    
    # 4. Retrieve the best validation loss from the training
    trial.set_user_attr("model", trained_model)
    
    return best_val_loss

# Create a study object and specify the direction is 'minimize'.
study = optuna.create_study(direction="minimize")

# Optimize the study
study.optimize(objective, n_trials=50, callbacks=[save_best_model])

# Results
print("Number of completed trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print("Value: ", trial.value)
print("Params: ")
for key, value in trial.params.items():
    print(f"{key}: {value}")