In [3]:
from keras.layers import *
from keras.models import Model
from keras import layers
from keras.layers.merge import concatenate
import sys
sys.path.insert(1, '../src')
sys.path.insert(1, '../image_segmentation_keras')
from keras_segmentation.models.config import IMAGE_ORDERING

from keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob
from crfrnn_layer import CrfRnnLayer

In [4]:
input_height = 256
input_width = 256
n_classes = 3
channels = 3

In [5]:
def unet_conv_block(inputs, filters, pool=True, batch_norm_first=True):
    if batch_norm_first == True:
        x = Conv2D(filters, 3, padding="same")(inputs)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = Conv2D(filters, 3, padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
    elif batch_norm_first == False:
        x = Conv2D(filters, 3, padding="same")(inputs)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)

        x = Conv2D(filters, 3, padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)

    if pool == True:
        p = MaxPooling2D((2, 2))(x)
        return [x, p]
    else:
        return x

In [12]:
img_input = Input(shape=(input_height,input_width, channels))
x1 = Conv2D(64, 3, padding="same")(img_input)
x1 = BatchNormalization()(x1)
x1 = Activation("relu")(x1)

x1 = Conv2D(64, 3, padding="same")(x1)
x1 = BatchNormalization()(x1)
x1 = Activation("relu")(x1)

p1 = MaxPooling2D((2, 2))(x1)
###########################
x2 = Conv2D(128, 3, padding="same")(p1)
x2 = BatchNormalization()(x2)
x2 = Activation("relu")(x2)

x2 = Conv2D(128, 3, padding="same")(x2)
x2 = BatchNormalization()(x2)
x2 = Activation("relu")(x2)

p2 = MaxPooling2D((2, 2))(x2)
###########################
x3 = Conv2D(256, 3, padding="same")(p2)
x3 = BatchNormalization()(x3)
x3 = Activation("relu")(x3)

x3 = Conv2D(256, 3, padding="same")(x3)
x3 = BatchNormalization()(x3)
x3 = Activation("relu")(x3)
###########################
x4 = UpSampling2D((2, 2))(x3)
x4 = concatenate([x4, x2], axis=3)

x4 = Conv2D(128, 3, padding="same")(x4)
x4 = BatchNormalization()(x4)
x4 = Activation("relu")(x4)

x4 = Conv2D(128, 3, padding="same")(x4)
x4 = BatchNormalization()(x4)
x4 = Activation("relu")(x4)
###########################
crf_mid = CrfRnnLayer(image_dims=(128, 128),
                         num_classes=128,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([x4, x2])
###########################
x5 = UpSampling2D((2, 2))(x4)
x5 = concatenate([x5, x1], axis=3)

x5 = Conv2D(64, 3, padding="same")(x5)
x5 = BatchNormalization()(x5)
x5 = Activation("relu")(x5)

x5 = Conv2D(64, 3, padding="same")(x5)
x5 = BatchNormalization()(x5)
x5 = Activation("relu")(x5)

x5 = Conv2D(n_classes, (1,1), padding='same')(x5)
x5 = BatchNormalization()(x5)

crf_output = CrfRnnLayer(image_dims=(input_height, input_width),
                         num_classes=3,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([x5, img_input])
model = get_segmentation_model(img_input, crf_output)

In [13]:
model.summary()

Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 256, 256, 64) 1792        input_6[0][0]                    
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 256, 256, 64) 256         conv2d_38[0][0]                  
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 256, 256, 64) 0           batch_normalization_36[0][0]     
____________________________________________________________________________________________

In [14]:
model.train(
    train_images =  "/Users/mavaylon/Research/Data1/train/img/",
    train_annotations = "/Users/mavaylon/Research/Data1/train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/train/img/*")),
    batch_size=1,
    validate=True,
    val_images="/Users/mavaylon/Research/Data1/test/img/",
    val_annotations="/Users/mavaylon/Research/Data1/test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/test/img/*"))
)

Verifying training dataset


 55%|█████▌    | 3254/5912 [00:21<00:17, 154.16it/s]


KeyboardInterrupt: 