In [None]:
!ls -lrt /home/quahb/caipi_denoising/data/results
!ls -lrt /home/quahb/caipi_denoising/data/datasets

In [None]:
denoised_name = 'ps256_th_varynoise_m3_reg_test'

import json
import numpy as np
import nibabel as nib
import matplotlib as mpl
from datetime import date
mpl.rc('image', cmap='gray')

import matplotlib.pyplot as plt


import sys
import nibabel as nib
import os

sys.path.insert(1, '/home/quahb/caipi_denoising/src')

%load_ext autoreload
%autoreload 2

from preparation.gen_data import get_masks, get_array_values_from_mask
from utils.data_io import load_dataset

from evaluation.compute_metrics import compute_metrics

# Load dataset

In [None]:
mask_data = get_masks()
train_data = load_dataset('/home/quahb/caipi_denoising/data/datasets/training_set_pp')
test_data = load_dataset('/home/quahb/caipi_denoising/data/datasets/testing_set_pp')
reg_test_data = load_dataset('/home/quahb/caipi_denoising/data/datasets/reg_testing_set_pp')
denoised = load_dataset(f'/home/quahb/caipi_denoising/data/results/{denoised_name}')


subj_ids = list(mask_data.keys())
dataset = {}

for i, subj_id in enumerate(subj_ids):
    dataset[subj_id] = {}
    dataset[subj_id]['data'] = {}
    
    for key, val in train_data.items():
        if subj_id in key:
            dataset[subj_id]['data']['epi'] = val[0]
            break

    for key, val in reg_test_data.items():
        if subj_id in key and '1x2' in key:
            dataset[subj_id]['data']['reg1x2'] = val[0]
        elif subj_id in key and '1x3' in key:
            dataset[subj_id]['data']['reg1x3'] = val[0]
        elif subj_id in key and '2x2' in key:
            dataset[subj_id]['data']['reg2x2'] = val[0]

    dataset[subj_id]['data']['lesion_mask'] = mask_data[subj_id]['probability_map']
    dataset[subj_id]['data']['vein_mask'] = mask_data[subj_id]['vein_mask']
    dataset[subj_id]['data']['wm_mask'] = mask_data[subj_id]['3D_T1_Reg_pve_2']
    
    for key, val in denoised.items(): # val is tuple of (input image, denoised image)
        if subj_id in key and '1x2' in key:
            dataset[subj_id]['data']['run1x2'] = val
        elif subj_id in key and '1x3' in key:
            dataset[subj_id]['data']['run1x3'] = val
        elif subj_id in key and '2x2' in key:
            dataset[subj_id]['data']['run2x2'] = val
            
print(dataset['1_01_016-V1']['data'].keys())

# Compute & Load metrics into dataset

In [None]:
'''
dataset: {
    subj_id: {
        data: {
            EPI: arr
            reg1x2: arr
            reg1x3: arr
            reg2x2: arr
            wm_mask: arr
            vein_mask: arr
            lesion_mask: arr
            run1x2: (x,y)
            run1x3: (x,y)
            run2x2: (x,y)
        }
        metrics: {
            1x2: {
                psnr: (native, denoised)
                ssim: ()
                cnr_vw: ()
                cnr_lv: ()
                ...
            }
            1x3: {}
            2x2: {}
        }
    }
}
'''

N=len(subj_ids)

for i, subj_id in enumerate(list(dataset.keys())[:N]):
    print(i, subj_id)
    dataset[subj_id]['metrics'] = {}
    dataset[subj_id]['metrics']['1x2'] = {}
    dataset[subj_id]['metrics']['1x3'] = {}
    dataset[subj_id]['metrics']['2x2'] = {}

    epi = dataset[subj_id]['data']['epi']
    wm_mask = dataset[subj_id]['data']['wm_mask']
    vein_mask = dataset[subj_id]['data']['vein_mask']
    lesion_mask = dataset[subj_id]['data']['lesion_mask']
    
    # brain_mask = dataset[subj_id]['data']['brain_mask']
    #brain_mask = np.moveaxis(nib.load('/home/quahb/Mask_Reg.nii.gz').get_fdata(), 1, 0)
    #brain_mask = np.array(np.moveaxis(nib.load('/home/quahb/Mask_Reg.nii.gz').get_fdata(), 1, 0) > 0.5, dtype=np.uint8)

    reg1x2, run1x2 = dataset[subj_id]['data']['reg1x2'], dataset[subj_id]['data']['run1x2'][1]
    reg1x3, run1x3 = dataset[subj_id]['data']['reg1x3'], dataset[subj_id]['data']['run1x3'][1]
    reg2x2, run2x2 = dataset[subj_id]['data']['reg2x2'], dataset[subj_id]['data']['run2x2'][1]

    dataset[subj_id]['metrics']['1x2'] = ( compute_metrics(epi, reg1x2, vein_mask, wm_mask, lesion_mask), 
                                           compute_metrics(epi, run1x2, vein_mask, wm_mask, lesion_mask) )
    dataset[subj_id]['metrics']['1x3'] = ( compute_metrics(epi, reg1x3, vein_mask, wm_mask, lesion_mask), 
                                           compute_metrics(epi, run1x3, vein_mask, wm_mask, lesion_mask) )
    dataset[subj_id]['metrics']['2x2'] = ( compute_metrics(epi, reg2x2, vein_mask, wm_mask, lesion_mask), 
                                           compute_metrics(epi, run2x2, vein_mask, wm_mask, lesion_mask) )

# Save Metrics to File

In [None]:
save_dict = {}

for subj_id in dataset.keys():
    save_dict[subj_id] = {}
    save_dict[subj_id] = dataset[subj_id]['metrics']

s_date = str(date.today())
with open(f'/home/quahb/caipi_denoising/metrics/{denoised_name}_{s_date}.json', 'w') as outfile:
    json.dump(save_dict, outfile)

# Compute Average Values from Previously Computed Metrics

In [None]:
dataset['1_07_006-V1']['metrics']['2x2'][0].keys()

In [None]:
# Compute averages
ACC='2x2'

metrics_to_compute = [
    'mse', 'psnr', 'snr',
    'ssim', 'luminance', 'contrast', 'structure', 
    'cnr_vw', 'cnr_lv', 'cnr_lw'
]

mse_before, mse_after = [], []
psnr_before, psnr_after = [], []
snr_before, snr_after = [], []
ssim_before, ssim_after = [], []
luminance_before, luminance_after = [], []
contrast_before, contrast_after = [], []
structure_before, structure_after = [], []
cnr_vw_before, cnr_vw_after = [], []
cnr_lv_before, cnr_lv_after = [], []
cnr_lw_before, cnr_lw_after = [], []

for subj_id in dataset.keys():
    if 'mse' in metrics_to_compute:
        mse_before.append(dataset[subj_id]['metrics'][ACC][0]['mse'])
        mse_after.append(dataset[subj_id]['metrics'][ACC][1]['mse'])
    
    if 'psnr' in metrics_to_compute:
        psnr_before.append(dataset[subj_id]['metrics'][ACC][0]['psnr'])
        psnr_after.append(dataset[subj_id]['metrics'][ACC][1]['psnr'])
  
    if 'snr' in metrics_to_compute:
        snr_before.append(dataset[subj_id]['metrics'][ACC][0]['snr'])
        snr_after.append(dataset[subj_id]['metrics'][ACC][1]['snr'])

    if 'ssim' in metrics_to_compute:
        ssim_before.append(dataset[subj_id]['metrics'][ACC][0]['ssim'])
        ssim_after.append(dataset[subj_id]['metrics'][ACC][1]['ssim'])

    if 'luminance' in metrics_to_compute:
        luminance_before.append(dataset[subj_id]['metrics'][ACC][0]['luminance'])
        luminance_after.append(dataset[subj_id]['metrics'][ACC][1]['luminance'])

    if 'contrast' in metrics_to_compute:
        contrast_before.append(dataset[subj_id]['metrics'][ACC][0]['contrast'])
        contrast_after.append(dataset[subj_id]['metrics'][ACC][1]['contrast'])

    if 'structure' in metrics_to_compute:
        structure_before.append(dataset[subj_id]['metrics'][ACC][0]['structure'])
        structure_after.append(dataset[subj_id]['metrics'][ACC][1]['structure'])

    if 'cnr_vw' in metrics_to_compute:
        cnr_vw_before.append(dataset[subj_id]['metrics'][ACC][0]['cnr_vw'])
        cnr_vw_after.append(dataset[subj_id]['metrics'][ACC][1]['cnr_vw'])

    if 'cnr_lv' in metrics_to_compute:
        cnr_lv_before.append(dataset[subj_id]['metrics'][ACC][0]['cnr_lv'])
        cnr_lv_after.append(dataset[subj_id]['metrics'][ACC][1]['cnr_lv'])

    if 'cnr_lw' in metrics_to_compute:
        cnr_lw_before.append(dataset[subj_id]['metrics'][ACC][0]['cnr_lw'])
        cnr_lw_after.append(dataset[subj_id]['metrics'][ACC][1]['cnr_lw'])

if True:
    for i in range(len(subj_ids)):
        print(
              mse_before[i],       ', ', mse_after[i],       ', ',
              psnr_before[i],      ', ', psnr_after[i],      ', ',
              snr_before[i],       ', ', snr_after[i],       ', ',
              ssim_before[i],      ', ', ssim_after[i],      ', ',
              luminance_before[i], ', ', luminance_after[i], ', ',
              contrast_before[i],  ', ', contrast_after[i],  ', ',
              structure_before[i], ', ', structure_after[i], ', ',
              cnr_vw_before[i],    ', ', cnr_vw_after[i],    ', ',
              cnr_lv_before[i],    ', ', cnr_lv_after[i],    ', ',
              cnr_lw_before[i],    ', ', cnr_lw_after[i]
              )

# Plot CAIPI Before/After Denoising with metrics

In [None]:
plt.rcParams.update({'font.size': 13})

def make_plot(subj_id, view, slice_idx):
    n_rows, n_cols = 2, 4
    figure, axis = plt.subplots(n_rows, n_cols, figsize=(28, n_rows * 9))

    data, metrics = dataset[subj_id]['data'], dataset[subj_id]['metrics']

    if view == 'axial':
        view = tuple([slice(slice_idx, slice_idx + 1), slice(312), slice(256)])
    elif view == 'coronal':
        view = tuple([slice(384), slice(slice_idx, slice_idx + 1), slice(256)])
    elif view == 'sagittal':
        view = tuple([slice(384), slice(312), slice(slice_idx, slice_idx + 1)])

    stats_text = 'min: {:.2f}, max: {:.2f}, mean: {:.2f}, std: {:.2f}'
    primary_metrics_text = 'PSNR: {:.3f}, SSIM: {:.3f}, CNR (V-WM, V-L): {:.3f}, {:.3f}'
    secondary_metrics_text = 'Luminance: {:.3f}, Contrast: {:.3f}, Structure: {:.3f}'

    epi_bottom_text = stats_text
    caipi_bottom_text = primary_metrics_text + '\n' + secondary_metrics_text + '\n' + stats_text


    im, slc_im = data['epi'].astype(np.float64), np.squeeze(data['epi'][view])
    axis[0, 0].imshow(slc_im)
    axis[0, 0].set_title(f'{subj_id}')
    axis[0, 0].set(xlabel=epi_bottom_text.format(
            np.min(im), np.max(im), np.mean(im), np.std(im)
    ))

    for i, acc in zip(range(1,4), ['1x2', '1x3', '2x2']):
        im, slc_im = data[f'reg{acc}'].astype(np.float64), np.squeeze(data[f'reg{acc}'][view])
        axis[0, i].imshow(slc_im)
        axis[0, i].set_title(f'Registered CAIPI{acc}')
        axis[0, i].set(xlabel=caipi_bottom_text.format(
                metrics[acc][0]['psnr'], metrics[acc][0]['ssim'], metrics[acc][0]['cnr_vw'], metrics[acc][0]['cnr_lv'],
                metrics[acc][0]['luminance'], metrics[acc][0]['contrast'], metrics[acc][0]['structure'],
                np.min(im), np.max(im), np.mean(im), np.std(im)
        ))

    im, slc_im = data['epi'].astype(np.float64), np.squeeze(data['epi'][view])
    axis[1, 0].imshow(slc_im)
    axis[1, 0].set_title(f'{subj_id}')
    axis[1, 0].set(xlabel=epi_bottom_text.format(
            np.min(im), np.max(im), np.mean(im), np.std(im)
    ))

    for i, acc in zip(range(1,4), ['1x2', '1x3', '2x2']):
        im, slc_im = data[f'run{acc}'][1].astype(np.float64), np.squeeze(data[f'run{acc}'][1][view])
        axis[1, i].imshow(slc_im)
        axis[1, i].set_title(f'Denoised')
        axis[1, i].set(xlabel=caipi_bottom_text.format(
                metrics[acc][1]['psnr'], metrics[acc][1]['ssim'], metrics[acc][1]['cnr_vw'], metrics[acc][1]['cnr_lv'],
                metrics[acc][1]['luminance'], metrics[acc][1]['contrast'], metrics[acc][1]['structure'],
                np.min(im), np.max(im), np.mean(im), np.std(im)
        ))

In [None]:
print(subj_ids)

In [None]:
make_plot('1_01_016-V1', 'sagittal', 128)