In [1]:
import sys
sys.path.insert(1, './src')
from crfrnn_model import get_crfrnn_model_def
from crfrnn_layer import CrfRnnLayer

Using TensorFlow backend.


In [8]:
from keras.models import *
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, Activation, add, concatenate
from image_segmentation_keras.keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob
input_height = 512
input_width = 512
n_classes = 2
channels = 3

shape=(input_height,input_width, channels)


def res_block(x, nb_filters, strides):
    res_path = BatchNormalization()(x)
    res_path = Activation(activation='relu')(res_path)
    res_path = Conv2D(filters=nb_filters[0], kernel_size=(3, 3), padding='same', strides=strides[0])(res_path)
    res_path = BatchNormalization()(res_path)
    res_path = Activation(activation='relu')(res_path)
    res_path = Conv2D(filters=nb_filters[1], kernel_size=(3, 3), padding='same', strides=strides[1])(res_path)

    shortcut = Conv2D(nb_filters[1], kernel_size=(1, 1), strides=strides[0])(x)
    shortcut = BatchNormalization()(shortcut)

    res_path = add([shortcut, res_path])
    return res_path


def encoder(x):
    to_decoder = []

    main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(x)
    main_path = BatchNormalization()(main_path)
    main_path = Activation(activation='relu')(main_path)

    main_path = Conv2D(filters=64, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path)

    shortcut = Conv2D(filters=64, kernel_size=(1, 1), strides=(1, 1))(x)
    shortcut = BatchNormalization()(shortcut)

    main_path = add([shortcut, main_path])
    # first branching to decoder
    to_decoder.append(main_path)

    main_path = res_block(main_path, [128, 128], [(2, 2), (1, 1)])
    to_decoder.append(main_path)

    main_path = res_block(main_path, [256, 256], [(2, 2), (1, 1)])
    to_decoder.append(main_path)

    return to_decoder


def decoder(x, from_encoder):
    main_path = UpSampling2D(size=(2, 2))(x)
    main_path = concatenate([main_path, from_encoder[2]], axis=3)
    main_path = res_block(main_path, [256, 256], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = concatenate([main_path, from_encoder[1]], axis=3)
    main_path = res_block(main_path, [128, 128], [(1, 1), (1, 1)])

    main_path = UpSampling2D(size=(2, 2))(main_path)
    main_path = concatenate([main_path, from_encoder[0]], axis=3)
    main_path = res_block(main_path, [64, 64], [(1, 1), (1, 1)])

    return main_path


def build_res_unet(input_shape):
    inputs = Input(shape=input_shape)

    to_decoder = encoder(inputs)

    path = res_block(to_decoder[2], [512, 512], [(2, 2), (1, 1)])

    path = decoder(path, from_encoder=to_decoder)

    path = Conv2D(filters=n_classes, kernel_size=(1, 1), activation='relu')(path)
    crf_output = CrfRnnLayer(image_dims=(input_height, input_width),
                         num_classes=n_classes,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([path, inputs])
    return get_segmentation_model(inputs, crf_output)

In [9]:
model = build_res_unet(shape)

In [10]:
model.summary()

Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 512, 512, 3)  0                                            
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 512, 512, 64) 1792        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_41 (BatchNo (None, 512, 512, 64) 256         conv2d_46[0][0]                  
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 512, 512, 64) 256         input_4[0][0]                    
____________________________________________________________________________________________

In [11]:
model.train(
    train_images = "/Users/mavaylon/Research/Research_Gambier/Data/BP_C_train/img/",
    train_annotations = "/Users/mavaylon/Research/Research_Gambier/Data/BP_C_train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Research_Gambier/Data/BP_C_train/img/*")),
    batch_size=1,
    validate=True,
    val_images="/Users/mavaylon/Research/Research_Gambier/Data/BP_C_test/img/",
    val_annotations="/Users/mavaylon/Research/Research_Gambier/Data/BP_C_test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("/Users/mavaylon/Research/Research_Gambier/Data/BP_C_test/img/*"))
)


  1%|▏         | 11/818 [00:00<00:07, 104.19it/s]

Verifying training dataset


100%|██████████| 818/818 [00:07<00:00, 109.11it/s]
  5%|▌         | 11/206 [00:00<00:01, 102.66it/s]

Dataset verified! 
Verifying validation dataset


100%|██████████| 206/206 [00:02<00:00, 101.17it/s]


Dataset verified! 
correct
Epoch 1/20
  4/818 [..............................] - ETA: 7:19:36 - loss: 0.2813 - accuracy: 0.9116

KeyboardInterrupt: 

https://towardsdatascience.com/understanding-and-coding-a-resnet-in-keras-446d7ff84d33 super helpful link on explaining the shortcuts and how we need matrices of the same size to add and so we have a conv2d in the shortcut path sometimes.