In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import pickle as pickle # for saving loss objects

import dataset as dd # custom dataset class
import models as md

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

  from ._conv import register_converters as _register_converters


In [2]:
model_params_1 = [{'tag': 'pooling_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet}]

model_params_2 = [{'tag': 'pooling_no_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_no_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'kaist', 'do_channel_augmentation': False, 'model_fn': md.get_kaist_unet}
                ]

model_params = model_params_1 + model_params_2

In [None]:
# cell for going backwards (loading data)
from keras.models import load_model

# load results
results = []

for model_param in model_params:
    print(model_param)
    save_path_model = 'models/' + model_param['tag'] + '.h5'    
    save_path_loss_object = 'models/' + model_param['tag'] + '_loss' + '.pkl'
    
    model = load_model(save_path_model)
    loss_dict = pickle.load(open(save_path_loss_object, 'rb'))

    results.append((model_param, model, loss_dict))    

Using TensorFlow backend.


{'tag': 'pooling_channel_aug_small', 'do_channel_augmentation': True, 'use_pool': True, 'model_fn': <function get_unet at 0x7f57fa572cf8>}
{'tag': 'no_pooling_channel_aug_small', 'do_channel_augmentation': True, 'use_pool': False, 'model_fn': <function get_unet at 0x7f57fa572cf8>}
{'tag': 'pooling_no_channel_aug_small', 'do_channel_augmentation': False, 'use_pool': True, 'model_fn': <function get_unet at 0x7f57fa572cf8>}
{'tag': 'no_pooling_no_channel_aug_small', 'do_channel_augmentation': False, 'use_pool': False, 'model_fn': <function get_unet at 0x7f57fa572cf8>}
{'tag': 'kaist', 'do_channel_augmentation': False, 'model_fn': <function get_kaist_unet at 0x7f57fa572c80>}


In [None]:
# plot train loss

legend = []
for result in results:
    
    model_param, model, loss_dict = result
    
    legend = legend + [model_param['tag']]    
    plt.plot(np.log10(loss_dict['train_losses_epoch']))   

    
plt.legend(legend)
plt.title('train log loss')
plt.ylim([-5, -3])
plt.xlabel('Epoch Number')
plt.ylabel('Mean Squared Error')
plt.show()


# plot test loss

for result in results:
    
    model_param, model, loss_dict = result
    plt.plot(np.arange(0, len(loss_dict['test_losses'])) * 9, np.log10(loss_dict['test_losses']))

    
plt.legend(legend)
plt.title('test log loss')
plt.ylim([-5, -3])
plt.xlabel('Epoch Number')
plt.ylabel('Mean Squared Error')
plt.show()

In [None]:
# predict test images
testing_scans = [6]
results_with_model_output = []

for result in results:
    model_param, model, loss_dict = result
    print(model_param)
    
    generator_test = dd.MRImageSequence(scan_numbers=testing_scans, batch_size=10, augment_channels=model_param['do_channel_augmentation'])    
    
    model_output = model.predict(generator_test.x_transformed[0], batch_size = 10)    
    
    results_with_model_output = results_with_model_output + [(model_param, model, loss_dict, model_output)]
    

In [None]:
# plot test images
slice_to_show = 120

nx, ny, nz, _ = model_output.shape

compound_image = np.zeros((ny, nz * len(results)))
compound_image_diff = np.zeros((ny, nz * len(results)))
order = '| '

for idx, result in enumerate(results_with_model_output):
    model_param, model, loss_dict, model_output = result
    
    z_min = idx * nz
    z_max = (idx + 1) * nz
    
    order = order + model_param['tag'] + ' | '
    
    compound_image[:, z_min:z_max] = np.squeeze(model_output[slice_to_show, :, :])
    
    diff = 10 * np.abs(np.squeeze(model_output[slice_to_show, :, :]) - np.squeeze(generator_test.y_transformed[0][slice_to_show, :, :]))
    compound_image_diff[:, z_min:z_max] = np.squeeze(diff)
    
fig = plt.figure(figsize=(20, 20))    
plt.imshow(compound_image, cmap='gray')
plt.title('test predictions \n ' + order)
plt.axis('off')
plt.show()

fig = plt.figure(figsize=(20, 20))
plt.imshow(compound_image_diff, cmap='gray', vmin=0, vmax=1)
plt.title('test predictions diff x 10')
plt.axis('off')
plt.show()

In [None]:
# 3D slice viewer from https://www.datacamp.com/community/tutorials/matplotlib-3d-volumetric-data
def remove_keymap_conflicts(new_keys_set):
    for prop in plt.rcParams:
        if prop.startswith('keymap.'):
            keys = plt.rcParams[prop]
            remove_list = set(keys) & new_keys_set
            for key in remove_list:
                keys.remove(key)

def multi_slice_viewer(volume):
    remove_keymap_conflicts({'u', 'i'})
    fig, ax = plt.subplots()
    ax.volume = volume
    ax.index = volume.shape[0] // 2
    ax.imshow(volume[ax.index])
    fig.canvas.mpl_connect('key_press_event', process_key)

def process_key(event):
    fig = event.canvas.figure
    ax = fig.axes[0]
    if event.key == 'u':
        previous_slice(ax)
    elif event.key == 'i':
        next_slice(ax)
    fig.canvas.draw()

def previous_slice(ax):
    volume = ax.volume
    ax.index = (ax.index - 1) % volume.shape[0]  # wrap around using %
    ax.images[0].set_array(volume[ax.index])

def next_slice(ax):
    volume = ax.volume
    ax.index = (ax.index + 1) % volume.shape[0]
    ax.images[0].set_array(volume[ax.index])

In [None]:
compound_volume = np.zeros((nx, ny, nz * len(results)))

for idx, result in enumerate(results_with_model_output):
    model_param, model, loss_dict, model_output = result
    
    z_min = idx * nz
    z_max = (idx + 1) * nz
    
    order = order + model_param['tag'] + ' | '
    
    compound_volume[:, :, z_min:z_max] = np.squeeze(model_output[:, :, :])

In [None]:

multi_slice_viewer(compound_volume)