In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import pickle as pickle # for saving loss objects

import dataset as dd # custom dataset class
import models as md

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

In [2]:
model_params_1 = [{'tag': 'pooling_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet}]

model_params_2 = [{'tag': 'pooling_no_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_no_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'kaist', 'do_channel_augmentation': False, 'model_fn': md.get_kaist_unet}
                ]

model_params = model_params_1 + model_params_2

In [3]:
# cell for going backwards (loading data)
from keras.models import load_model

# load results
results = []

for model_param in model_params:
    print(model_param)
    save_path_model = 'models/' + model_param['tag'] + '.h5'    
    save_path_loss_object = 'models/' + model_param['tag'] + '_loss' + '.pkl'
    
    model = load_model(save_path_model)
    loss_dict = pickle.load(open(save_path_loss_object, 'rb'), encoding='latin1') # latin1 is required if loading obj from python2

    results.append((model_param, model, loss_dict))    

Using TensorFlow backend.


{'tag': 'pooling_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': True, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
{'tag': 'no_pooling_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': True, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
{'tag': 'pooling_no_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': False, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
{'tag': 'no_pooling_no_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': False, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
{'tag': 'kaist', 'do_channel_augmentation': False, 'model_fn': <function get_kaist_unet at 0x7f4f4625ac80>}


In [4]:
# plot train loss

legend = []
fig = plt.figure()
for result in results:
    
    model_param, model, loss_dict = result
    
    legend = legend + [model_param['tag']]    
    plt.plot(np.log10(loss_dict['train_losses_epoch']))   
   
plt.legend(legend)
plt.title('train log loss')
plt.ylim([-5, -3])
plt.xlabel('Epoch Number')
plt.ylabel('Mean Squared Error')
plt.show()


# plot test loss
fig = plt.figure()
for result in results:
    
    model_param, model, loss_dict = result
    plt.plot(np.arange(0, len(loss_dict['test_losses'])) * 9, np.log10(loss_dict['test_losses']))

    
plt.legend(legend)
plt.title('test log loss')
plt.ylim([-5, -3])
plt.xlabel('Epoch Number')
plt.ylabel('Mean Squared Error')
plt.show()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
# predict test images and calculate ssim
from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_mse as compare_mse

def calc_volume_ssim(volume1, volume2):
    volume_ssim = []

    for slice_index in range(volume1.shape[0]):
        img1 = np.squeeze(volume1[slice_index, :, :])
        img2 = np.squeeze(volume2[slice_index, :, :])
        volume_ssim = volume_ssim + [ssim(img1, img2, data_range=1)]
        
    return volume_ssim

def calc_volume_mse(volume1, volume2):
    
    volume_mse = []
    for slice_index in range(volume1.shape[0]):
        img1 = np.squeeze(volume1[slice_index, :, :])
        img2 = np.squeeze(volume2[slice_index, :, :])
        volume_mse = volume_mse + [compare_mse(img1, img2)]
        
    return volume_mse

testing_scans = [5]
results_with_model_output = []

for result in results:
    model_param, model, loss_dict = result
    print(model_param)
    
    generator_test = dd.MRImageSequence(scan_numbers=testing_scans, batch_size=10, augment_channels=model_param['do_channel_augmentation'])    
    
    model_output = model.predict(generator_test.x_transformed[0], batch_size = 10)   
    
    volume_ssim = calc_volume_ssim(model_output, generator_test.y_transformed[0])
    volume_mse = calc_volume_mse(model_output, generator_test.y_transformed[0])
    
    recon_metrics = {}
    recon_metrics['ssim'] = volume_ssim
    recon_metrics['mse'] = volume_mse
    
    results_with_model_output = results_with_model_output + [(model_param, model, loss_dict, model_output, recon_metrics)]
    

{'tag': 'pooling_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': True, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
loading scan  5
X shape:  (320, 320, 256, 16)
y shape:  (320, 320, 256, 1)
augment_images:  False
{'tag': 'no_pooling_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': True, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
loading scan  5
X shape:  (320, 320, 256, 16)
y shape:  (320, 320, 256, 1)
augment_images:  False
{'tag': 'pooling_no_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': False, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
loading scan  5
X shape:  (320, 320, 256, 8)
y shape:  (320, 320, 256, 1)
augment_images:  False
{'tag': 'no_pooling_no_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': False, 'model_fn': <function get_unet at 0x7f4f4625ad08>}
loading scan  5
X shape:  (320, 320, 256, 8)
y shape:  (320, 320, 256, 1)
augment_images:  False
{'tag': 'kaist', 'do_channel_augmentat

In [6]:
# plot test images
slice_to_show = 120

nx, ny, nz, _ = model_output.shape

compound_image = np.zeros((ny, nz * len(results)))
compound_image_diff = np.zeros((ny, nz * len(results)))
order = '| '

for idx, result in enumerate(results_with_model_output):
    model_param, model, loss_dict, model_output, volume_ssim = result
    
    z_min = idx * nz
    z_max = (idx + 1) * nz
    
    order = order + model_param['tag'] + ' | '
    
    compound_image[:, z_min:z_max] = np.squeeze(model_output[slice_to_show, :, :])
    
    diff = 10 * np.abs(np.squeeze(model_output[slice_to_show, :, :]) - np.squeeze(generator_test.y_transformed[0][slice_to_show, :, :]))
    compound_image_diff[:, z_min:z_max] = np.squeeze(diff)
    
fig = plt.figure(figsize=(10, 3))    
plt.imshow(compound_image, cmap='gray')
plt.title('test predictions \n ' + order)
plt.axis('off')
plt.show()

fig = plt.figure(figsize=(10, 3))
plt.imshow(compound_image_diff, cmap='gray', vmin=0, vmax=1)
plt.title('test predictions diff x 10')
plt.axis('off')
plt.show()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
# 3D slice viewer from https://www.datacamp.com/community/tutorials/matplotlib-3d-volumetric-data
def remove_keymap_conflicts(new_keys_set):
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)

def multi_slice_viewer(volume):
    remove_keymap_conflicts({'j', 'k'})
    
    fig = plt.figure(figsize=(9, 3))    
    ax = plt.gca()    
    ax.volume = volume
    ax.index = volume.shape[0] // 2 # start in middle slice
    ax.imshow(volume[ax.index], cmap='gray', vmin=0, vmax=1)   
    update_ax(ax)    
    plt.axis('off')
    fig.canvas.mpl_connect('key_press_event', process_key)  
    
def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'j':
        previous_slice(ax)
    elif event.key == 'k':
        next_slice(ax)
    fig.canvas.draw()

def update_ax(ax):    
    volume = ax.volume
    ax.index = ax.index % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])
    ax.set_title('Slice: ' + str(ax.index))
    
def previous_slice(ax):
    ax.index = (ax.index - 1)   # wrap around using %
    update_ax(ax)

def next_slice(ax):
    ax.index = (ax.index + 1)
    update_ax(ax)

In [8]:
''' Plot 3d scroller viewer and diff images '''
compound_volume = np.zeros((nx, ny, nz * len(results)))
compound_volume_diff = np.zeros((nx, ny, nz * len(results)))

for idx, result in enumerate(results_with_model_output):
    model_param, model, loss_dict, model_output, volume_ssim = result
    
    z_min = idx * nz
    z_max = (idx + 1) * nz
    
    compound_volume[:, :, z_min:z_max] = np.squeeze(model_output[:, :, :])
    
    diff = 10 * np.abs(np.squeeze(model_output[:, :, :]) - np.squeeze(generator_test.y_transformed[0][:, :, :]))
    compound_volume_diff[:, :, z_min:z_max] = np.squeeze(diff)    

In [9]:
multi_slice_viewer(np.concatenate((compound_volume, compound_volume_diff), axis=1))

<IPython.core.display.Javascript object>

In [10]:
''' Plot ssim '''
fig = plt.figure()
for result in results_with_model_output:
    
    model_param, model, loss_dict, model_output, recon_metrics = result
    
    legend = legend + [model_param['tag']]    
    plt.plot(recon_metrics['ssim'])   
   
plt.legend(legend)
plt.title('SSIM')
plt.xlim([0,  nx])
plt.xlabel('Slice')
plt.ylabel('SSIM')
plt.show()

''' Plot mse '''
fig = plt.figure()
for result in results_with_model_output:
    
    model_param, model, loss_dict, model_output, recon_metrics = result
    
    legend = legend + [model_param['tag']]    
    plt.plot(np.log10(recon_metrics['mse']))
   
plt.legend(legend)
plt.title('MSE')
plt.xlim([0,  nx])
plt.xlabel('Slice')
plt.ylabel('log10 MSE')
plt.show()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>