In [17]:
import tensorflow as tf
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import backend as keras_backend
from tensorflow.python.keras.layers import Layer, Input, Dense, Conv2D, Activation, BatchNormalization, Reshape
import pandas as pd
from os import walk as walk


In [18]:
PATH_DATA = './ImageSegData/'
DIR_LABELS = 'labels'
DIR_IMAGES = 'images'

PATH_LABELS = PATH_DATA+DIR_LABELS
PATH_IMAGES = PATH_DATA+DIR_IMAGES


# SegNet
Link to paper: https://arxiv.org/abs/1511.00561

## Custom Layers


In [19]:

class MaxPoolingWithArgmax2D(Layer):       
    """Keras Wrapper around tf.nn.max_pool_with_argmax.
    Return the pooled output and the corresponding max indices
    """
    def __init__(self, pool_size=(2,2), strides=(2,2), padding='same', **kwargs):
        super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
        self.pool_size = pool_size
        self.strides = strides
        self.padding = padding
        
    def call(self, inputs, **kwargs):
        padding = self.padding
        pool_size = self.pool_size
        strides = self.strides
        ksize = [1, pool_size[0], pool_size[1], 1] # byxc
        padding = padding.upper()
        strides = [1, strides[0], strides[1], 1]
        output, argmax = tf.nn.max_pool_with_argmax(
                                                    inputs,
                                                    ksize=ksize,
                                                    strides=strides,
                                                    padding=padding)
        argmax = tf.cast(argmax, keras_backend.floatx())
        return [output, argmax]
    
class MaxUpsampling2D(Layer):       
    """Keras Wrapper around tf.nn.max_pool_with_argmax.
    Return the pooled output and the corresponding max indices
    """
    def __init__(self, size=(2, 2), **kwargs):
        super(MaxUpsampling2D, self).__init__(**kwargs)
        self.size = size
        
    def call(self, inputs, output_shape=None):
        updates, mask = inputs[0], inputs[1]
        with tf.variable_scope(self.name):
            mask = keras_backend.cast(mask, 'int32')
            input_shape = tf.shape(updates, out_type='int32')
            #  Calculate new shape
            if output_shape is None:
                output_shape = (
                        input_shape[0],
                        input_shape[1]*self.size[0],
                        input_shape[2]*self.size[1],
                        input_shape[3])
            self.output_shape1 = output_shape

            # calculation indices for batch, height, width and feature maps
            one_like_mask = tf.ones_like(mask, dtype='int32')
            batch_shape = tf.concat(
                    [[input_shape[0]], [1], [1], [1]],
                    axis=0)
            batch_range = tf.reshape(
                    tf.range(output_shape[0], dtype='int32'),
                    shape=batch_shape)
            b = one_like_mask * batch_range
            y = mask // (output_shape[2] * output_shape[3])
            x = (mask // output_shape[3]) % output_shape[2]
            feature_range = tf.range(output_shape[3], dtype='int32')
            f = one_like_mask * feature_range

            # transpose indices & reshape update values to one dimension
            updates_size = tf.size(updates)
            indices = tf.transpose(keras_backend.reshape(
                tf.stack([b, y, x, f]),
                [4, updates_size]))
            values = tf.reshape(updates, [updates_size])
            ret = tf.scatter_nd(indices, values, output_shape)
            return ret
        

In [20]:
def build_segnet(input_shape, num_classes, kernel_size=3, pool_size=(2,2)):
    # Encoder pass
    inputs = Input(shape=input_shape)
    
    conv_1 = Conv2D(filters=64, kernel_size=kernel_size, padding='same')(inputs)
    batch_norm_1 = BatchNormalization()(conv_1)
    act_1 = Activation('relu')(batch_norm_1)
    conv_2 = Conv2D(filters=64, kernel_size=kernel_size, padding='same')(act_1)
    batch_norm_2 = BatchNormalization()(conv_2)
    act_2 = Activation('relu')(batch_norm_2)
    max_pool_1, max_indices_1 = MaxPoolingWithArgmax2D(pool_size)(act_2)
    
    conv_3 = Conv2D(filters=128, kernel_size=kernel_size, padding='same')(max_pool_1)
    batch_norm_3 = BatchNormalization()(conv_3)
    act_3 = Activation('relu')(batch_norm_3)
    conv_4 = Conv2D(filters=128, kernel_size=kernel_size, padding='same')(act_3)
    batch_norm_4 = BatchNormalization()(conv_4)
    act_4 = Activation('relu')(batch_norm_4)
    max_pool_2, max_indices_2 = MaxPoolingWithArgmax2D(pool_size)(act_4)
    
    conv_5 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(max_pool_2)
    batch_norm_5 = BatchNormalization()(conv_5)
    act_5 = Activation('relu')(batch_norm_5)
    conv_6 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(act_5)
    batch_norm_6 = BatchNormalization()(conv_6)
    act_6 = Activation('relu')(batch_norm_6)
    conv_7 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(act_6)
    batch_norm_7 = BatchNormalization()(conv_7)
    act_7 = Activation('relu')(batch_norm_7)
    max_pool_3, max_indices_3 = MaxPoolingWithArgmax2D(pool_size)(act_7)
    
    conv_8 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(max_pool_3)
    batch_norm_8 = BatchNormalization()(conv_8)
    act_8 = Activation('relu')(batch_norm_8)
    conv_9 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_8)
    batch_norm_9 = BatchNormalization()(conv_9)
    act_9 = Activation('relu')(batch_norm_9)
    conv_10 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_9)
    batch_norm_10 = BatchNormalization()(conv_10)
    act_10 = Activation('relu')(batch_norm_10)
    max_pool_4, max_indices_4 = MaxPoolingWithArgmax2D(pool_size)(act_10)
    
    conv_11 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(max_pool_4)
    batch_norm_11 = BatchNormalization()(conv_11)
    act_11 = Activation('relu')(batch_norm_11)
    conv_12 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_11)
    batch_norm_12 = BatchNormalization()(conv_12)
    act_12 = Activation('relu')(batch_norm_12)
    conv_13 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_12)
    batch_norm_13 = BatchNormalization()(conv_13)
    act_13 = Activation('relu')(batch_norm_13)
    max_pool_5, max_indices_5 = MaxPoolingWithArgmax2D(pool_size)(act_13)
    
    # Decoder pass
    upsampling_1 = MaxUpsampling2D(pool_size)([max_pool_5, max_indices_5])
    conv_14 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(upsampling_1)
    batch_norm_14 = BatchNormalization()(conv_14)
    act_14 = Activation('relu')(batch_norm_14)
    conv_15 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_14)
    batch_norm_15 = BatchNormalization()(conv_15)
    act_15 = Activation('relu')(batch_norm_15)
    conv_16 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_15)
    batch_norm_16 = BatchNormalization()(conv_16)
    act_16 = Activation('relu')(batch_norm_16)
    
    upsampling_2 = MaxUpsampling2D(pool_size)([act_16, max_indices_4])
    conv_17 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(upsampling_2)
    batch_norm_17 = BatchNormalization()(conv_17)
    act_17 = Activation('relu')(batch_norm_17)
    conv_18 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_17)
    batch_norm_18 = BatchNormalization()(conv_18)
    act_18 = Activation('relu')(batch_norm_18)
    conv_19 = Conv2D(filters=512, kernel_size=kernel_size, padding='same')(act_18)
    batch_norm_19 = BatchNormalization()(conv_19)
    act_19 = Activation('relu')(batch_norm_19)
    
    upsampling_3 = MaxUpsampling2D(pool_size)([act_19, max_indices_3])
    conv_20 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(upsampling_3)
    batch_norm_20 = BatchNormalization()(conv_20)
    act_20 = Activation('relu')(batch_norm_20)
    conv_21 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(act_20)
    batch_norm_21 = BatchNormalization()(conv_21)
    act_21 = Activation('relu')(batch_norm_21)
    conv_22 = Conv2D(filters=256, kernel_size=kernel_size, padding='same')(act_21)
    batch_norm_22 = BatchNormalization()(conv_22)
    act_22 = Activation('relu')(batch_norm_22)
    
    upsampling_4 = MaxUpsampling2D(pool_size)([act_22, max_indices_2])
    conv_23 = Conv2D(filters=128, kernel_size=kernel_size, padding='same')(upsampling_4)
    batch_norm_23 = BatchNormalization()(conv_23)
    act_23 = Activation('relu')(batch_norm_23)
    conv_24 = Conv2D(filters=128, kernel_size=kernel_size, padding='same')(act_23)
    batch_norm_24 = BatchNormalization()(conv_24)
    act_24 = Activation('relu')(batch_norm_24)
    
    upsampling_5 = MaxUpsampling2D(pool_size)([act_24, max_indices_1])
    conv_25 = Conv2D(filters=64, kernel_size=kernel_size, padding='same')(upsampling_3)
    batch_norm_25 = BatchNormalization()(conv_25)
    act_25 = Activation('relu')(batch_norm_25)
    conv_26 = Conv2D(filters=num_classes, kernel_size=(1,1), padding='valid')(act_25)
    batch_norm_26 = BatchNormalization()(conv_26)
    
    outputs = Reshape(target_shape=(input_shape[0]*input_shape[1], num_classes),
                      input_shape=(input_shape[0], input_shape[1], num_classes))(batch_norm_26)
    
    outputs = Activation('relu')(outputs)
    
    return Model(inputs=inputs, outputs=outputs, name="SegNet")
       

# Pipeline


In [21]:
batch_size=16
labels = []
for (dirpath, dirnames, filenames) in walk(PATH_LABELS):
    for name in filenames:
        path = dirpath+'/'+name
        labels.append(path)

images = []
for (dirpath, dirnames, filenames) in walk(PATH_IMAGES):
    for name in filenames:
        path = dirpath+'/'+name
        images.append(path)

# Create a dataset (image, label)-pairs
ds = tf.data.Dataset.from_tensor_slices((images, labels))

def load_and_preprocess_from_paths(image_path, label_path):
    image_raw = tf.read_file(image_path)
    label_raw = tf.read_file(label_path)

    # TODO: PREPROCESSING (Normalization..)
    image = tf.image.decode_png(image_raw, channels=3)
    label = tf.image.decode_png(label_raw, channels=3) 
    return image, label

image_label_ds = ds.map(load_and_preprocess_from_paths)
image_label_ds = image_label_ds.shuffle(buffer_size=3)
image_label_ds = image_label_ds.repeat()
image_label_ds = image_label_ds.batch(batch_size)
#image_label_ds = image_label_ds.prefetch(buffer_size=AUTOTUNE)

# create general iterator
iterator = tf.data.Iterator.from_structure(image_label_ds.output_types,
                                       image_label_ds.output_shapes)

next_element = iterator.get_next()
training_init_op = iterator.make_initializer(image_label_ds)

# with tf.Session() as session:
#     session.run(tf.local_variables_initializer())
#     session.run(tf.global_variables_initializer())
#     session.run(training_init_op)
# 
#     for step in range(max_steps):
#         feed_dict = {self.inputs_pl: next_element}
#         print(session.run(next_element))       
         

# Train SegNet
                

In [24]:
segnet = build_segnet((640,480,1), 5)
segnet.summary()

ValueError: The channel dimension of the inputs should be defined. Found `None`.

In [25]:
segnet.fit(iterator, steps_per_epoch=len(images), epochs=1, verbose=1)


NameError: name 'segnet' is not defined