In [1]:
from keras.layers import *
from keras.models import Model
from keras import layers
from keras.layers.merge import concatenate
import sys
sys.path.insert(1, '../src')
sys.path.insert(1, '../image_segmentation_keras')
from keras_segmentation.models.config import IMAGE_ORDERING

from keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob
from crfrnn_layer import CrfRnnLayer

Using TensorFlow backend.


In [2]:
input_height=256
input_width=256 
n_classes = 3
channels = 3

In [3]:
#residual unet adapted from:
#https://github.com/nikhilroxtomar/Deep-Residual-Unet/blob/master/Deep%20Residual%20UNet.ipynb

In [4]:
def bn_act(x, act=True):
    x = BatchNormalization()(x)
    if act == True:
        x = Activation("relu")(x)
    return x

def conv_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = bn_act(x)
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(conv)
    return conv

def stem(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
    conv = conv_block(conv, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    
    shortcut = Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = Add()([conv, shortcut])
    return output

def residual_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1)
    
    shortcut = Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = Add()([shortcut, res])
    return output

def upsample_concat_block(x, xskip):
    u = UpSampling2D((2, 2))(x)
    c = Concatenate()([u, xskip])
    return c


In [5]:
img_input = Input(shape=(input_height,input_width, channels))
f = [32, 64, 128, 256, 512]
#inputs = keras.layers.Input((image_size, image_size, 3))

## Encoder
e0 = img_input
e1 = stem(e0, f[0])
e2 = residual_block(e1, f[1], strides=2)
e3 = residual_block(e2, f[2], strides=2)
e4 = residual_block(e3, f[3], strides=2)
e5 = residual_block(e4, f[4], strides=2)

## Bridge
b0 = conv_block(e5, f[4], strides=1)
b1 = conv_block(b0, f[4], strides=1)

## Decoder
u1 = upsample_concat_block(b1, e4)
d1 = residual_block(u1, f[4])

u2 = upsample_concat_block(d1, e3)
d2 = residual_block(u2, f[3])

u3 = upsample_concat_block(d2, e2)
d3 = residual_block(u3, f[2])

u4 = upsample_concat_block(d3, e1)
d4 = residual_block(u4, f[1])
outputs = Conv2D(n_classes, (1, 1), padding="same", activation="relu")(d4)
print(outputs)
print(n_classes)
crf_output = CrfRnnLayer(image_dims=(input_height, input_width),
                         num_classes=n_classes,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([outputs, img_input])
model = get_segmentation_model(img_input, crf_output)
model.n_classes

Tensor("conv2d_30/Relu:0", shape=(None, 256, 256, 3), dtype=float32)
3


3

In [6]:
model.train(
    train_images =  "/Users/mavaylon/Research/Data1/train/img/",
    train_annotations = "/Users/mavaylon/Research/Data1/train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/train/img/*")),
    batch_size=1,
    validate=True,
    val_images="/Users/mavaylon/Research/Data1/test/img/",
    val_annotations="/Users/mavaylon/Research/Data1/test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/test/img/*"))
)

  0%|          | 0/5912 [00:00<?, ?it/s]

Verifying training dataset


100%|██████████| 5912/5912 [00:17<00:00, 338.61it/s]
  2%|▏         | 33/1478 [00:00<00:04, 322.08it/s]

Dataset verified! 
Verifying validation dataset


100%|██████████| 1478/1478 [00:04<00:00, 348.79it/s]


Dataset verified! 
fit
Epoch 1/20

Epoch 00001: val_accuracy improved from -inf to 0.74829, saving model to pet_class_crf.h5
Epoch 2/20

Epoch 00002: val_accuracy improved from 0.74829 to 0.75722, saving model to pet_class_crf.h5
Epoch 3/20

Epoch 00003: val_accuracy improved from 0.75722 to 0.78183, saving model to pet_class_crf.h5
Epoch 4/20

Epoch 00004: val_accuracy improved from 0.78183 to 0.79738, saving model to pet_class_crf.h5
Epoch 5/20

Epoch 00005: val_accuracy improved from 0.79738 to 0.81205, saving model to pet_class_crf.h5
Epoch 6/20

Epoch 00006: val_accuracy did not improve from 0.81205
Epoch 7/20

Epoch 00007: val_accuracy did not improve from 0.81205
Epoch 8/20

Epoch 00008: val_accuracy improved from 0.81205 to 0.81842, saving model to pet_class_crf.h5
Epoch 9/20

Epoch 00009: val_accuracy improved from 0.81842 to 0.83036, saving model to pet_class_crf.h5
Epoch 10/20

Epoch 00010: val_accuracy did not improve from 0.83036
Epoch 11/20

Epoch 00011: val_accuracy did 