## Diffusion Model

In [None]:
# import torch utils
import torch
from torch.utils.data import DataLoader

# import custom src code
from real_src.utils import get_num_workers, set_device, load_model
from real_src.architecture_unet import UNet
from real_src.database_mri import MRImagesDB, sample_batch_mri

# import other src code
from src_diffusion.testing import sample
#from src_diffusion.training import train_diffusion_model

# reflect changes in src code immediately without restarting kernel
%load_ext autoreload
%autoreload 2

### Set Device

In [None]:
# set the device depending on available GPU
device = set_device()
num_workers = get_num_workers()

### Define Parameters

In [None]:
# set the conditioning dimension
# 3 for toy data:   1 for timestep, 2 for angle (x, y) on unit circle
# 8 for MRI data:   1 for timestep, 1 for scalar bval, and 6 for transformed bvec 
cond_dim = 8       

# 2 channels for shape image + noised img (shadow/dwi)
in_chan = 2

# create model and move to device
model = UNet(in_chan=in_chan, cond_dim=cond_dim)
model.to(device)

# training hyperparameters
batch_size = 100
epochs = 3000

# optimizer (Adam seemed to work better than SGD)
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# loss function
loss_fn = torch.nn.MSELoss()

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2

# diffusion hyperparameters (from DDPM paper)
timesteps = 1000
beta_start, beta_end = 1e-4, 0.02

### Load Data

In [None]:
# load file paths
data_dir_path = '../DWsynth_project/'                #'/cs/student/projects3/cgvi/2024/morrison/DWsynth_project/'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'
img_dir_path = data_dir_path + 'train'

# volume dimensions  [1, H, W, D]
volume_dims = [1, 96, 96, 70]

# number of samples
num_samples = 1000

# create dataset and dataloader
train_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=slice_axis)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, num_workers=num_workers,
                          pin_memory=torch.cuda.is_available(),  persistent_workers=torch.cuda.is_available()) # speeds up GPU transfer

In [None]:
# display sample from batch to verify data
sample_batch_mri(DataLoader(dataset=train_set, batch_size=batch_size))

### Train Model

In [None]:
# train model, saves weights in model folder, and plots loss curve
train_diffusion_model(model, device, train_set, train_loader, loss_fn, optimizer,
            epochs, batch_size, learning_rate, timesteps, beta_start, beta_end)

### Load and Evaluate Model

In [None]:
# set unet dims
cond_dim = 8        # 1 for timestep, 3 for bvec, 1 for bval
in_chan = 2         # change to 2 channels for noisy shadow img + blob image       

model_path = "models_diffusion/best_diffusion_model.pth"
loaded_model = load_model(model_path, device, cond_dim, in_chan)

In [None]:
# switch to test dataset
data_dir_path = '../DWsynth_project/'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'
img_dir_path = data_dir_path + 'test'

# volume dimensions  [1, H, W, D]
volume_dims = [1, 96, 96, 70]

# sample from diffusion model
n_samples = 10

# noise scheduler
timesteps = 1000
beta_start, beta_end = 1e-4, 0.02

# image shape
img_shape = [1, 96, 96]    # [C, H, W] for MRI data

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2

# create test data
test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=n_samples, slice_axis=slice_axis)
test_loader = DataLoader(dataset=test_set, batch_size=n_samples)

In [None]:
# first test DDPM sampling 
sampler = 'ddpm'
samples, gt_shadows, blobs = sample(loaded_model, test_loader, n_samples, timesteps, beta_start, beta_end, img_shape, device, sampler=sampler)

In [None]:
print(blobs.shape)

In [None]:
import matplotlib.pyplot as plt

def plot_samples(anat, gt, preds):
    """
    anat, gt, preds: tensors of shape [b, 1, 96, 96]
    """
    b = anat.shape[0]

    fig, axes = plt.subplots(b, 3, figsize=(9, 3*b))

    if b == 1:  # special case if batch size = 1
        axes = axes[None, :]  # add batch dim

    # set column titles only for the first row
    col_titles = ["Anatomical", "Ground Truth", "Prediction"]
    for j, title in enumerate(col_titles):
        axes[0, j].set_title(title)

    for i in range(b):
        # anatomical image
        axes[i, 0].imshow(anat[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 0].axis("off")

        # ground truth
        axes[i, 1].imshow(gt[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 1].axis("off")

        # prediction
        axes[i, 2].imshow(preds[i, 0].detach().cpu().numpy(), cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
plot_samples(blobs, gt_shadows, samples)

In [None]:
# then try DDIM sampling, (faster and more efficient)
ddim_steps = 100 
ddim_eta = 0.0  # usually set to 0.0 for deterministic sampling

ddim_samples = sample(loaded_model, test_loader, n_samples, timesteps, beta_start, beta_end, img_shape, device, sampler='ddpm')

## Make GIFS

In [None]:
from real_src.utils import create_rotation_gif_frames, set_device, load_model

method = 'diffusion'
subject = 'sub-012-01'
slice_idx = 32

save_dir = 'bvec_rotation_diffusion'

data_dir_path = '../DWsynth_project/'

model_path = 'models_diffusion/best_diffusion_model.pth'
cond_dim = 8
in_chan = 2

device = set_device()

loaded_model = load_model(model_path, device, cond_dim=cond_dim, in_chan=in_chan)

In [None]:
create_rotation_gif_frames(method, subject, slice_idx, data_dir_path, save_dir, loaded_model, device)

In [None]:
from real_src.utils import create_rotation_gif

gif_name = "diffusion_rotation_gif_with_anat.gif"
create_rotation_gif(save_dir, gif_name)

## SLICE STACK

first generate a random DWI volume, and ideally pick one with 2000 bval for better contrast
Then create the full stack of slices from this volume

In [None]:
import numpy as np

def generate_random_dwi_vol(bvals_path, bvecs_path):
    
    bvals = np.loadtxt(bvals_path)
    bvecs = np.loadtxt(bvecs_path)

    dw_inds = np.where(bvals > 0)[0]
    b0_inds = np.where(bvals == 0)[0]

    vol_ind = dw_inds[np.random.randint(0, len(dw_inds))]
    b0_ind = b0_inds[np.random.randint(0, len(b0_inds))]

    bvec = bvecs[:, vol_ind]
    bval = bvals[vol_ind]

    print("bvalue:", bval)

    return vol_ind, b0_ind, bvec, bval

In [None]:
# data paths 
img_dir_path = data_dir_path + 'test'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'

# generate until we get desired bval
vol_ind, b0_ind, bvec, bval = generate_random_dwi_vol(bvals_path, bvecs_path)

In [None]:
##### FOR LATENT MODELS #####

# load encoder blob and decoder shadow
from src_latent.autoencoder import load_ae

from src_latent.testing import sample_latent

shadow_ae = load_ae("models_diffusion/ae_dw.pth", device)
blob_ae = load_ae("models_diffusion/ae_anat.pth", device)

# extract encoder and decoder components
encoder_shadow = shadow_ae.encoder
decoder_shadow = shadow_ae.decoder
encoder_blob = blob_ae.encoder

# move all to device
encoder_shadow = encoder_shadow.to(device)
decoder_shadow = decoder_shadow.to(device)
encoder_blob = encoder_blob.to(device)

In [None]:
# code for slice consistency experiment
import numpy as np
import matplotlib.pyplot as plt
from src_unet.testing import test_unet
import os
import torchio as tio

# set unet dims
cond_dim = 7        # 8 for diffusion, 7 for unet
in_chan = 1         # 2 for diffuion, 1 for unet    

# set unet dims for LATENT
#latent_chan = 8
#in_chan = 2 * latent_chan # multiplied by the latent dim
#out_chan = latent_chan  # out_chan should be same as latent_chan, then we use decoder to get back to original chan
#cond_dim = 8

# load model
model_path = "models_diffusion/best_unet_model.pth"
loaded_model = load_model(model_path, device, cond_dim, in_chan)

# choose method of model to generate slices
method = 'unet'

# choose a subject
subject = 'sub-030-01'

# needed for diffusion for some reason even though we only need one sample
num_samples = 3

# get the ground truth image
anat_path = os.path.join(img_dir_path, subject, 'anat', subject + '_t1.nii.gz')
dw_path = os.path.join(img_dir_path, subject, 'dwi', subject + '_dwi_preproc_' + str(vol_ind) + '.nii.gz')
b0_path = os.path.join(img_dir_path, subject, 'dwi', subject + '_dwi_preproc_' + str(b0_ind) + '.nii.gz')

anat_vol = tio.ScalarImage(anat_path).data.numpy() 
dw_vol = tio.ScalarImage(dw_path).data.numpy()
b0_vol = tio.ScalarImage(b0_path).data.numpy()

dw_vol_norm = np.clip(dw_vol / (b0_vol + 1e-10), 0, 1)  # normalize by b0 volume
mask = anat_vol > 0
dw_vol_norm = dw_vol_norm * mask
dw_vol_norm = dw_vol_norm[0]   # [96, 96, 70]


In [None]:
# iterate through all horizontal slices
dw_hat = np.zeros((0,96,96))

for slice_idx in range(70):

    test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                        subject=subject, slice_idx=slice_idx, bval=bval, bvec=bvec)
    test_loader = DataLoader(dataset=test_set, batch_size=num_samples)
    
    # determine method: diffusion, latent diffusion, or unet
    if method == "diffusion":
        sampler = 'ddpm'
        timesteps = 1000
        beta_start, beta_end = 1e-4, 0.02
        img_shape = volume_dims[:3]
        
        #betas, alphas, alphas_bar = get_noise_scheduler('linear', timesteps, beta_start, beta_end, device)
        ddpm_sample, gt_shadows, blobs = sample(loaded_model, test_loader, num_samples, timesteps, beta_start, beta_end, img_shape, device, sampler=sampler)
        yhat = ddpm_sample.detach().cpu().numpy()[0]  #[1, 96, 96] takes the first sample of 3
        print(slice_idx)
    
    elif method == "unet":
            inputs, targets, preds = test_unet(loaded_model, test_set, device)
            yhat = preds[0]
            yhat = yhat[None, ...] # add dimension so [1,96,96]

    elif method == "latent":
        sampler = 'ddpm'
        timesteps = 1000
        beta_start, beta_end = 1e-4, 0.02
        
        ddpm_sample, gt_shadows = sample_latent(loaded_model, test_loader, num_samples, timesteps, beta_start, beta_end, device, sampler, 
                                                encoder_blob, decoder_shadow)
        yhat = ddpm_sample.detach().cpu().numpy()[0]      #[1, 96, 96] takes the first sample of 3
        print(slice_idx)

    dw_hat = np.concatenate((dw_hat, yhat), axis=0)

In [None]:
plt.figure(figsize=(6,6))   # make the figure bigger (adjust size as needed)
plt.imshow((np.rot90(dw_vol_norm[:,:,34], k=1)), cmap="gray", vmin=0, vmax=1)  # rotate 90 degrees for correct orientation
plt.axis('off')
plt.show()

In [None]:
plt.figure(figsize=(6,6))   # make the figure bigger (adjust size as needed)
plt.imshow((np.rot90(dw_hat[34,:,:], k=1)), cmap="gray", vmin=0, vmax=1)  # rotate 90 degrees for correct orientation
plt.axis('off')
plt.show()

In [None]:
np.save('slice_stack_DDPM_DIFF_NOW', dw_hat)

In [None]:
# just show slices from anat for figure
print(anat_vol.shape)
sagittal = np.rot90(anat_vol[0,40,:,:], k=3)
coronal = np.flipud(np.rot90(anat_vol[0,:,43,:], k=1))
axial = np.flipud(np.rot90(anat_vol[0,:,:,36], k=1))

plt.figure(figsize=(6,6))   # make the figure bigger (adjust size as needed)
plt.imshow(axial, origin='lower', cmap='gray')
plt.axis('off')             # turns off axes
plt.show()

In [None]:
# slice from actual dw volume

sagittal_target = np.fliplr(np.rot90(dw_vol_norm[40,:,:], k=1))
coronal_target = np.rot90(dw_vol_norm[:,43,:], k=1)

plt.figure()  
plt.subplot(1,2,1)
plt.imshow(sagittal, vmin=0, vmax=1)
plt.axis('off')   
plt.title("Sagittal")   
plt.subplot(1,2,2)
plt.imshow(coronal, vmin=0, vmax=1)
plt.axis('off')    
plt.title("Coronal")  
plt.tight_layout()
plt.show()

plt.imshow(sagittal, vmin=0, vmax=1, cmap='gray')
plt.axis('off')
plt.show()
plt.imshow(coronal, vmin=0, vmax=1, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# stack created by model then sliced in other directions

plt.figure()   
plt.subplot(1,2,1)
plt.imshow(np.fliplr(dw_hat[:,40,:]), origin='lower', vmin=0, vmax=1)
plt.axis('off')  
plt.title("Sagittal")  
plt.subplot(1,2,2)
plt.imshow(dw_hat[:,:,43], origin='lower', vmin=0, vmax=1)
plt.axis('off')    
plt.title("Coronal") 
plt.tight_layout()
plt.show()

sagittal_pred = np.fliplr(dw_hat[:,40,:])
coronal_pred = dw_hat[:,:,43]

plt.imshow(np.fliplr(dw_hat[:,40,:]), vmin=0, vmax=1, cmap='gray', origin='lower')
plt.axis('off')
plt.show()
plt.imshow(dw_hat[:,:,43], vmin=0, vmax=1, cmap='gray', origin='lower')
plt.axis('off')
plt.show()

In [None]:
np.save('slice_stack_latent_ddpm', dw_hat)

In [None]:
x = np.load('slice_stack_diff.npy')
print(x.shape)

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

data_range = sagittal_target.max() - sagittal_target.min()

# predictions
sagittal_pred = np.flipud(np.fliplr(dw_hat[:,40,:]))
coronal_pred = np.flipud(dw_hat[:,:,43])

# compute ssim
ssim_val = ssim(sagittal_pred, sagittal_target, data_range=data_range)
psnr_val = psnr(sagittal_pred, sagittal_target, data_range=data_range)
print("Sagittal slice:")
print("SSIM:", ssim_val)
print("PSNR:", psnr_val)

# compute ssim
ssim_val = ssim(coronal_pred, coronal_target, data_range=data_range)
psnr_val = psnr(coronal_pred, coronal_target, data_range=data_range)
print("Coronal slice:")
print("SSIM:", ssim_val)
print("PSNR:", psnr_val)

In [None]:
# code for slice consistency experiment
import numpy as np
import matplotlib.pyplot as plt
from src_unet.testing import test_model
import os
import torchio as tio

# set unet dims
cond_dim = 7        # 8 for diffusion bc of timestep
in_chan = 1         # change to 2 channels for noisy img + shape image       

# load model
model_path = "models_diffusion/best_unet_model.pth"
loaded_model = load_model(model_path, device, cond_dim, in_chan)

# choose a subject
subject = 'sub-012-01'

# get the ground truth image
anat_path = os.path.join(img_dir_path, subject, 'anat', subject + '_t1.nii.gz')
dw_path = os.path.join(img_dir_path, subject, 'dwi', subject + '_dwi_preproc_' + str(vol_ind) + '.nii.gz')
b0_path = os.path.join(img_dir_path, subject, 'dwi', subject + '_dwi_preproc_' + str(b0_ind) + '.nii.gz')

anat_vol = tio.ScalarImage(anat_path).data.numpy() 
dw_vol = tio.ScalarImage(dw_path).data.numpy()
b0_vol = tio.ScalarImage(b0_path).data.numpy()

dw_vol_norm = np.clip(dw_vol / (b0_vol + 1e-10), 0, 1)  # normalize by b0 volume
mask = anat_vol > 0
dw_vol_norm = dw_vol_norm * mask
dw_vol_norm = dw_vol_norm[0]   # [96, 96, 70]

In [None]:
# choose method of model to generate slices
method = 'unet'

# needed for diffusion for some reason even though we only need one sample
num_samples = 3

# iterate through all horizontal slices
dw_hat2 = np.zeros((0,96,96))

for slice_idx in range(70):

    test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                        subject=subject, slice_idx=slice_idx, bval=bval, bvec=bvec)
    test_loader = DataLoader(dataset=test_set, batch_size=num_samples)
    
    # determine method: diffusion, latent diffusion, or unet
    if method == "diffusion":
        sampler = 'ddim'
        timesteps = 1000
        beta_start, beta_end = 1e-4, 0.02
        img_shape = volume_dims[:3]
        
        #betas, alphas, alphas_bar = get_noise_scheduler('linear', timesteps, beta_start, beta_end, device)
        ddpm_sample = sample(loaded_model, test_loader, num_samples, timesteps, beta_start, beta_end, img_shape, device, sampler=sampler)
        yhat = ddpm_sample.detach().cpu().numpy()[0]  #[1, 96, 96] takes the first sample of 3
    
    elif method == "unet":
            inputs, targets, preds = test_model(loaded_model, test_set, device)
            yhat = preds[0]
            yhat = yhat[None, ...] # add dimension so [1,96,96]

    dw_hat2 = np.concatenate((dw_hat2, yhat), axis=0)

In [None]:
# unet stack 
plt.figure()   
plt.subplot(1,2,1)
plt.imshow(np.fliplr(dw_hat2[:,40,:]), origin='lower', vmin=0, vmax=1)
plt.axis('off')  
plt.title("Sagittal")  
plt.subplot(1,2,2)
plt.imshow(dw_hat2[:,:,43], origin='lower', vmin=0, vmax=1)
plt.axis('off')    
plt.title("Coronal") 
plt.tight_layout()
plt.show()

### Model Comparison

Test each of the three models by taking one slice from each volume for each patient in the test set

In [None]:
# create dataloader with specific bval/bvec/slice/subject
            

# slice axis set = 2 as these models were only trained on horizontal slices

slice_idx = 3



test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                        subject=subject, slice_idx=slice_idx, bval=bval, bvec=bvec)
test_loader = DataLoader(dataset=test_set, batch_size=num_samples)

### SSIM / PSNR

The model generates samples way too slow one by one, need to work on getting batching to help

In [None]:
# the idx variable in __getitem__ is directly related to the batch size - its the index in the list
# so if batch_size = 10 it would go from 0-9

# we need a way to use this for the generation task
# set the batch size to 118 so that it aligns with the number of volumes for each subject
# use that instead of the volume index variable??

# wait theres 118 total (but 13 B0 volumes) so really only 105 dw volumes we need to compute metrics for

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import os
import numpy as np

In [None]:
# set unet dims
cond_dim = 8        # 1 for timestep, 6 for bvec, 1 for bval
in_chan = 2         # change to 2 channels for noisy img + shape image       

model_path = "models_diffusion/best_diffusion_model.pth"
loaded_model = load_model(model_path, device, cond_dim, in_chan)

In [None]:
# data paths
data_dir_path = '../DWsynth_project/'
img_dir_path = os.path.join(data_dir_path, 'test')
bvals_path = os.path.join(data_dir_path, 'bvals_round.bval')
bvecs_path = os.path.join(data_dir_path, 'bvecs.bvec')

# load bvals and bvecs
bvals = np.loadtxt(bvals_path)
bvecs = np.loadtxt(bvecs_path)

# volume dimensions  [1, H, W, D]
volume_dims = [1, 96, 96, 70]

# choose method of generating images
method = 'diffusion'

# list of 50 test subjects
subjects = next(os.walk(img_dir_path))[1]

# indices where bvals non zero (dw)
dw_inds = np.where(bvals > 0)[0]

# total number of dw volumes for each subject (118)
#num_volumes = len(bvals)

# num samples (number of non-zero dw volumes = 105)
num_samples = len(dw_inds)

# decide slice index (maybe middle slice)
slice_idx = 35

In [None]:
def batch_metrics(preds, targets):
    
    ssim_vals, psnr_vals = [], []
    
    for i in range(len(preds)):
        pred = preds[i].astype(np.float64)
        target = targets[i].astype(np.float64)

        data_range = target.max() - target.min()
        
        # compute ssim
        ssim_val = ssim(pred, target, data_range=data_range)
        ssim_vals.append(ssim_val)

        # compute psnr
        psnr_val = psnr(pred, target, data_range=data_range)
        psnr_vals.append(psnr_val)

    return ssim_vals, psnr_vals

In [None]:
# to create pandas df
ssim_results_list = []
psnr_results_list = []

ctr = 0

# iterate through each subject
for subject in subjects:

    # create dataloader - just needs subject and slice idx
    test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                              subject=subject, slice_idx=slice_idx)
    test_loader = DataLoader(dataset=test_set, batch_size=num_samples)

    if method == "diffusion":
        sampler = 'ddpm'
        timesteps = 1000
        beta_start, beta_end = 1e-4, 0.02
        img_shape = volume_dims[:3]
        
        ddpm_sample, gt_shadows = sample(loaded_model, test_loader, num_samples, timesteps, beta_start, beta_end, img_shape, device, sampler=sampler)
        
        preds = ddpm_sample.detach().cpu().numpy()[:,0,:,:]     # get rid of channel dim so just [B, 96, 96]
        targets = gt_shadows.detach().cpu().numpy()[:,0,:,:]

    ctr += 1
    print(ctr)

    # get ssim and psnr esults for the batch (subject) and add to master list
    ssim_results, psnr_results = batch_metrics(preds, targets)
    ssim_results_list.append(ssim_results)
    psnr_results_list.append(psnr_results)

In [None]:
import pandas as pd 


all_dfs = []

# match bvals  (105 non zero bvals in order)
matched_bvals = [bvals[i] for i in dw_inds]

for subj in range(50):

    # store as df
    df = pd.DataFrame({
        "subject": subj,
        "bval": matched_bvals,
        "ssim": ssim_results_list[subj],
        "psnr": psnr_results_list[subj]
    })
    all_dfs.append(df)

# concatenate all subjects into one big dataframe
df_all = pd.concat(all_dfs, ignore_index=True)

# mean SSIM and PSNR per bval across all subjects
mean_metrics = df_all.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
# SAVE DATAFRAME!!!

df_all.to_csv("diffusion_metrics_df.csv", index=False)

In [None]:
loaded_df = pd.read_csv("diffusion_metrics_df.csv")
mean_metrics = loaded_df.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
# 750 of bval=650, 1500 of bval=1000, 3000 of bval=2000 -> 5250 samples total

# SSIM ranges from [-1, 1] -> values closer to 1 are better (perfect reconstruction), .85 is good similarity but small diff like noise etc
# PSNR higher values mean better quality, > 40 is excellent, 20-30 is acceptable but noticable differences

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Boxplot for SSIM
plt.figure(figsize=(8,6))
sns.boxplot(x="bval", y="ssim", data=loaded_df, palette="Set2", showfliers=False)
#sns.stripplot(x="bval", y="SSIM", data=df, color="black", size=3, alpha=0.5)  # overlay points
plt.title("SSIM distribution per b-value")
plt.xlabel("b-value")
plt.ylabel("SSIM")
#plt.ylim(0.75, 1)  # since SSIM ∈ [0,1]
plt.show()

# Boxplot for PSNR
plt.figure(figsize=(8,6))
sns.boxplot(x="bval", y="psnr", data=loaded_df, palette="Set2", showfliers=False)
plt.title("PSNR distribution per b-value")
plt.xlabel("b-value")
plt.ylabel("PSNR (dB)")
plt.show()