In [1]:
from __future__ import print_function
%matplotlib inline

In [2]:
import os, h5py
from skimage.transform import resize
from skimage.io import imsave, imshow
from skimage.exposure import equalize_adapthist

import numpy as np
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers.merge import concatenate
from keras.optimizers import Adam
from keras.layers.normalization import BatchNormalization

from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import backend as K

import matplotlib.pyplot as plt

Using TensorFlow backend.


In [3]:
from keras.layers import merge
from keras.layers.core import Lambda
from keras.models import Model

import tensorflow as tf

def make_parallel(model, gpu_count):
    def get_slice(data, idx, parts):
        shape = tf.shape(data)
        size = tf.concat([ shape[:1] // parts, shape[1:] ],axis=0)
        stride = tf.concat([ shape[:1] // parts, shape[1:]*0 ],axis=0)
        start = stride * idx
        return tf.slice(data, start, size)

    outputs_all = []
    for i in range(len(model.outputs)):
        outputs_all.append([])

    #Place a copy of the model on each GPU, each getting a slice of the batch
    for i in range(gpu_count):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('tower_%d' % i) as scope:

                inputs = []
                #Slice each input into a piece for processing on this GPU
                for x in model.inputs:
                    input_shape = tuple(x.get_shape().as_list())[1:]
                    slice_n = Lambda(get_slice, output_shape=input_shape, arguments={'idx':i,'parts':gpu_count})(x)
                    inputs.append(slice_n)                

                outputs = model(inputs)
                
                if not isinstance(outputs, list):
                    outputs = [outputs]
                
                #Save all the outputs for merging back together later
                for l in range(len(outputs)):
                    outputs_all[l].append(outputs[l])

    # merge outputs on CPU
    with tf.device('/cpu:0'):
        merged = []
        for outputs in outputs_all:
            merged.append(merge(outputs, mode='concat', concat_axis=0))
            
        return Model(input=model.inputs, output=merged)

In [4]:
FolderPath = "../../dataset/ultrasound_nerve_segmentation/"

K.set_image_dim_ordering('tf')
#K.set_floatx('float16')

#original size : 420x580
img_rows=96
img_cols=96

f_size = 1
learning_rate = 1e-2
activation = 'elu'

smooth = 1.


In [5]:
def load_train_data():
    with h5py.File(FolderPath + 'X_Train.h5', 'r') as hf:
        imgs_train = hf['X_Train'][:]
    with h5py.File(FolderPath + 'Y_Train.h5', 'r') as hf:
        imgs_mask_train = hf['Y_Train'][:]
    
    return imgs_train, imgs_mask_train


# In[6]:

def load_test_data():
    with h5py.File(FolderPath + 'X_Test.h5', 'r') as hf:
        imgs_test = hf['X_Test'][:]
    with h5py.File(FolderPath + 'Y_Test.h5', 'r') as hf:
        imgs_mask_test = hf['Y_Test'][:]
        
    return imgs_test, imgs_mask_test

In [6]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [7]:
def get_unet():
    inputs = Input((img_rows, img_cols, 1))
    conv1 = Conv2D(3, (1, 1), activation=activation, padding='same')(inputs)

    conv1 = Conv2D(32, (3, 3), activation=activation, padding='same')(conv1)
    conv1 = Conv2D(32, (3, 3), activation=activation, padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation=activation, padding='same')(pool1)
    conv2 = Conv2D(64, (3, 3), activation=activation, padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation=activation, padding='same')(pool2)
    conv3 = Conv2D(128, (3, 3), activation=activation, padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(256, (3, 3), activation=activation, padding='same')(pool3)
    conv4 = Conv2D(256, (3, 3), activation=activation, padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    conv5 = Conv2D(512, (3, 3), activation=activation, padding='same')(pool4)
    conv5 = Conv2D(512, (3, 3), activation=activation, padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)

    up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2,2), padding='same')(conv5), conv4], axis=3)
    conv6 = Conv2D(256, (3, 3), activation=activation, padding='same')(up6)
    conv6 = Conv2D(256, (3, 3), activation=activation, padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    

    up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2,2), padding='same')(conv6), conv3], axis=3)
    conv7 = Conv2D(128, (3, 3), activation=activation, padding='same')(up7)
    conv7 = Conv2D(128, (3, 3), activation=activation, padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    
    up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2,2), padding='same')(conv7), conv2], axis=3)
    conv8 = Conv2D(64, (3, 3), activation=activation, padding='same')(up8)
    conv8 = Conv2D(64, (3, 3), activation=activation, padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    
    up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2,2), padding='same')(conv8), conv1], axis=3)
    conv9 = Conv2D(32, (3, 3), activation=activation, padding='same')(up9)
    conv9 = Conv2D(32, (3, 3), activation=activation, padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    
    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(input=inputs, output=conv10)
    model.summary()
    model = make_parallel(model, 2)
    model.compile(optimizer=Adam(lr=learning_rate), loss=dice_coef_loss, metrics=[dice_coef])
    
    return model

In [8]:
def exposure_image(X):
    Xf = np.array(X)
    for i in range(X.shape[0]):
        Xf[i] = equalize_adapthist(X[i])
    return Xf

def preprocess(X,y):
    X = (X / 255.).astype(np.float32)
    y = (y / 255.).astype(np.float32)

    X = exposure_image(X)
    
    return X,y

In [9]:
def resize_image(imgs):
    imgs_p = np.ndarray((imgs.shape[0], img_rows, img_cols), dtype=np.float32)
    for i in range(imgs.shape[0]):
        imgs_p[i] = resize(imgs[i], (img_cols, img_rows), preserve_range=True)

    imgs_p = imgs_p[..., np.newaxis]
    return imgs_p

In [10]:
model = get_unet()
model_checkpoint = ModelCheckpoint('model_UNET.hdf5', monitor='loss', save_best_only=True)
model_earlystopping = EarlyStopping(monitor='loss')



____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 96, 96, 1)     0                                            
____________________________________________________________________________________________________
conv2d_1 (Conv2D)                (None, 96, 96, 3)     6           input_1[0][0]                    
____________________________________________________________________________________________________
conv2d_2 (Conv2D)                (None, 96, 96, 32)    896         conv2d_1[0][0]                   
____________________________________________________________________________________________________
conv2d_3 (Conv2D)                (None, 96, 96, 32)    9248        conv2d_2[0][0]                   
___________________________________________________________________________________________

  name=name)


In [None]:
imgs_train, imgs_mask_train = load_train_data()

In [None]:
imgs_train, imgs_mask_train = preprocess(imgs_train, imgs_mask_train)

In [None]:
imgs_train = resize_image(imgs_train)
imgs_mask_train = resize_image(imgs_mask_train)

In [None]:
imgs_train.shape, imgs_mask_train.shape

In [None]:
model.fit(imgs_train, imgs_mask_train, batch_size=128*2, nb_epoch=20, verbose=1, shuffle=True,
          validation_split=0.2,
          callbacks=[model_checkpoint, model_earlystopping])

In [None]:
imgs_test, imgs_id_test = load_test_data()

In [None]:
imgs_test, imgs_id_test = preprocess(imgs_test, imgs_id_test)

In [None]:
imgs_test = resize_image(imgs_test)
imgs_id_test = resize_image(imgs_id_test)

In [None]:
model.load_weights('model_UNET.hdf5')

In [None]:
loss, accu = model.evaluate(imgs_test, imgs_id_test, verbose=1)
print("loss:{}%, accuracy:{}%".format(loss*100, accu*100))

In [None]:
imgs_pred_test = model.predict(imgs_test, verbose=1)

In [None]:
pred_dir = 'preds'
test_dir = 'tests'
if not os.path.exists(pred_dir):
    os.mkdir(pred_dir)
if not os.path.exists(test_dir):
    os.mkdir(test_dir)

In [None]:
# for image_id, image in enumerate(imgs_pred_test):
#     image = (image[:,:,0] * 255.).astype(np.uint8)
#     imsave(os.path.join(pred_dir, str(image_id) + '_pred.png'), image)

In [None]:
imgs_test.shape

In [None]:
def plot_compare(idx):
    plt.figure(dpi=40)
    plt.subplot(131)
    imshow(imgs_id_test[idx,:,:,0],cmap='gray')
    plt.subplot(132)
    imshow(imgs_test[idx,:,:,0],cmap='gray')
    imshow(imgs_pred_test[idx,:,:,0],cmap='jet', alpha=0.5)
    plt.subplot(133)
    imshow(imgs_id_test[idx,:,:,0] - imgs_pred_test[idx,:,:,0], cmap='gray')
    plt.savefig(os.path.join(test_dir, 'test' + str(idx) + '.png'))

In [None]:
ran_idx = np.random.randint(0, imgs_test.shape[0],10)
for i,idx in enumerate(ran_idx):
    plot_compare(idx)