In [6]:
from keras.layers import Conv2D, MaxPooling2D, Input, ZeroPadding2D, Input, Dropout, Conv2DTranspose, Cropping2D, Add, UpSampling2D, BatchNormalization, Activation
from keras.models import Model
from keras.layers.merge import concatenate
from image_segmentation_keras.keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob

import sys
sys.path.insert(1, './src')
from crfrnn_layer import CrfRnnLayer

In [7]:
import keras
import tensorflow
from tensorflow import keras as k
print(keras.__version__, tensorflow.__version__, k.__version__)

2.3.1 2.2.0 2.3.0-tf


In [8]:
def conv_block(inputs, filters, pool=True):
    x = Conv2D(filters, 3, padding="same")(inputs)
#     x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)


    x = Conv2D(filters, 3, padding="same")(x)
#     x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = BatchNormalization()(x)


    if pool == True:
        p = MaxPooling2D((2, 2))(x)
#         p = Dropout(0.15)(p)
        return x, p
    else:
        return x
def build_unet(shape, num_classes):
    img_input = Input(shape)
    o = Dropout(0.05)(img_input)

    """ Encoder """
    x1, p1 = conv_block(img_input, 64, pool=True)
    x2, p2 = conv_block(p1, 128, pool=True)
    x3, p3 = conv_block(p2, 256, pool=True)
    x4, p4 = conv_block(p3, 512, pool=True)

    """ Bridge """
    b1 = conv_block(p4, 1024, pool=False)

    """ Decoder """
    u1 = Conv2DTranspose(512,(2,2),strides=2,padding='same')(b1)
    c1 = concatenate([u1, x4],axis=3)
    x5 = conv_block(c1, 512, pool=False)

    u2 = Conv2DTranspose(256,(2,2),strides=2,padding='same')(x5)
    c2 = concatenate([u2, x3],axis=3)
    x6 = conv_block(c2, 256, pool=False)

    u3 = Conv2DTranspose(256,(2,2),strides=2,padding='same')(x6)
    c3 = concatenate([u3, x2],axis=3)
    x7 = conv_block(c3, 128, pool=False)

    u4 = Conv2DTranspose(256,(2,2),strides=2,padding='same')(x7)
    c4 = concatenate([u4, x1],axis=3)
    x8 = conv_block(c4, 64, pool=False)

    """ Output layer """
    output = Conv2D(num_classes, 1, padding="same", activation="relu")(x8)
    
    crf_output = CrfRnnLayer(image_dims=(input_height, input_width),
                         num_classes=num_classes,
                         theta_alpha=160.,
                         theta_beta=3.,
                         theta_gamma=3.,
                         num_iterations=10,
                         name='crfrnn')([output, img_input])

    return get_segmentation_model(img_input ,  crf_output )


In [9]:
shape = (256, 256, 3)
num_classes = 3
input_height=256 
input_width=256
model = build_unet(shape, num_classes)


In [5]:
#shuffled
model.train(
    train_images =  "/Users/mavaylon/Research/Data1/train/img/",
    train_annotations = "/Users/mavaylon/Research/Data1/train/ann/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/train/img/*")),
    batch_size=1,
    validate=True,
    val_images="/Users/mavaylon/Research/Data1/test/img/",
    val_annotations="/Users/mavaylon/Research/Data1/test/ann/",
    val_batch_size=1,
    val_steps_per_epoch=len(glob("/Users/mavaylon/Research/Data1/test/img/*"))
)

Verifying training dataset


100%|██████████| 5912/5912 [00:17<00:00, 337.18it/s]
  2%|▏         | 29/1478 [00:00<00:05, 283.65it/s]

Dataset verified! 
Verifying validation dataset


100%|██████████| 1478/1478 [00:04<00:00, 317.43it/s]


Dataset verified! 
correct
Epoch 1/20

Epoch 00001: val_accuracy improved from -inf to 0.54860, saving model to pet_class_crf.h5
Epoch 2/20

Epoch 00002: val_accuracy improved from 0.54860 to 0.69533, saving model to pet_class_crf.h5
Epoch 3/20

Epoch 00003: val_accuracy did not improve from 0.69533
Epoch 4/20
  73/5912 [..............................] - ETA: 4:42:41 - loss: 0.3610 - accuracy: 0.8678

KeyboardInterrupt: 

In [10]:
model.load_weights('/Users/mavaylon/Research/pet_weights/unet_petcrf/unet__shuffled_pet_class_crf_bn_after_bothconv.h5')

ValueError: You are trying to load a weight file containing 32 layers into a model with 42 layers.

!pip list

model.load_weights("/Users/mavaylon/Research/LBNL_Segmentation_crf/unet_pet_class.h5")

model.trai