In [None]:
config_name = 'ps256_es64_loo_m3_test'
# ps256_es64_loo_m3
# ps256_es64_loo_m3_nonoise
# ps256_es64_loo_m3_test
# ps256_es64_th_m3_test


%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import os
import sys

sys.path.insert(1, '/home/quahb/caipi_denoising/src')

from utils.data_io import load_dataset
from preparation.gen_data import get_train_data, get_test_data
from evaluation.compute_metrics import compute_metrics

dataset_folder = os.path.join('/home/quahb/caipi_denoising/data/results/', config_name)

subj_volumes = load_dataset(dataset_folder)
print(len(subj_volumes), subj_volumes.keys())
random_key = list(subj_volumes.keys())[0]
for i in subj_volumes[random_key]:
    print(i.shape)

In [None]:
def plot_trainloo_set(subj_volumes, slc_index=128, mode='sagittal'):
    fontsize=14
    fig = plt.figure(figsize=(52,120))
    plt.subplots_adjust(left=0.1,
                        bottom=0.1, 
                        right=0.55, 
                        top=0.9, 
                        wspace=0.01, 
                        hspace=0.15)
    n_row, n_col = 10, 3
    print('Plotting {}/{} slices...'.format(n_row, len(subj_volumes.keys())))
    for i, subj_id in enumerate(list(subj_volumes.keys())[:n_row]):

        gt_img = subj_volumes[subj_id][0][:,:,:,0]
        X_img  = subj_volumes[subj_id][1]
        y_img  = subj_volumes[subj_id][2]
        
        if mode == 'sagittal':
            gt_img = gt_img[slc_index,:,:]
            X_img  = X_img[slc_index,:,:]
            y_img  = y_img[slc_index,:,:]
        elif mode == 'axial':
            gt_img = gt_img[:,slc_index,:].transpose()
            X_img  = X_img[:,slc_index,:].transpose()
            y_img  = y_img[:,slc_index,:].transpose()

        metrics_gt_X = compute_metrics(gt_img, X_img)
        metrics_gt_y = compute_metrics(gt_img, y_img)

        ax_gt = fig.add_subplot(n_row, n_col, i * n_col + 1)
        ax_gt.imshow(gt_img, cmap='gray')
        ax_gt.set_title(f'{subj_id}, Slice: {slc_index}', fontsize=fontsize)
        ax_gt.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(gt_img), np.max(gt_img), np.mean(gt_img), np.std(gt_img)), fontsize=fontsize)

        ax_X = fig.add_subplot(n_row, n_col, i * n_col + 2)
        ax_X.imshow(X_img, cmap='gray')
        ax_X.set_title('Noise Added, PSNR: {}, SSIM: {}'.format(metrics_gt_X['psnr'], metrics_gt_X['ssim']), fontsize=fontsize)
        ax_X.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(X_img), np.max(X_img), np.mean(X_img), np.std(X_img)), fontsize=fontsize)

        ax_y = fig.add_subplot(n_row, n_col, i * n_col + 3)
        ax_y.imshow(y_img, cmap='gray')
        ax_y.set_title('Denoised, PSNR: {}, SSIM: {}'.format(metrics_gt_y['psnr'], metrics_gt_y['ssim']), fontsize=fontsize)
        ax_y.set_xlabel('min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(
            np.min(y_img), np.max(y_img), np.mean(y_img), np.std(y_img)), fontsize=fontsize)

def plot_test_set(subj_volumes, slc_index=128, mode='sagittal'):
    init_slc = 0
    n_rows, n_cols = 10, 2

    figure, axis = plt.subplots(n_rows, n_cols, figsize=(25,130))
    print('Plotting {}/{} slices...'.format(n_rows, len(subj_volumes.keys())))
    for i, subj_id in enumerate(list(subj_volumes.keys())[:n_rows]):
        acc_vol, denoised_vol = subj_volumes[subj_id]
        #vol_metrics_acc_den = compute_metrics(acc_vol, denoised_vol)
        
        if mode == 'sagittal':
            acc_slc = acc_vol[slc_index,:,:]
            denoised_slc = denoised_vol[slc_index,:,:]
        elif mode == 'axial':
            acc_slc = acc_vol[:,slc_index,:].transpose()
            denoised_slc = denoised_vol[:,slc_index,:].transpose()
        
        slc_metrics_acc_den = compute_metrics(acc_slc, denoised_slc)
        
        axis[i, 0].imshow(acc_slc, cmap='gray')
        #axis[i, 0].set_title('{}, Slice: {}, Whole volume PSNR: {}, SSIM: {}'.format(subj_id, slc_index, vol_metrics_acc_den['psnr'], vol_metrics_acc_den['ssim']))
        axis[i, 0].set_title('{}, Slice: {}'.format(subj_id, slc_index))
        axis[i, 0].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(acc_slc), np.max(acc_slc), np.mean(acc_slc), np.std(acc_slc)))
        
        axis[i, 1].imshow(denoised_slc, cmap='gray')
        axis[i, 1].set_title('Denoised, Slice PSNR: {}, SSIM: {}'.format(slc_metrics_acc_den['psnr'], slc_metrics_acc_den['ssim']))
        axis[i, 1].set(xlabel='min: {:.3f}, max: {:.3f}, mean: {:.3f}, std: {:.3f}'.format(np.min(denoised_slc), np.max(denoised_slc), np.mean(denoised_slc), np.std(denoised_slc)))
        init_slc += 1
        
    plt.show()

# Plot slices

In [None]:
view_mode='axial'
slc_index=128

if 'test' in config_name:
    plot_test_set(subj_volumes, slc_index=slc_index, mode=view_mode)
else:
    plot_trainloo_set(subj_volumes, slc_index=slc_index, mode=view_mode)

# Histogram Plot

In [None]:
print(random_key)
random_key = random_key # '1_07_047-V1_3D_T2STAR_segEPI'
image = subj_volumes[random_key][1][128,:,:]

histogram, bin_edges = np.histogram(image, bins=256, range=(np.min(image), np.max(image)))
plt.figure()
plt.title(random_key)
plt.xlabel("grayscale value")
plt.ylabel("pixel count")
plt.xlim([np.min(image), np.max(image)])

plt.plot(bin_edges[0:-1], histogram)
plt.show()

# Axial view

In [None]:
X.shape, y.shape