## U-Net Brain

In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

# library for loading and manipulating medical images
import torchio as tio

# import custom src code
from src_unet.architecture import UNet
from src_unet.training import sample_batch, train_model
from src_unet.testing import test_unet

from real_src.database_mri import MRImagesDB, sample_batch_mri
from real_src.utils import get_num_workers, set_device, load_model

# reflect changes in src code immediately without restarting kernel
%load_ext autoreload
%autoreload 2

### Set Device

In [None]:
# set the device depending on available GPU
device = set_device()
num_workers = get_num_workers()

### Create and Train Model

In [None]:
# create model and move to device
model = UNet(cond_dim=7)     # set the cond_dim based on bval and bvec
model.to(device)

# training hyperparameters
learning_rate = 0.01
batch_size = 64
epochs = 30

# momentum = .9 means 90% of prev velocity is maintained
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# loss function
loss_fn = torch.nn.MSELoss()

# volume dimensions
volume_dims = tio.ScalarImage('dataset_mini/train_mini/sub-051-01/anat/sub-051-01_t1.nii.gz').data.shape

# choose axis to slice along: 0 = saggital, 1 = coronal, 2 = horizontal
slice_axis = 2

In [None]:
# load files
bvals_path = 'dataset_mini/bvals_round_mini.bval'
bvecs_path = 'dataset_mini/bvecs_mini.bvec'
img_dir_path = 'dataset_mini/train_mini'

# create datasets
train_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=10, slice_axis=slice_axis)
val_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=200, slice_axis=slice_axis)
test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=70, slice_axis=slice_axis)
# 4270 slices total (61x70)

# create dataloaders 
train_loader = DataLoader(dataset=train_set, batch_size=batch_size) # normally set shuffle=True for training (but here dataset randomly generated)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size)

In [None]:
# display sample from batch to verify data
sample_batch_mri(DataLoader(dataset=val_set, batch_size=batch_size))

In [None]:
# train model, saves weights in model folder, and plots loss curve
train_model(model, device, train_loader, val_loader, loss_fn, optimizer, 
                epochs, batch_size, learning_rate, train_set)

### Load and Evaluate Model

In [None]:
cond_dim = 7
model_path = "models_diffusion/best_unet_model.pth"
loaded_model = load_model(model_path, device, cond_dim=cond_dim)

In [None]:
def generate_random_dwi_vol(bvals_path, bvecs_path):
    
    bvals = np.loadtxt(bvals_path)
    bvecs = np.loadtxt(bvecs_path)

    dw_inds = np.where(bvals > 0)[0]
    b0_inds = np.where(bvals == 0)[0]

    vol_ind = dw_inds[np.random.randint(0, len(dw_inds))]
    b0_ind = b0_inds[np.random.randint(0, len(b0_inds))]

    bvec = bvecs[:, vol_ind]
    bval = bvals[vol_ind]

    print("bvalue:", bval)

    return vol_ind, b0_ind, bvec, bval

vol_ind, b0_ind, bvec, bval = generate_random_dwi_vol(bvals_path, bvecs_path)

In [None]:
# define plot title
super_title = "UNet on real data with 4000 Samples and 1 FiLM Layer at Bottleneck"

data_dir_path = '../DWsynth_project/'
bvals_path = data_dir_path + 'bvals_round.bval'
bvecs_path = data_dir_path + 'bvecs.bvec'
img_dir_path = data_dir_path + 'test'

# volume dimensions  [1, H, W, D]
volume_dims = tio.ScalarImage(data_dir_path + 'train/sub-051-01/anat/sub-051-01_t1.nii.gz').data.shape

# test images
test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=10, slice_axis=2)
#test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=1, slice_axis=2, 
#                        subject='sub-051-01', slice_idx=2, bval=bval, bvec=bvec)

# test model on 5 random unseen samples
inputs, targets, preds = test_unet(loaded_model, test_set, device)

# display results 
from src_unet.testing import display_results
display_results(inputs, targets, preds, super_title)

In [None]:
import matplotlib.pyplot as plt

def plot_samples(anat, gt, preds):
    """
    anat, gt, preds: tensors of shape [b, 1, 96, 96]
    """
    b = len(anat)

    fig, axes = plt.subplots(b, 3, figsize=(9, 3*b))

    if b == 1:  # special case if batch size = 1
        axes = axes[None, :]  # add batch dim

    # set column titles only for the first row
    col_titles = ["Anatomical", "Ground Truth", "Prediction"]
    for j, title in enumerate(col_titles):
        axes[0, j].set_title(title)

    for i in range(b):
        # anatomical image
        axes[i, 0].imshow(anat[i], cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 0].axis("off")

        # ground truth
        axes[i, 1].imshow(gt[i], cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 1].axis("off")

        # prediction
        axes[i, 2].imshow(preds[i], cmap="gray", origin="lower", vmin=0, vmax=1)
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
plot_samples(inputs, targets, preds)

In [None]:
plt.imshow(preds[2], cmap="gray", origin="lower")
plt.axis('off')
plt.show()

In [None]:
# starting code for SSIM / PSNR experiment

# so the goal is to iterate through each subject in the test dataset
# iterate through each dwi volume
# we want results for each bvalue separately
# get one particular slice
# compute SSIM and PSNR and then do boxplot of results

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [None]:
cond_dim = 7
model_path = "models_diffusion/best_unet_model.pth"
loaded_model = load_model(model_path, device, cond_dim=cond_dim)

In [None]:
# data paths
data_dir_path = '../DWsynth_project/'
img_dir_path = os.path.join(data_dir_path, 'test')
bvals_path = os.path.join(data_dir_path, 'bvals_round.bval')
bvecs_path = os.path.join(data_dir_path, 'bvecs.bvec')

# load bvals and bvecs
bvals = np.loadtxt(bvals_path)
bvecs = np.loadtxt(bvecs_path)

# choose method of generating images
method = 'unet'

# list of 50 test subjects
subjects = next(os.walk(img_dir_path))[1]

# indices where bvals non zero (dw)
dw_inds = np.where(bvals > 0)[0]

# total number of dw volumes for each subject
num_volumes = len(bvals)

# decide slice index (maybe middle slice)
slice_idx = 35

# num samples (1 for now)
num_samples = 1

# dataframes for metrics for each bval
SSIM_b650 = []
SSIM_b1000 = []
SSIM_b2000 = []

PSNR_b650 = []
PSNR_b1000 = []
PSNR_b2000 = []

# pandas df
results = []

# iterate through each subject
for subject in subjects:
    
    # iterate through each volume
    for vol_idx in range(num_volumes):

        # get bval and bvec
        bval = bvals[vol_idx]
        bvec = bvecs[:, vol_idx]

        #print(bval)

        # pass if bval is 0 (were normalizing by itself would cause errors)
        if bval == 0.0:
            continue

        # create dataloader
        test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                              subject=subject, slice_idx=slice_idx, bval=bval, bvec=bvec, dw_idx=vol_idx)
        test_loader = DataLoader(dataset=test_set, batch_size=num_samples)

        if method == "unet":
            inputs, targets, preds = test_unet(loaded_model, test_set, device)
            target = targets[0]
            pred = preds[0]

        #plt.imshow(target, vmin=0, vmax=1, cmap='grey')
        #plt.show()
        #plt.imshow(pred, vmin=0, vmax=1, cmap='gray')
        #plt.show()

        # convert to float64 for precision
        pred = pred.astype(np.float64)
        target = target.astype(np.float64)
    
        # SSIM (set data_range = max value range of the images)
        ssim_val = ssim(pred, target, data_range=target.max() - target.min())
        # PSNR
        psnr_val = psnr(pred, target, data_range=target.max() - target.min())

        #print(ssim_val)
        #print(psnr_val)

        # check which bval
        if bval == 650:
            SSIM_b650.append(ssim_val)
            PSNR_b650.append(psnr_val)
        elif bval == 1000:
            SSIM_b1000.append(ssim_val)
            PSNR_b1000.append(psnr_val)
        elif bval == 2000:
            SSIM_b2000.append(ssim_val)
            PSNR_b2000.append(psnr_val)

        # append results
        results.append({
            "subject": subject,
            "volume": vol_idx,
            "bval": int(bval),
            "SSIM": ssim_val,
            "PSNR": psnr_val
        })

#print(len(SSIM_b650))
#print(len(SSIM_b1000))
#print(len(SSIM_b2000))

# 750 of bval=650, 1500 of bval=1000, 3000 of bval=2000 -> 5250 samples total

# SSIM ranges from [-1, 1] -> values closer to 1 are better (perfect reconstruction), .85 is good similarity but small diff like noise etc
# PSNR higher values mean better quality, > 40 is excellent, 20-30 is acceptable but noticable differences

In [None]:
print(num_volumes)

In [None]:
import pandas as pd
# make into DataFrame
df = pd.DataFrame(results)

# quick overview
print(df.head())
print(df.groupby("bval")[["SSIM", "PSNR"]].mean())

In [None]:
print(np.mean(SSIM_b650))
print(np.mean(SSIM_b1000))
print(np.mean(SSIM_b2000))

print(np.mean(PSNR_b650))
print(np.mean(PSNR_b1000))
print(np.mean(PSNR_b2000))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Boxplot for SSIM
plt.figure(figsize=(8,6))
sns.boxplot(x="bval", y="SSIM", data=df, palette="Set2", showfliers=False)
#sns.stripplot(x="bval", y="SSIM", data=df, color="black", size=3, alpha=0.5)  # overlay points
plt.title("SSIM distribution per b-value")
plt.xlabel("b-value")
plt.ylabel("SSIM")
#plt.ylim(0.75, 1)  # since SSIM âˆˆ [0,1]
plt.show()

# Boxplot for PSNR
plt.figure(figsize=(8,6))
sns.boxplot(x="bval", y="PSNR", data=df, palette="Set2", showfliers=False)
plt.title("PSNR distribution per b-value")
plt.xlabel("b-value")
plt.ylabel("PSNR (dB)")
plt.show()

In [None]:
# data paths
data_dir_path = '../DWsynth_project/'
img_dir_path = os.path.join(data_dir_path, 'test')
bvals_path = os.path.join(data_dir_path, 'bvals_round.bval')
bvecs_path = os.path.join(data_dir_path, 'bvecs.bvec')

# load bvals and bvecs
bvals = np.loadtxt(bvals_path)
bvecs = np.loadtxt(bvecs_path)

# choose method of generating images
method = 'unet'

# list of 50 test subjects
subjects = next(os.walk(img_dir_path))[1]

# indices where bvals non zero (dw)
dw_inds = np.where(bvals > 0)[0]

# total number of dw volumes for each subject (118)
num_volumes = len(bvals)

# num samples (number of non-zero dw volumes = 105)
num_samples = len(dw_inds)

# decide slice index (maybe middle slice)
slice_idx = 35

In [None]:
def batch_metrics(preds, targets):
    
    ssim_vals, psnr_vals = [], []
    
    for i in range(len(preds)):
        pred = preds[i].astype(np.float64)
        target = targets[i].astype(np.float64)

        data_range = target.max() - target.min()
        
        # compute ssim
        ssim_val = ssim(pred, target, data_range=data_range)
        ssim_vals.append(ssim_val)

        # compute psnr
        psnr_val = psnr(pred, target, data_range=data_range)
        psnr_vals.append(psnr_val)

    return ssim_vals, psnr_vals

In [None]:
# pandas df
ssim_results_list = []
psnr_results_list = []

# iterate through each subject
for subject in subjects:

    # create dataloader - just needs subject and slice idx
    test_set = MRImagesDB(img_dir_path, bvals_path, bvecs_path, volume_dims, num_samples=num_samples, slice_axis=2, 
                              subject=subject, slice_idx=slice_idx)
    test_loader = DataLoader(dataset=test_set, batch_size=num_samples)

    if method == "unet":
        inputs, targets, preds = test_unet(loaded_model, test_set, device)
        #print(len(preds))       # [B, 96, 96]

    # get ssim and psnr esults for the batch (subject) and add to master list
    ssim_results, psnr_results = batch_metrics(preds, targets)
    ssim_results_list.append(ssim_results)
    psnr_results_list.append(psnr_results)

    print(subject)

In [None]:
all_dfs = []

# match bvals  (105 non zero bvals in order)
matched_bvals = [bvals[i] for i in dw_inds]

for subj in range(50):

    # store as df
    df = pd.DataFrame({
        "subject": subj,
        "bval": matched_bvals,
        "ssim": ssim_results_list[subj],
        "psnr": psnr_results_list[subj]
    })
    all_dfs.append(df)

# concatenate all subjects into one big dataframe
df_all = pd.concat(all_dfs, ignore_index=True)

# mean SSIM and PSNR per bval across all subjects
mean_metrics = df_all.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
df_all.to_csv("unet_metrics_df.csv", index=False)

In [None]:
df_unet = pd.read_csv("metrics_df_unet.csv")
mean_metrics = df_unet.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
df_diffusion = pd.read_csv("metrics_df_diffusion.csv")
mean_metrics = df_diffusion.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
df_latent = pd.read_csv("metrics_df_latent.csv")
mean_metrics = df_latent.groupby("bval")[["ssim", "psnr"]].mean()
print(mean_metrics)

In [None]:
# Add a column indicating which model the data came from
df_unet["model"] = "U-Net"
df_diffusion["model"] = "Diffusion"
df_latent["model"] = "Latent"

# Combine them into one dataframe
df_combined = pd.concat([df_unet, df_diffusion, df_latent], ignore_index=True)

# Make the boxplot: SSIM vs b-value, grouped by model
plt.figure(figsize=(12,6))
sns.boxplot(data=df_combined, x="bval", y="ssim", hue="model", palette="Set2", showfliers=False)

# Tidy up
plt.xticks(rotation=45)
plt.xlabel("b-value")
plt.ylabel("SSIM")
plt.title("SSIM comparison between models across b-values")
plt.legend(title="Model")
plt.tight_layout()
plt.show()

# Make the boxplot: SSIM vs b-value, grouped by model
plt.figure(figsize=(12,6))
sns.boxplot(data=df_combined, x="bval", y="psnr", hue="model", palette="Set2", showfliers=False)

# Tidy up
plt.xticks(rotation=45)
plt.xlabel("b-value")
plt.ylabel("PSNR")
plt.title("PSNR comparison between models across b-values")
plt.legend(title="Model")
plt.tight_layout()
plt.show()