In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.init as init
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.utils as torch_utils
import torch.nn.functional as F
import numpy as np
import os
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
from utils import *
from train_test_split import *

In [9]:
file_path = r'dataloaders_AT_pretrain_3d_rev.pkl'
with open(file_path, 'rb') as f:
    train_dataloader, val_dataloader = pickle.load(f)

In [1]:
from VAEres3d_mri_unsuper import VAE, VAE_Loss
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils import *
train_losses = []
val_losses = []
train_survlosses = []
val_survlosses = []

beta = 1
latent_dim = 256 
red = 1
drop = 0.0
num_epochs = 500
best_loss = float('inf') 
no_improvement_counter = 0
early_stopping_patience = 10
checkpoint_interval = 1
checkpoint_counter = 0
early_stopping_activate = True
vae = VAE(latent_dim, red, drop).to(device)
vae_loss_fn = VAE_Loss(beta).to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.00001)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', cooldown = 2, factor=0.1, patience=5, verbose=True)
vae.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    total_kl_loss = 0.0
    vae.train() 
    for i, batch in enumerate(train_dataloader):
        inputs = batch[0].to(device)
        optimizer.zero_grad()
        x_recon, z, mu, log_var  = vae(inputs)
        loss, _, kl_loss = vae_loss_fn(x_recon, inputs, mu, log_var)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), 0.5)
        optimizer.step()
        total_loss, total_kl_loss = update_total_losses_unsuper(
            total_loss, total_kl_loss, loss, kl_loss
        )

    avg_loss, avg_kl_loss = calculate_average_losses_unsuper(
        total_loss, total_kl_loss, train_dataloader
    )
    
    # Validation loop
    vae.eval() 
    with torch.no_grad():
        total_val_loss = 0.0
        total_val_kl = 0.0
        for val_batch in val_dataloader:
            val_inputs = val_batch[0].to(device)
            val_x_recon, val_z, val_mu, val_log_var = vae(val_inputs)
            val_loss, _, val_kl_loss = vae_loss_fn(val_x_recon, val_inputs, val_mu, val_log_var)

            total_val_loss, total_val_kl = update_total_losses_unsuper(total_val_loss, total_val_kl, val_loss, val_kl_loss)
            
            avg_val_loss, avg_val_kl_loss = calculate_average_losses_unsuper(total_val_loss, total_val_kl, val_dataloader)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Rec.T {avg_loss:.0f}, Rec.V: {avg_val_loss:.0f}, KL.T: {avg_kl_loss:.2f}, KL.V: {avg_val_kl_loss:.2f}")
        append_losses_unsuper(train_losses, val_losses, avg_loss, avg_val_loss)
        scheduler_on_loss = avg_val_loss
        scheduler.step(scheduler_on_loss)
        
        if scheduler_on_loss < best_loss:
            best_loss = scheduler_on_loss
            no_improvement_counter = 0
        else:
            no_improvement_counter += 1
            if no_improvement_counter >= early_stopping_patience:
                print(f"Early stopping at epoch {epoch+1} due to no improvement in validation loss.")
                break
                
        save_checkpoint_intraining(
            epoch, vae, optimizer, best_loss, checkpoint_folder, checkpoint_interval, checkpoint_counter, active=False)
    vae.train()
    
plot_losses_unsuper(train_losses, val_losses)


# Save weights

In [53]:
# Define a directory to save the final model weights
final_model_dir = r"\\WEIGHTS PRETRAIN"
os.makedirs(final_model_dir, exist_ok=True)
final_model_path = os.path.join(final_model_dir, 'CMR-VAE.pt')
torch.save(vae.state_dict(), final_model_path)

In [4]:
import torch
import matplotlib.pyplot as plt 
from mri_reconstruction_plottings import *
%matplotlib inline
from VAEres3d_mri_unsuper import *
vae = VAE(latent_dim=256, red=1, drop=0).to(device)
path_weights = r'\\CMR-VAE.pt'
vae.load_state_dict(torch.load(path_weights))

output_folder = r'\\Reconstruction examples'
get_mri_reconstructions_unsuper_3d(vae, train_dataloader, num_epochs=5, number_of_examples=5, output_folder=output_folder, device='cuda')


In [7]:
def get_mri_reconstructions_metrics(vae, train_dataloader, number_of_examples):
    vae.to(device)
    ssim_scores = []
    psnr_scores = []
    mse_scores = []
    nrmse_scores = []
    
    with torch.no_grad():
        for i, batch in enumerate(train_dataloader):
            inputs = batch[0].to(device)
           # if i > 5:
           #     break
            outputs, _, _, _, = vae(inputs)
            for j in range(inputs.size(0)):
                ssim_perpatient = []
                psnr_perpatient = []
                mse_perpatient = []
                nrmse_perpatient = []
                for slice in range(inputs.size(1)):
                    original_image = inputs[j, slice, :, :, 0].cpu().numpy()
                    reconstructed_image = outputs[j, slice, :, :, 0].cpu().numpy()

                    ssim_value = ssim(original_image, reconstructed_image, data_range=reconstructed_image.max() - reconstructed_image.min())
                    psnr_value = psnr(original_image, reconstructed_image, data_range=reconstructed_image.max() - reconstructed_image.min())
                    mse_value = mean_squared_error(original_image, reconstructed_image)
                    nrmse_value = normalized_root_mse(original_image, reconstructed_image)

                    ssim_perpatient.append(ssim_value)
                    psnr_perpatient.append(psnr_value)
                    mse_perpatient.append(mse_value)
                    if not (np.isinf(nrmse_value) or np.isnan(nrmse_value)):
                        nrmse_perpatient.append(nrmse_value)

                if len(nrmse_perpatient) > 0:
                    ssim_scores.append(np.mean(ssim_perpatient))
                    psnr_scores.append(np.mean(psnr_perpatient))
                    mse_scores.append(np.mean(mse_perpatient))
                    nrmse_scores.append(np.mean(nrmse_perpatient))

    return {
        "Average SSIM": np.mean(ssim_scores),
        "Std SSIM": np.std(ssim_scores),
        "Average PSNR": np.mean(psnr_scores),
        "Std PSNR": np.std(psnr_scores),
        "Average MSE": np.mean(mse_scores),
        "Std MSE": np.std(mse_scores),
        "Average NRMSE": np.mean(nrmse_scores),
        "Std NRMSE": np.std(nrmse_scores)
    }

# Call the function with your VAE, data loader, and other parameters
metrics = get_mri_reconstructions_metrics(vae, train_dataloader, 5)
print("Metrics:")
for key, value in metrics.items():
    print(f"{key}: {value:.4f}")
