In [15]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import dataset as dd # custom dataset class
import models as md

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

# Cleaning up variables to prevent loading data multiple times (which may cause memory issue)
try:
   del X_train, y_train
   del X_test, y_test
   print('Clear previously loaded data.')
except:
   pass

do_channel_augmentation = True
im_ref_test, im_us_test = dd.get_dataset(4)

if(do_channel_augmentation == True):
    im_us_test_aug = dd.augment_channel_image(im_us_test)
    X_test, y_test = (im_us_test_aug, im_ref_test)
else:
    X_test, y_test = (im_us_test, im_ref_test)




The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
('loading scan ', 4)


In [16]:
#print('Training data shape: ', X_train.shape)
#print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

('Test data shape: ', (320, 320, 256, 16))
('Test labels shape: ', (320, 320, 256, 1))


In [22]:
#from unetblocks import res_block, gen_conv_relu, gen_conv_bn_relu

input_shape = X_test.shape[1:]
inputs = tf.keras.layers.Input(shape=input_shape)

out = md.get_very_small_unet(inputs, use_pool=True)

model = tf.keras.models.Model(inputs=inputs, outputs=out)


get_very_small_unet
('use_pool:', True)
('gen_fn: ', 'gen_conv_relu')


In [23]:
## example from https://keras.io/callbacks/
class LossHistory(tf.keras.callbacks.Callback):
    
    def __init__(self, test_data = None):
        self.test_data = test_data
    
    def on_train_begin(self, logs={}):
        self.train_losses_batch = []
        self.train_losses_epoch = []
        self.test_losses = []

    def on_batch_end(self, batch, logs={}):
        self.train_losses_batch.append(logs.get('loss'))
        
    def on_epoch_end(self, epochs, logs={}):
        self.train_losses_epoch.append(logs.get('loss'))
        
        if (epochs % 10 == 0 and self.test_data != None):
            x, y = self.test_data
            loss, _ = self.model.evaluate(x, y, verbose=0)
            self.test_losses.append(loss)
            
        
        

history_callback = LossHistory(test_data=(X_test, y_test))
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='/home/pkllee/tmp/')

adam_optimizer = tf.keras.optimizers.Adam(lr=0.001, decay=0.01)

model.compile(optimizer=adam_optimizer, loss='mean_squared_error', metrics=['mse'])

In [24]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            (None, 320, 256, 16) 0                                            
__________________________________________________________________________________________________
conv2d_29 (Conv2D)              (None, 320, 256, 16) 2320        input_7[0][0]                    
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 320, 256, 16) 0           conv2d_29[0][0]                  
__________________________________________________________________________________________________
conv2d_30 (Conv2D)              (None, 320, 256, 16) 2320        activation_29[0][0]              
__________________________________________________________________________________________________
activation

In [25]:
#generator =  dd.MRImageSequence(scan_numbers=[1, 2, 3], batch_size=10, augment_channels=do_channel_augmentation, augment_images=True)
generator =  dd.MRImageSequence(scan_numbers=[1], batch_size=10, augment_channels=do_channel_augmentation)

('loading scan ', 1)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('augment_images: ', False)


In [28]:
model.fit_generator(generator, callbacks=[history_callback, tb_callback], epochs=100, 
                    use_multiprocessing=False) # for us, use_multiprocessing=True is slower by about 50% compared to model.fit()

#model.fit(x=X_test, y=y_test, callbacks=[history_callback, tb_callback], epochs=100, batch_size=10)

Epoch 1/100
Epoch 2/100
Epoch 3/100
 7/32 [=====>........................] - ETA: 14s - loss: 6.8163e-05 - mean_squared_error: 6.8163e-05

KeyboardInterrupt: 

In [None]:
plt.plot(np.log10(history_callback.train_losses_epoch))
plt.title('train log loss')
plt.xlabel('Epoch Number')
plt.ylabel('Mean Squared Error')
plt.show()

plt.plot(np.log10(history_callback.test_losses))
plt.title('test log loss')
plt.xlabel('10 epochs')
plt.ylabel('Mean Squared Error')
plt.show()

In [None]:
pred_test = model.predict(X_test)
to_show_ref_test = y_test
to_show_us_test = X_test

pred_train = model.predict(generator.x_transformed[0])
to_show_ref_train  = generator.y_transformed[0]
to_show_us_train = generator.x_transformed[0]


In [None]:
def show_images(slice_to_show, pred, ref, us):
    im1 = pred[slice_to_show, :, :, 0]    
    im2 = ref[slice_to_show, :, :, 0]
    im3 = dd.sos(us[slice_to_show, :, :, :], axis=2)

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(np.hstack((im1, im2, im3)), cmap='gray')
    plt.title('pred | ref | us')
    plt.axis('off')
    plt.show()

    fig = plt.figure(figsize=(5, 5))
    plt.imshow(abs(im2 - im1)* 10, cmap='gray', vmin=0, vmax=1)
    plt.title('diff x10')
    plt.axis('off')
    plt.show()

In [None]:
slice_to_show = 100
show_images(slice_to_show, pred_test, to_show_ref_test, to_show_us_test)

In [None]:
show_images(slice_to_show, pred_train, to_show_ref_train, to_show_us_train)

In [None]:
#model.save('models/very_small_unet_no_aug_kernel_1_3_no_pooling.h5')