In [1]:
import glob
import imageio
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split

from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam, SGD

from importlib.util import find_spec
import sys

if find_spec("keras_unet") is None:
    sys.path.append('/scratch/cloned_repositories/keras-unet/')

import keras_unet
from keras_unet.utils import get_augmented
from keras_unet.models import custom_unet

from keras_unet.metrics import iou, iou_thresholded
from keras_unet.losses import jaccard_distance

from keras_unet.utils import plot_segm_history
from keras_unet.utils import plot_imgs

if find_spec("losses") is None:
    sys.path.append('..')
import losses

-----------------------------------------
keras-unet init: TF version is >= 2.0.0 - using `tf.keras` instead of `Keras`
-----------------------------------------


In [2]:
random_seed = 42
dataset_version = 'ds210324'

In [3]:
weights_dir = os.path.join('/scratch/fibro_arrhythm_data', dataset_version, 'weights')
print(weights_dir)

/scratch/fibro_arrhythm_data/ds210324/weights


In [4]:
image_filenames = glob.glob(os.path.join("/scratch/fibro_arrhythm_data", dataset_version, "textures/*.npy"))
label_filenames = glob.glob(os.path.join("/scratch/fibro_arrhythm_data", dataset_version, "labels/*.npy"))

image_filenames.sort()
label_filenames.sort()

In [5]:
images = []
labels = []

for image_filename, label_filename in zip(image_filenames, label_filenames):
    image = np.load(image_filename)
    image = np.pad(image, (1, 1), 'constant')  # TODO: check this or avoid this by saving the data differently
    images.append(image)
    
    label = np.load(label_filename)
    label = np.pad(label, (1, 1), 'constant')
    labels.append(label)
    
    if len(images) > 9250:
        break

images = np.array(images, dtype=np.float32)
labels = np.array(labels, dtype=np.float32)

images = np.expand_dims(images, axis=-1)
labels = np.expand_dims(labels, axis=-1)

print(images.min(), images.max(), "    ", labels.min(), labels.max())
print(images.shape, "    ", labels.shape)

0.0 1.0      0.0 1.0
(9251, 256, 256, 1)      (9251, 256, 256, 1)


In [6]:
x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size=0.1, random_state=random_seed)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.111111, random_state=random_seed) # 0.111111 x 0.9 = 0.1

In [7]:
print("x_train: ", x_train.shape)
print("y_train: ", y_train.shape)
print("x_val: ", x_val.shape)
print("y_val: ", y_val.shape)
print("x_test: ", x_test.shape)
print("y_test: ", y_test.shape)

x_train:  (7400, 256, 256, 1)
y_train:  (7400, 256, 256, 1)
x_val:  (925, 256, 256, 1)
y_val:  (925, 256, 256, 1)
x_test:  (926, 256, 256, 1)
y_test:  (926, 256, 256, 1)


In [8]:
train_gen = get_augmented(
    x_train, y_train, batch_size=2,
    data_gen_args = dict(
        horizontal_flip=False,
        vertical_flip=False,
        fill_mode='constant'
    ))

In [9]:
input_shape = x_train[0].shape

model = custom_unet(
    input_shape,
    use_batch_norm=False,
    num_classes=1,
    filters=64,
    dropout=0.2,
    output_activation='sigmoid'
)

In [10]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 640         input_1[0][0]                    
__________________________________________________________________________________________________
spatial_dropout2d (SpatialDropo (None, 256, 256, 64) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 256, 256, 64) 36928       spatial_dropout2d[0][0]          
______________________________________________________________________________________________

In [11]:
model_version = 'custom_unet_v0.0.0.2_ftloss_adam_9250samples'
callback_checkpoint = ModelCheckpoint(
    weights_dir + '/weights_' + model_version + '/weights.{epoch:02d}-{val_loss:.2f}.h5', 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=False,
)

In [12]:
def recall(y_true, y_pred):
    _, recall = losses.confusion(y_true, y_pred)
    return recall

In [13]:
model.compile(
    optimizer=Adam(), 
    #optimizer=SGD(lr=0.01, momentum=0.99),
    loss=losses.focal_tversky,
    #loss=jaccard_distance,
    metrics=[iou, iou_thresholded, recall]
)

In [None]:
history = model.fit_generator(
    train_gen,
    steps_per_epoch=x_train.shape[0],
    epochs=25,    
    validation_data=(x_val, y_val),
    callbacks=[callback_checkpoint]
)
model.save(weights_dir + '/' + model_version + '.h5')

Instructions for updating:
Please use Model.fit, which supports generators.
  ...
    to  
  ['...']
Train for 7400 steps, validate on 925 samples
Epoch 1/25
Epoch 00001: saving model to /scratch/fibro_arrhythm_data/ds210324/weights/weights_custom_unet_v0.0.0.2_ftloss_adam_9250samples/weights.01-0.93.h5
Epoch 2/25
 491/7400 [>.............................] - ETA: 3:17:17 - loss: 0.9235 - iou: 0.0352 - iou_thresholded: 0.0352 - recall: 1.0000

In [None]:
history.history

In [None]:
plot_segm_history(history)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()

In [None]:
# model.load_weights(weights_dir + '/' + model_version + '.h5')

In [None]:
segm_beg = 10
segm_eng = 30
x_test_segment = x_test[segm_beg:segm_eng]
y_test_segment = y_test[segm_beg:segm_eng]
y_pred = model.predict(x_test_segment)

In [None]:
print(x_test_segment.shape)
print(y_test_segment.shape)
print(y_pred.shape)

In [None]:
plot_imgs(org_imgs=x_test, mask_imgs=y_test, pred_imgs=y_pred, nm_img_to_plot=20)

In [None]:
temp = y_pred[5,:,:,0]
print(temp.shape)

In [None]:
print(np.min(temp), np.max(temp))