In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import dataset as dd # custom dataset class
import unet

# so that when you change an imported file, it changes in the notebook
%load_ext autoreload 
%autoreload 2
%matplotlib notebook

# Cleaning up variables to prevent loading data multiple times (which may cause memory issue)
try:
   del X_train, y_train
   del X_test, y_test
   print('Clear previously loaded data.')
except:
   pass

im_ref, im_us = dd.get_dataset(1)
X_train, y_train = (im_us,  im_ref)

im_ref_test, im_us_test = dd.get_dataset(1)
X_test, y_test = (im_us_test, im_ref_test)



In [2]:
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)

('Training data shape: ', (320, 320, 256, 8))
('Training labels shape: ', (320, 320, 256, 1))
('Test data shape: ', (320, 320, 256, 8))
('Test labels shape: ', (320, 320, 256, 1))


In [15]:
from unetblocks import res_block, gen_conv, gen_conv_bn_relu

input_shape = X_train.shape[1:]
inputs = tf.keras.layers.Input(shape=input_shape)
f16 = {'filters': 16, 'kernel_size': (3, 3), 'strides': (1, 1), 'padding': 'same', 'activation': 'linear'}
f32 = f16.copy()
f32['filters'] = 32
f64 = f32.copy()
f64['filters'] = 64
f128 = f64.copy()
f128['filters'] = 128

gen_fn = gen_conv_bn_relu

res_out = res_block(gen_fn(**f16), gen_fn(**f16))(inputs, 
                                                      resblocks=[res_block(gen_fn(**f32), gen_fn(**f32)),
                                                                 res_block(gen_fn(**f64), gen_fn(**f64)),
                                                                         res_block(gen_fn(**f128))])
out = tf.keras.layers.Dense(1)(res_out)
model = tf.keras.models.Model(inputs=inputs, outputs=out)
## example from https://keras.io/callbacks/
class LossHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

history = LossHistory()
tb_callback = tf.keras.callbacks.TensorBoard(log_dir='/home/pkllee/tmp/')

adam_optimizer = tf.keras.optimizers.Adam(lr=0.01, beta_1 = 0.9, beta_2=0.999, decay=0.01)

model.compile(optimizer=adam_optimizer, loss='mean_squared_error', metrics=['mse'])

In [None]:
model.fit(x=X_train, y=y_train, callbacks=[history, tb_callback], epochs=200)

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200


In [None]:
plt.plot(history.losses)
plt.show()

In [None]:
pred = model.predict(X_test)

In [None]:
slice_to_show = 100

def sos(im, axis):
    return np.sqrt(np.sum(np.power(im, 2), axis=axis))

im1 = im_ref_test[slice_to_show, :, :, 0]
im2 = sos(im_us_test[slice_to_show, :, :, :], axis=2)
im3 = pred[slice_to_show, :, :, 0]


fig = plt.figure(figsize=(10, 10))
plt.imshow(np.hstack((im3, im1, im2)), cmap='gray')
plt.title('pred | ref | us')
plt.axis('off')
plt.show()

fig = plt.figure(figsize=(5, 5))
plt.imshow(im3 - im1, cmap='gray')
plt.axis('off')
plt.show()