In [17]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import pickle as pickle # for saving loss objects

import dataset as dd # custom dataset class
import models as md

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# Notebook cell for running all models we are interested in for spotlight talk trained over many epochs
epochs_to_train = 10

model_params = [{'tag': 'pooling_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': True, 
                 'model_fn': md.get_unet},
                {'tag': 'pooling_no_channel_aug_small', 'use_pool': True, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'no_pooling_no_channel_aug_small', 'use_pool': False, 'do_channel_augmentation': False, 
                 'model_fn': md.get_unet},
                {'tag': 'kaist', 'do_channel_augmentation': False, 'model_fn': md.get_kaist_unet}
               ]

for model_param in model_params:
    
    # make generators
    print(model_param['tag'])
    training_scans = [1, 2, 3, 4]
    testing_scans = [5]
    
    generator_train = dd.MRImageSequence(scan_numbers=training_scans, batch_size=10, 
                                         augment_channels=model_param['do_channel_augmentation'],
                                         augment_images=True)
    generator_test = dd.MRImageSequence(scan_numbers=testing_scans, batch_size=10, augment_channels=model_param['do_channel_augmentation'])    
        
    # make model
    input_shape = generator_test.x_transformed[0].shape[1:]
    inputs = tf.keras.layers.Input(shape=input_shape)

    out = model_param['model_fn'](inputs, [(2, 32), (3, 64)], use_pool=model_param['use_pool'])
    model = tf.keras.models.Model(inputs=inputs, outputs=out)
    
    # make callback
    history_callback = dd.LossHistory(test_data=(generator_test.x_transformed[0], generator_test.y_transformed[0]))
    tb_callback = tf.keras.callbacks.TensorBoard(log_dir='/home/pkllee/tmp/')

    adam_optimizer = tf.keras.optimizers.Adam(lr=0.001, decay=0.01)
    
    model.compile(optimizer=adam_optimizer, loss='mean_squared_error', metrics=['mse'])
    
    # train model
    model.fit_generator(generator_train, callbacks=[history_callback, tb_callback], epochs=epochs_to_train, 
                    use_multiprocessing=False) # use_multiprocessing=True is slower by about 50% compared to model.fit() so set it to False
    
    
    # save model
    save_path_model = 'models/' + model_param['tag'] + '.h5'
    print(save_path_model)
    model.save(save_path_model)
    
    # save loss object
    loss_dict = {'train_losses_batch': history_callback.train_losses_batch,
                 'train_losses_epoch': history_callback.train_losses_epoch,
                 'test_losses': history_callback.test_losses}
    
    save_path_loss_object = 'models/' + model_param['tag'] + '_loss' + '.pkl'
    with open(save_path_loss_object, 'wb') as output:
        pickle.dump(loss_dict, output, pickle.HIGHEST_PROTOCOL)


pooling_channel_aug_small
('loading scan ', 1)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('augment_images: ', True)
('loading scan ', 5)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('augment_images: ', False)
get_unet
('use_pool: ', True)
('gen_fn: ', 'gen_conv_relu')
('unet_shape: ', [(2, 32), (3, 64)])
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
models/pooling_channel_aug_small.h5
no_pooling_channel_aug_small
('loading scan ', 1)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('augment_images: ', True)
('loading scan ', 5)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('augment_images: ', False)
get_unet
('use_pool: ', False)
('gen_fn: ', 'gen_conv_relu')
('unet_shape: ', [(2, 32), (3, 64)])
Epoch 1/10
 5/32 [===>..........................] - ETA: 21s - loss: 8.5162e-04 - mean_squared_error: 8.5162e-04