In [12]:
from keras.layers import *
from keras.models import Model
from keras import layers
from keras.layers.merge import concatenate
import sys
sys.path.insert(1, '../src')
sys.path.insert(1, '../image_segmentation_keras')
from keras_segmentation.models.config import IMAGE_ORDERING

from keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob
from crfrnn_layer import CrfRnnLayer

In [13]:
input_height = 512
input_width = 512
n_classes = 2
channels = 3

In [None]:
#residual unet adapted from:
#https://github.com/nikhilroxtomar/Deep-Residual-Unet/blob/master/Deep%20Residual%20UNet.ipynb

In [14]:
def bn_act(x, act=True):
    x = BatchNormalization()(x)
    if act == True:
        x = Activation("relu")(x)
    return x

def conv_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = bn_act(x)
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(conv)
    return conv

def stem(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    conv = Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
    conv = conv_block(conv, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    
    shortcut = Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = Add()([conv, shortcut])
    return output

def residual_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
    res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides)
    res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1)
    
    shortcut = Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
    shortcut = bn_act(shortcut, act=False)
    
    output = Add()([shortcut, res])
    return output

def upsample_concat_block(x, xskip):
    u = UpSampling2D((2, 2))(x)
    c = Concatenate()([u, xskip])
    return c


In [24]:
img_input = Input(shape=(input_height,input_width, channels))
f = [16, 32, 64, 128, 256]
#inputs = keras.layers.Input((image_size, image_size, 3))

## Encoder
e0 = img_input
e1 = stem(e0, f[0])
e2 = residual_block(e1, f[1], strides=2)
e3 = residual_block(e2, f[2], strides=2)
e4 = residual_block(e3, f[3], strides=2)
e5 = residual_block(e4, f[4], strides=2)

## Bridge
b0 = conv_block(e5, f[4], strides=1)
b1 = conv_block(b0, f[4], strides=1)

## Decoder
u1 = upsample_concat_block(b1, e4)
d1 = residual_block(u1, f[4])

u2 = upsample_concat_block(d1, e3)
d2 = residual_block(u2, f[3])

u3 = upsample_concat_block(d2, e2)
d3 = residual_block(u3, f[2])

u4 = upsample_concat_block(d3, e1)
d4 = residual_block(u4, f[1])
outputs = Conv2D(n_classes, (1, 1), padding="same", activation="relu")(d4)
print(outputs)
print(n_classes)
crf_output = CrfRnnLayer(image_dims=(input_height, input_width),
                         num_classes=n_classes,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([outputs, img_input])
model = get_segmentation_model(img_input, crf_output)
model.n_classes

Tensor("conv2d_267/Relu:0", shape=(None, 512, 512, 2), dtype=float32)
2


2

In [26]:
len(glob("../../data/BP_C_train/img/*"))

818

In [27]:
len(glob("../../data/BP_C_test/img/*"))

206

In [28]:
model.train(
    train_images = "../../data/BP_C_train/img/",
    train_annotations = "../../data/BP_C_train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("../../data/BP_C_train/img/*")),
    batch_size=1,
    validate=True,
    val_images="../../data/BP_C_test/img/",
    val_annotations="../../data/BP_C_test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("../../data/BP_C_test/img/*"))
)

train.py n_classes from model.n_classes: 2


  1%|          | 10/818 [00:00<00:08, 92.25it/s]

Verifying training dataset


100%|██████████| 818/818 [00:08<00:00, 96.62it/s] 
  5%|▍         | 10/206 [00:00<00:02, 94.40it/s]

Dataset verified! 
Verifying validation dataset


100%|██████████| 206/206 [00:02<00:00, 87.15it/s]


Dataset verified! 
fit
Epoch 1/20
 74/818 [=>............................] - ETA: 1:46:03 - loss: 0.0825 - accuracy: 0.9711

KeyboardInterrupt: 