# Load data

In [None]:
config_name = 'ps256_th_varynoise_m3_reg_test'

%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import os
import sys

sys.path.insert(1, '/home/quahb/caipi_denoising/src')

from utils.data_io import load_dataset
from preparation.gen_data import get_train_data, get_test_data
from evaluation.compute_metrics import compute_metrics

dataset_folder = os.path.join('/home/quahb/caipi_denoising/data/results/', config_name)

subj_volumes = load_dataset(dataset_folder)
print(len(subj_volumes), subj_volumes.keys())
random_key = list(subj_volumes.keys())[0]
for i in subj_volumes[random_key]:
    print(i.shape)
shapes = [ subj_volumes[subj_id][0].shape for subj_id in subj_volumes.keys() ]
print(shapes)

In [None]:
vv = subj_volumes['1_07_006-V1_3D_EPI_2x2_Reg'][1]

figure, axis = plt.subplots(1, 5, figsize=(28, 16))

i=92
ims = [vv[:,i-2,:],
       vv[:,i-1,:],
       vv[:,i,:],
       vv[:,i+1,:],
       vv[:,i+2,:]]

for ii in range(5):
    im = ims[ii]
    axis[ii].imshow(ims[ii], cmap='gray')
    axis[ii].set_title(f'index: {ii}')
    axis[ii].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(im), np.max(im), np.mean(im), np.std(im)))

In [None]:
def plot_test_set(subj_volumes, start_i=0, plot_n=10, slc_index=128, mode='sagittal', noise_map=False):
    init_slc = 0
    n_rows, n_cols = plot_n, 2
    
    if noise_map: n_cols = 3

    figure, axis = plt.subplots(n_rows, n_cols, figsize=(n_cols * 8, n_rows * 10))
    print('Plotting {}/{} slices...'.format(n_rows, len(subj_volumes.keys())))
    for i, subj_id in enumerate(list(subj_volumes.keys())[start_i:start_i + n_rows]):
        acc_vol, denoised_vol = subj_volumes[subj_id]
        
        acc_vol = acc_vol.astype(np.float64)
        denoised_vol = denoised_vol.astype(np.float64)
        noise_vol = np.subtract(acc_vol, denoised_vol)
        
        acc_min, acc_max, acc_mean, acc_std = np.min(acc_vol), np.max(acc_vol), np.mean(acc_vol), np.std(acc_vol)
        den_min, den_max, den_mean, den_std = np.min(denoised_vol), np.max(denoised_vol), np.mean(denoised_vol), np.std(denoised_vol)
        noise_min, noise_max, noise_mean, noise_std = np.min(noise_vol), np.max(noise_vol), np.mean(noise_vol), np.std(noise_vol)
        
        if mode == 'sagittal':
            acc_slc = acc_vol[:,:,slc_index]
            denoised_slc = denoised_vol[:,:,slc_index]
            noise_slc = noise_vol[:,:,slc_index]
        elif mode == 'axial':
            acc_slc = acc_vol[slc_index,:,:]
            denoised_slc = denoised_vol[slc_index,:,:]
            noise_slc = noise_vol[slc_index,:,:]
        elif mode == 'coronal':
            acc_slc = acc_vol[:,slc_index,:]
            denoised_slc = denoised_vol[:,slc_index,:]
            noise_slc = noise_vol[:,slc_index,:]
        

        axis[i, 0].imshow(acc_slc, cmap='gray')
        axis[i, 0].set_title('{}, Slice: {}'.format(subj_id, slc_index))
        axis[i, 0].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            acc_min, acc_max, acc_mean, acc_std))
        
        axis[i, 1].imshow(denoised_slc, cmap='gray')
        axis[i, 1].set_title('Denoised')
        axis[i, 1].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            den_min, den_max, den_mean, den_std))
        
        if noise_map:
            axis[i, 2].imshow(noise_slc, cmap='gray')
            axis[i, 2].set_title('')
            axis[i, 2].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
                noise_min, noise_max, noise_mean, noise_std))
        
        init_slc += 1
        
    plt.show()
    
    
def plot_trainloo_set(subj_volumes, slc_index=128, mode='sagittal'):
    fontsize=14
    n_row, n_col = 10, 3
    fig = plt.figure(figsize=(52,n_rows * 12))
    plt.subplots_adjust(left=0.1,
                        bottom=0.1, 
                        right=0.55, 
                        top=0.9, 
                        wspace=0.01, 
                        hspace=0.15)

    print('Plotting {}/{} slices...'.format(n_row, len(subj_volumes.keys())))
    for i, subj_id in enumerate(list(subj_volumes.keys())[:n_row]):

        gt_img, X_img, y_img = subj_volumes[subj_id]
        
        if mode == 'sagittal':
            gt_img = gt_img[:,:,slc_index]
            X_img  = X_img[:,:,slc_index]
            y_img  = y_img[:,:,slc_index]
        elif mode == 'axial':
            gt_img = gt_img[slc_index,:,:]
            X_img  = X_img[slc_index,:,:]
            y_img  = y_img[slc_index,:,:]
        elif mode == 'coronal':
            gt_img = gt_img[:,slc_index,:]
            X_img  = X_img[:,slc_index,:]
            y_img  = y_img[:,slc_index,:]

        metrics_gt_X = compute_metrics(gt_img, X_img)
        metrics_gt_y = compute_metrics(gt_img, y_img)

        ax_gt = fig.add_subplot(n_row, n_col, i * n_col + 1)
        ax_gt.imshow(gt_img, cmap='gray')
        ax_gt.set_title(f'{subj_id}, Slice: {slc_index}', fontsize=fontsize)
        ax_gt.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(gt_img), np.max(gt_img), np.mean(gt_img), np.std(gt_img)), fontsize=fontsize)

        ax_X = fig.add_subplot(n_row, n_col, i * n_col + 2)
        ax_X.imshow(X_img, cmap='gray')
        ax_X.set_title('Noise Added, PSNR: {}, SSIM: {}'.format(metrics_gt_X['psnr'], metrics_gt_X['ssim']), fontsize=fontsize)
        ax_X.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(X_img), np.max(X_img), np.mean(X_img), np.std(X_img)), fontsize=fontsize)

        ax_y = fig.add_subplot(n_row, n_col, i * n_col + 3)
        ax_y.imshow(y_img, cmap='gray')
        ax_y.set_title('Denoised, PSNR: {}, SSIM: {}'.format(metrics_gt_y['psnr'], metrics_gt_y['ssim']), fontsize=fontsize)
        ax_y.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(y_img), np.max(y_img), np.mean(y_img), np.std(y_img)), fontsize=fontsize)

# Plot slices

In [None]:
view_mode='sagittal'
slc_index=154

if 'test' in config_name:
    plot_test_set(subj_volumes, start_i=75, plot_n=3, slc_index=slc_index, mode=view_mode, noise_map=False)
else:
    plot_trainloo_set(subj_volumes, slc_index=slc_index, mode=view_mode)

In [None]:
view_mode='coronal'
slc_index=90

if 'test' in config_name:
    plot_test_set(subj_volumes, start_i=60, plot_n=2, slc_index=slc_index, mode=view_mode)
else:
    plot_trainloo_set(subj_volumes, slc_index=slc_index, mode=view_mode)

# Histogram Plot

In [None]:
print(random_key)
random_key = '1_01_020-V1_CAIPI1x2'
image = subj_volumes[random_key][1][128,:,:]
print(np.mean(image))
histogram, bin_edges = np.histogram(image, bins=500, range=(np.min(image), np.max(image)))
print(len(histogram), len(bin_edges))
plt.figure()
plt.title(random_key)
plt.xlabel("grayscale value")
plt.ylabel("pixel count")
plt.xlim([np.min(image), np.max(image)])

plt.plot(bin_edges[0:-1], histogram)
plt.show()

In [None]:
print(random_key)
random_key = '1_01_020-V1_CAIPI1x2'
image = subj_volumes[random_key][1][230,:,:]
print(np.mean(image))
histogram, bin_edges = np.histogram(image, bins=256, range=(np.min(image), np.max(image)))

plt.figure()
plt.title(random_key)
plt.xlabel("grayscale value")
plt.ylabel("pixel count")
plt.xlim([np.min(image), np.max(image)])

plt.plot(bin_edges[0:-1], histogram)
plt.show()

In [None]:
print(random_key)
random_key = random_key
image = subj_volumes[random_key][1][:,:,:]
plt.figure(figsize=(10,10))
plt.imshow(image[128,:,:], cmap='gray')
print(image.shape)
slc_means = []
for slc_i in range(len(image)):
    histogram, bin_edges = np.histogram(image[slc_i], bins=500, range=(np.min(image[slc_i]), np.max(image[slc_i])))
    avg_bins = []
    for bin_i in range(0, len(bin_edges) - 1):
        avg_bins.append( (bin_edges[bin_i] + bin_edges[bin_i + 1]) / 2 )
        
    slc_means.append( sum([ freq * bin_val for freq, bin_val in zip(histogram, avg_bins) ]) / 500 )


new_im = np.copy(image)

for slc_i in range(1, len(new_im)):
    new_im[slc_i] = new_im[slc_i] + slc_means[slc_i] - slc_means[slc_i - 1]


plt.figure(figsize=(10,10))
plt.imshow(new_im[128,:,:], cmap='gray')