In [1]:
from __future__ import print_function
from ipywidgets import interact, interactive, fixed, interact_manual
import os
import random
import warnings
from glob import glob
import ipywidgets as widgets
%matplotlib inline

from lib import *

import numpy as np
import pandas as pd

from itertools import chain
import matplotlib.pyplot as plt
from skimage.transform import resize
from skimage.morphology import label
from skimage.io import imread, imshow, imread_collection, concatenate_images

from keras.models import Model, load_model
from keras.layers import Input
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.merge import concatenate
from keras.layers.pooling import MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K

import tensorflow as tf

# Set some parameters
IMG_WIDTH = 240
IMG_HEIGHT = 320
IMG_CHANNELS = 3

seed = 42
random.seed = seed
np.random.seed = seed

warnings.filterwarnings('ignore')

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


# KERAS model of U-net and learning

In [2]:
# Build U-Net model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = Lambda(lambda x: x / 255) (inputs)

c1 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)
c1 = Dropout(0.1) (c1)
c1 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c1)
p1 = MaxPooling2D((2, 2)) (c1)

c2 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p1)
c2 = Dropout(0.1) (c2)
c2 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c2)
p2 = MaxPooling2D((2, 2)) (c2)

c3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p2)
c3 = Dropout(0.2) (c3)
c3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c3)
p3 = MaxPooling2D((2, 2)) (c3)

c4 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p3)
c4 = Dropout(0.2) (c4)
c4 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c4)
p4 = MaxPooling2D(pool_size=(2, 2)) (c4)

c5 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p4)
c5 = Dropout(0.3) (c5)
c5 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c5)

u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u6)
c6 = Dropout(0.2) (c6)
c6 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c6)

u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u7)
c7 = Dropout(0.2) (c7)
c7 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c7)

u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u8)
c8 = Dropout(0.1) (c8)
c8 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c8)

u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u9)
c9 = Dropout(0.1) (c9)
c9 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c9)

outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)
model = Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics = ['accuracy'])
#model.summary()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [3]:
# Get train imgs and masks
X_train = Get_IMGs('data\\train\\') 
Y_train = Get_Masks('data\\train_mask\\')

100%|██████████████████████████████████████████████████████████████████████████████| 1315/1315 [00:22<00:00, 57.60it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 1315/1315 [00:07<00:00, 172.72it/s]


In [4]:


# Fit model
earlystopper = EarlyStopping(patience=5, verbose=1)
checkpointer = ModelCheckpoint('model_1.h5', verbose=1, save_best_only=True)
#results = model.fit(X_train, Y_train, validation_split=0.1, batch_size=16, epochs=50, 
#                    callbacks=[checkpointer], verbose=0)

In [5]:
datagen = ImageDataGenerator(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='constant'
)
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=10)

Instructions for updating:
Use tf.cast instead.
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10

KeyboardInterrupt: 

In [28]:


for e in range(10):
    print('Epoch', e)
    batches = 0
    for x_batch, y_batch in datagen.flow(X_train, Y_train, batch_size=32):
        results = model.fit(X_train, Y_train, validation_split=0.1,
                    callbacks=[checkpointer], verbose=0)
        batches += 1
        if batches >= len(X_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break

Epoch 0

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epoch 00001: val_loss did not improve from 0.13791

Epo

KeyboardInterrupt: 

# Results

Test model on validation images, save to csv and get Dice metric

In [25]:
model = load_model('model_1.h5', custom_objects={'mean_iou': mean_iou})
valid_imgs = Get_IMGs('data\\valid\\') #get np.array with images
valid_true_masks = Get_Masks('data\\valid_mask\\')#get np.array with masks

valid_pred_masks = (model.predict(valid_imgs)> 0.5).astype(np.uint8) #get predicted masks

# Save valid rle_masks to pred_valid_template.csv
df = pd.DataFrame({
        'id': [int(x.split('.')[0]) for x in next(os.walk('data\\valid\\'))[2]],
        'rle_mask': [encode_rle(mask) for mask in np.squeeze(valid_pred_masks)]
})
df.to_csv('data/pred_valid_template.csv',index=False)  #save to csv

#Check Dice metric
Dice_metric = get_dice(np.squeeze(valid_true_masks), np.squeeze(valid_pred_masks))
print('Dice metric for valid imgs:', Dice_metric)

100%|████████████████████████████████████████████████████████████████████████████████| 145/145 [00:02<00:00, 57.16it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 145/145 [00:00<00:00, 172.98it/s]


Dice metric for valid imgs: 0.9512201277679161


# Dice metric for valid imgs: 0.951

Show image, true mask and predicted mask

In [26]:
@interact(i=(0,len(valid_imgs),1))
def g(i=0):
    plt.figure(figsize=(10,4))
    
    plt.subplot(1,3,1)
    plt.title('image')
    plt.imshow(valid_imgs[i,:,:])
    
    plt.subplot(1,3,2)
    plt.title('true mask')
    plt.imshow(valid_true_masks[i,:,:,0])
    
    plt.subplot(1,3,3)
    plt.title('mask predicted')
    plt.imshow(valid_pred_masks[i,:,:,0])

interactive(children=(IntSlider(value=0, description='i', max=145), Output()), _dom_classes=('widget-interact'…

Get masks from test images and create examples.html 

In [6]:
test_imgs = Get_IMGs('data/test/') 

# Predict masks for test imgs
model = load_model('model_1.h5', custom_objects={'mean_iou': mean_iou})
test_pred_masks = model.predict(test_imgs)
test_pred_masks = np.squeeze(test_pred_masks > 0.5).astype(np.uint8)*255 
# without *255 masks, putted in example.html, will look like black rectangle

# Save to html
paths_to_imgs = sorted(glob("data\\test\\*"))
_ = get_html(paths_to_imgs, test_pred_masks, path_to_save="results\\example")

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 74.56it/s]
