In [9]:
from keras.layers import *
from keras.models import Model
from keras import layers
from keras.layers.merge import concatenate
import sys
sys.path.insert(1, '../src')
sys.path.insert(1, '../image_segmentation_keras')
from keras_segmentation.models.config import IMAGE_ORDERING

from keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob
from crfrnn_layer import CrfRnnLayer

In [10]:
channels, height, width = 3, 256, 256


In [11]:
def unet_conv_block(inputs, filters, pool=True, batch_norm_first=True):
    if batch_norm_first == True:
        x = Conv2D(filters, 3, padding="same")(inputs)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)

        x = Conv2D(filters, 3, padding="same")(x)
        x = BatchNormalization()(x)
        x = Activation("relu")(x)
    elif batch_norm_first == False:
        x = Conv2D(filters, 3, padding="same")(inputs)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)

        x = Conv2D(filters, 3, padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization()(x)

    if pool == True:
        p = MaxPooling2D((2, 2))(x)
        return [x, p]
    else:
        return x

In [12]:
def _unet(n_classes, encoder, l1_skip_conn=True, input_height=416,
          input_width=608):

  
    img_input, levels = encoder(
        input_height=input_height, input_width=input_width)
    [f1, f2, f3, f4, f5, p5] = levels
    
    print("f5",f5.shape)

    x = p5
    
    """ Bridge """
    x = unet_conv_block(x, 1024, pool=False)
    
    x = UpSampling2D((2, 2))(x)
    x = concatenate([x, f5], axis=3)
    x = unet_conv_block(x, 512, pool=False, batch_norm_first=True)
    
    x = UpSampling2D((2, 2))(x)
    x = concatenate([x, f4], axis=3)
    x = unet_conv_block(x, 512, pool=False, batch_norm_first=True)

    x = UpSampling2D((2, 2))(x)
    x = concatenate([x, f3], axis=3)
    x = unet_conv_block(x, 256, pool=False, batch_norm_first=True)

    x = UpSampling2D((2, 2))(x)
    x = concatenate([x, f2], axis=3)
    x = unet_conv_block(x, 128, pool=False, batch_norm_first=True)

    x = UpSampling2D((2, 2))(x)
    x = concatenate([x, f1], axis=3)
    x = unet_conv_block(x, 64, pool=False, batch_norm_first=True)

    x = Conv2D(n_classes, (1, 1), padding='same')(x)

    crf_output = CrfRnnLayer(image_dims=(height, width),
                         num_classes=n_classes,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([x, img_input])
    model = get_segmentation_model(img_input, crf_output)

    return model

In [13]:
if IMAGE_ORDERING == 'channels_first':
    MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':
    MERGE_AXIS = -1
def get_vgg_encoder(input_height=224,  input_width=224, pretrained='imagenet'):

    assert input_height % 32 == 0
    assert input_width % 32 == 0

    if IMAGE_ORDERING == 'channels_first':
        img_input = Input(shape=(3, input_height, input_width))
    elif IMAGE_ORDERING == 'channels_last':
        img_input = Input(shape=(input_height, input_width, 3))

    x = Conv2D(64, (3, 3), activation='relu', padding='same',
               name='block1_conv1', data_format=IMAGE_ORDERING)(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same',
               name='block1_conv2', data_format=IMAGE_ORDERING)(x)
    p1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool',
                     data_format=IMAGE_ORDERING)(x)
    f1 = x
    # Block 2
    x = Conv2D(128, (3, 3), activation='relu', padding='same',
               name='block2_conv1', data_format=IMAGE_ORDERING)(p1)
    x = Conv2D(128, (3, 3), activation='relu', padding='same',
               name='block2_conv2', data_format=IMAGE_ORDERING)(x)
    p2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool',
                     data_format=IMAGE_ORDERING)(x)
    f2 = x

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu', padding='same',
               name='block3_conv1', data_format=IMAGE_ORDERING)(p2)
    x = Conv2D(256, (3, 3), activation='relu', padding='same',
               name='block3_conv2', data_format=IMAGE_ORDERING)(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same',
               name='block3_conv3', data_format=IMAGE_ORDERING)(x)
    p3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool',
                     data_format=IMAGE_ORDERING)(x)
    f3 = x

    # Block 4
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block4_conv1', data_format=IMAGE_ORDERING)(p3)
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block4_conv2', data_format=IMAGE_ORDERING)(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block4_conv3', data_format=IMAGE_ORDERING)(x)
    p4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool',
                     data_format=IMAGE_ORDERING)(x)
    f4 = x

    # Block 5
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block5_conv1', data_format=IMAGE_ORDERING)(p4)
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block5_conv2', data_format=IMAGE_ORDERING)(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same',
               name='block5_conv3', data_format=IMAGE_ORDERING)(x)
    p5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool',
                     data_format=IMAGE_ORDERING)(x)
    f5 = x

    return img_input, [f1, f2, f3, f4, f5, p5]

In [14]:
def vgg_unet(n_classes, input_height=416, input_width=608, encoder_level=3):

    model = _unet(n_classes, get_vgg_encoder,
                  input_height=input_height, input_width=input_width)
    model.model_name = "vgg_unet"
    return model    return model

In [15]:
model = vgg_unet(n_classes=3,input_height=256, input_width=256)
model.summary()
# model.load_weights('/Users/mavaylon/Research/pet_weights/VGG_CRF_PET/pet_class_crf.h5')

f5 (None, 16, 16, 512)
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 256, 256, 64) 1792        input_2[0][0]                    
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 256, 256, 64) 36928       block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_pool (MaxPooling2D)      (None, 128, 128, 64) 0           block1_conv2[0][0]               
_____________________________________________________________________

In [8]:
model.train(
    train_images =  "/Users/mavaylon/Research/Data1/train/img/",
    train_annotations = "/Users/mavaylon/Research/Data1/train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/train/img/*")),
    batch_size=1,
    validate=True,
    val_images="/Users/mavaylon/Research/Data1/test/img/",
    val_annotations="/Users/mavaylon/Research/Data1/test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/test/img/*"))
)

Verifying training dataset


100%|██████████| 5912/5912 [00:18<00:00, 311.94it/s]
  0%|          | 0/1478 [00:00<?, ?it/s]

Dataset verified! 
Verifying validation dataset


100%|██████████| 1478/1478 [00:05<00:00, 281.29it/s]


Dataset verified! 
fit
Epoch 1/20

Epoch 00001: val_accuracy improved from -inf to 0.86069, saving model to pet_class_crf.h5
Epoch 2/20

Epoch 00002: val_accuracy improved from 0.86069 to 0.87200, saving model to pet_class_crf.h5
Epoch 3/20

Epoch 00003: val_accuracy improved from 0.87200 to 0.87277, saving model to pet_class_crf.h5
Epoch 4/20

Epoch 00004: val_accuracy did not improve from 0.87277
Epoch 5/20

Epoch 00005: val_accuracy did not improve from 0.87277
Epoch 6/20

Epoch 00006: val_accuracy did not improve from 0.87277
Epoch 7/20

KeyboardInterrupt: 