In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import dataset as dd # custom dataset class

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

# Cleaning up variables to prevent loading data multiple times (which may cause memory issue)
try:
   del X_train, y_train
   del X_test, y_test
   print('Clear previously loaded data.')
except:
   pass

#im_ref, im_us = dd.get_dataset(1)
#im_us_aug = dd.augment_channel_image(im_us)
#X_train, y_train = (im_us_aug,  im_ref)
#X_train, y_train = (im_us,  im_ref)

do_channel_augmentation = True
im_ref_test, im_us_test = dd.get_dataset(4)

if(do_channel_augmentation == True):
    im_us_test_aug = dd.augment_channel_image(im_us_test)
    X_test, y_test = (im_us_test_aug, im_ref_test)
else:
    X_test, y_test = (im_us_test, im_ref_test)




  from ._conv import register_converters as _register_converters


('loading scan ', 4)


In [2]:
#print('Training data shape: ', X_train.shape)
#print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

('Test data shape: ', (320, 320, 256, 16))
('Test labels shape: ', (320, 320, 256, 1))


In [71]:
from unetblocks import res_block, gen_conv_relu, gen_conv_bn_relu

input_shape = X_test.shape[1:]
inputs = tf.keras.layers.Input(shape=input_shape)

gen_conv_params = lambda num_filters : {'filters': num_filters, 'kernel_size': (3, 3), 'strides': (1, 1), 'padding': 'same'}
f16 = gen_conv_params(16)
f32 = gen_conv_params(32)
f64 = gen_conv_params(64)
f128 = gen_conv_params(128)
f256 = gen_conv_params(256)
f512 = gen_conv_params(512)
f1024 = gen_conv_params(1024)

#gen_fn = gen_conv_bn_relu
gen_fn = gen_conv_relu

# very small unet no pooling
'''
res_out = res_block(gen_fn(N=2, **f16), gen_fn(N=2, **f16), use_pool=False)(inputs, 
                                                      resblocks=[res_block(gen_fn(N=3, **f32))])
'''
# very small unet with pooling

res_out = res_block(gen_fn(N=2, **f16), gen_fn(N=2, **f16))(inputs, 
                                                      resblocks=[res_block(gen_fn(N=3, **f32))])

# small unet
'''
res_out = res_block(gen_fn(N=2, **f16), gen_fn(N=2, **f16))(inputs, 
                                                      resblocks=[res_block(gen_fn(N=3, **f32), gen_fn(N=3, **f32)),
                                                                 res_block(gen_fn(N=3, **f64), gen_fn(N=3, **f64)),
                                                                         res_block(gen_fn(N=3, **f128))])
'''

# big unet
'''
res_out = res_block(gen_fn(N=2, **f32), gen_fn(N=2, **f32))(inputs, 
                                                      resblocks=[res_block(gen_fn(N=2, **f64), gen_fn(N=2, **f64)),
                                                                 res_block(gen_fn(N=2, **f128), gen_fn(N=2, **f128)),
                                                                 res_block(gen_fn(N=2, **f256), gen_fn(N=2, **f256)),
                                                                         res_block(gen_fn(N=2, **f512))])
'''

out = tf.keras.layers.Dense(1)(res_out)
model = tf.keras.models.Model(inputs=inputs, outputs=out)
## example from https://keras.io/callbacks/
class LossHistory(tf.keras.callbacks.Callback):
    
    def __init__(self, test_data = None):
        self.test_data = test_data
    
    def on_train_begin(self, logs={}):
        self.train_losses_batch = []
        self.train_losses_epoch = []
        self.test_losses = []

    def on_batch_end(self, batch, logs={}):
        self.train_losses_batch.append(logs.get('loss'))
        
    def on_epoch_end(self, epochs, logs={}):
        self.train_losses_epoch.append(logs.get('loss'))
        
        if (epochs % 10 == 0 and self.test_data != None):
            x, y = self.test_data
            loss, _ = self.model.evaluate(x, y, verbose=0)
            self.test_losses.append(loss)
            
        
        

history_callback = LossHistory(test_data=(X_test, y_test))
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='/home/pkllee/tmp/')

#adam_optimizer = tf.keras.optimizers.Adam(lr=0.1, beta_1 = 0.9, beta_2=0.999, decay=0.1)
adam_optimizer = tf.keras.optimizers.Adam(lr=0.001, decay=0.01)

model.compile(optimizer=adam_optimizer, loss='mean_squared_error', metrics=['mse'])

In [72]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            (None, 320, 256, 16) 0                                            
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 320, 256, 16) 2320        input_9[0][0]                    
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 320, 256, 16) 0           conv2d_57[0][0]                  
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 320, 256, 16) 2320        activation_57[0][0]              
__________________________________________________________________________________________________
activation

In [73]:
generator =  dd.MRImageSequence(scan_numbers=[1, 2, 3], batch_size=10, augment_channels=do_channel_augmentation)

('loading scan ', 1)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('loading scan ', 2)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))
('loading scan ', 3)
('X shape: ', (320, 320, 256, 16))
('y shape: ', (320, 320, 256, 1))


In [None]:
#model.fit(x=X_train, y=y_train, callbacks=[history_callback, tb_callback], epochs=100, batch_size=10)

model.fit_generator(generator, callbacks=[history_callback, tb_callback], epochs=1, use_multiprocessing=True)

Epoch 1/1
(2, 40, 50)
(2, 60, 70)
(1, 60, 70)
 2/96 [..............................] - ETA: 2:02 - loss: 0.0058 - mean_squared_error: 0.0058(0, 190, 200)
 3/96 [..............................] - ETA: 1:54 - loss: 0.0060 - mean_squared_error: 0.0060(2, 130, 140)
 4/96 [>.............................] - ETA: 1:46 - loss: 0.0055 - mean_squared_error: 0.0055(0, 180, 190)
 5/96 [>.............................] - ETA: 1:36 - loss: 0.0052 - mean_squared_error: 0.0052(2, 110, 120)
 6/96 [>.............................] - ETA: 1:30 - loss: 0.0048 - mean_squared_error: 0.0048(1, 120, 130)
 7/96 [=>............................] - ETA: 1:25 - loss: 0.0044 - mean_squared_error: 0.0044(0, 0, 10)
 8/96 [=>............................] - ETA: 1:21 - loss: 0.0043 - mean_squared_error: 0.0043(1, 40, 50)
 9/96 [=>............................] - ETA: 1:17 - loss: 0.0041 - mean_squared_error: 0.0041(0, 70, 80)
10/96 [==>...........................] - ETA: 1:14 - loss: 0.0038 - mean_squared_error: 0.0038(0,

In [None]:
plt.plot(np.log10(history_callback.train_losses_epoch))
plt.title('train log loss')
plt.xlabel('Batch Number')
plt.ylabel('Mean Squared Error')
plt.show()

plt.plot(np.log10(history_callback.test_losses))
plt.title('test log loss')
plt.xlabel('10 epochs')
plt.ylabel('Mean Squared Error')
plt.show()

In [None]:
pred_test = model.predict(X_test)
to_show_ref_test = y_test
to_show_us_test = X_test

pred_train = model.predict(generator.x_ref)
to_show_ref_train  = generator.y_ref
to_show_us_train = generator.x_ref

#mean_squared_error = tf.keras.losses.mean_squared_error(pred_test, to_show_ref_test)
#print(mean_squared_error)

In [None]:
def show_images(slice_to_show, pred, ref, us):
    im1 = pred[slice_to_show, :, :, 0]    
    im2 = ref[slice_to_show, :, :, 0]
    im3 = dd.sos(us[slice_to_show, :, :, :], axis=2)

    fig = plt.figure(figsize=(10, 10))
    plt.imshow(np.hstack((im1, im2, im3)), cmap='gray')
    plt.title('pred | ref | us')
    plt.axis('off')
    plt.show()

    fig = plt.figure(figsize=(5, 5))
    plt.imshow(abs(im2 - im1)* 10, cmap='gray', vmin=0, vmax=1)
    plt.title('diff x10')
    plt.axis('off')
    plt.show()

In [None]:
slice_to_show = 130
show_images(slice_to_show, pred_test, to_show_ref_test, to_show_us_test)

In [None]:
show_images(slice_to_show, pred_train, to_show_ref_train, to_show_us_train)

In [12]:
#model.save('models/very_small_unet_no_aug_kernel_1_3_no_pooling.h5')