# SegNet

In [None]:
import _02c_read_datasets
import _02_evaluate_model

import os
import numpy as np
import itertools
import matplotlib.pyplot as plt
import tensorflow as tf
from datetime import datetime
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
from keras import backend as K
K.set_image_data_format('channels_last')
from tensorflow.keras.models import Model
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
import pandas as pd
from tensorflow import keras
from tensorflow.keras import *
from keras.layers.convolutional import Convolution2D
from keras.layers.core import Activation, Reshape
from keras.models import Model

In [None]:
#load data
epochs = 20
batches = 16
input_width = 256
input_shape = (256,256,4)
shuffled = True
augment = False #{True, False}
if augment:
    augmentation_settings = {
    "flip_left_right": 0,
    "flip_up_down": 0,
    "gaussian_blur": 0.2,
    "random_noise": 0.0,
    "random_brightness": 0.5,
    "random_contrast": 0.5}
else:
    augmentation_settings = None

train_dataset, val_dataset, test_dataset = _02c_read_datasets.load_datasets(augmented = augment)

In [3]:
# Layers
class MaxPoolingWithArgmax2D(Layer):
    def __init__(self, pool_size=(2, 2), strides=(2, 2), padding="same", **kwargs):
        super(MaxPoolingWithArgmax2D, self).__init__(**kwargs)
        self.padding = padding
        self.pool_size = pool_size
        self.strides = strides

    def call(self, inputs, **kwargs):
        padding = self.padding
        pool_size = self.pool_size
        strides = self.strides
        if K.backend() == "tensorflow":
            ksize = [1, pool_size[0], pool_size[1], 1]
            padding = padding.upper()
            strides = [1, strides[0], strides[1], 1]
            output, argmax = K.tf.nn.max_pool_with_argmax(
                inputs, ksize=ksize, strides=strides, padding=padding
            )
        else:
            errmsg = "{} backend is not supported for layer {}".format(
                K.backend(), type(self).__name__
            )
            raise NotImplementedError(errmsg)
        argmax = K.cast(argmax, K.floatx())
        return [output, argmax]

    def compute_output_shape(self, input_shape):
        ratio = (1, 2, 2, 1)
        output_shape = [
            dim // ratio[idx] if dim is not None else None
            for idx, dim in enumerate(input_shape)
        ]
        output_shape = tuple(output_shape)
        return [output_shape, output_shape]

    def compute_mask(self, inputs, mask=None):
        return 2 * [None]


class MaxUnpooling2D(Layer):
    def __init__(self, size=(2, 2), **kwargs):
        super(MaxUnpooling2D, self).__init__(**kwargs)
        self.size = size

    def call(self, inputs, output_shape=None):
        updates, mask = inputs[0], inputs[1]
        with K.tf.compat.v1.variable_scope(self.name):
            mask = K.cast(mask, "int32")
            input_shape = K.tf.shape(updates, out_type="int32")
            #  calculation new shape
            if output_shape is None:
                output_shape = (
                    input_shape[0],
                    input_shape[1] * self.size[0],
                    input_shape[2] * self.size[1],
                    input_shape[3],
                )
            self.output_shape1 = output_shape

            # calculation indices for batch, height, width and feature maps
            one_like_mask = K.ones_like(mask, dtype="int32")
            batch_shape = K.concatenate([[input_shape[0]], [1], [1], [1]], axis=0)
            batch_range = K.reshape(
                K.tf.range(output_shape[0], dtype="int32"), shape=batch_shape
            )
            b = one_like_mask * batch_range
            y = mask // (output_shape[2] * output_shape[3])
            x = (mask // output_shape[3]) % output_shape[2]
            feature_range = K.tf.range(output_shape[3], dtype="int32")
            f = one_like_mask * feature_range

            # transpose indices & reshape update values to one dimension
            updates_size = K.tf.size(updates)
            indices = K.transpose(K.reshape(K.stack([b, y, x, f]), [4, updates_size]))
            values = K.reshape(updates, [updates_size])
            ret = K.tf.scatter_nd(indices, values, output_shape)
            return ret

    def compute_output_shape(self, input_shape):
        mask_shape = input_shape[1]
        return (
            mask_shape[0],
            mask_shape[1] * self.size[0],
            mask_shape[2] * self.size[1],
            mask_shape[3],
        )

In [4]:
def segnet(input_shape, n_labels, kernel=3, pool_size=(2, 2), output_mode="softmax"):
    # encoder
    inputs = Input(shape=input_shape)

    conv_1 = Convolution2D(64, (kernel, kernel), padding="same")(inputs)
    conv_1 = BatchNormalization()(conv_1)
    conv_1 = Activation("relu")(conv_1)
    conv_2 = Convolution2D(64, (kernel, kernel), padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    conv_2 = Activation("relu")(conv_2)

    pool_1, mask_1 = MaxPoolingWithArgmax2D(pool_size)(conv_2)

    conv_3 = Convolution2D(128, (kernel, kernel), padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_3 = Activation("relu")(conv_3)
    conv_4 = Convolution2D(128, (kernel, kernel), padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    conv_4 = Activation("relu")(conv_4)

    pool_2, mask_2 = MaxPoolingWithArgmax2D(pool_size)(conv_4)

    conv_5 = Convolution2D(256, (kernel, kernel), padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_5 = Activation("relu")(conv_5)
    conv_6 = Convolution2D(256, (kernel, kernel), padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    conv_6 = Activation("relu")(conv_6)
    conv_7 = Convolution2D(256, (kernel, kernel), padding="same")(conv_6)
    conv_7 = BatchNormalization()(conv_7)
    conv_7 = Activation("relu")(conv_7)

    pool_3, mask_3 = MaxPoolingWithArgmax2D(pool_size)(conv_7)

    conv_8 = Convolution2D(512, (kernel, kernel), padding="same")(pool_3)
    conv_8 = BatchNormalization()(conv_8)
    conv_8 = Activation("relu")(conv_8)
    conv_9 = Convolution2D(512, (kernel, kernel), padding="same")(conv_8)
    conv_9 = BatchNormalization()(conv_9)
    conv_9 = Activation("relu")(conv_9)
    conv_10 = Convolution2D(512, (kernel, kernel), padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)
    conv_10 = Activation("relu")(conv_10)

    pool_4, mask_4 = MaxPoolingWithArgmax2D(pool_size)(conv_10)

    conv_11 = Convolution2D(512, (kernel, kernel), padding="same")(pool_4)
    conv_11 = BatchNormalization()(conv_11)
    conv_11 = Activation("relu")(conv_11)
    conv_12 = Convolution2D(512, (kernel, kernel), padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)
    conv_12 = Activation("relu")(conv_12)
    conv_13 = Convolution2D(512, (kernel, kernel), padding="same")(conv_12)
    conv_13 = BatchNormalization()(conv_13)
    conv_13 = Activation("relu")(conv_13)

    pool_5, mask_5 = MaxPoolingWithArgmax2D(pool_size)(conv_13)
    print("Build enceder done..")

    # decoder

    unpool_1 = MaxUnpooling2D(pool_size)([pool_5, mask_5])

    conv_14 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_1)
    conv_14 = BatchNormalization()(conv_14)
    conv_14 = Activation("relu")(conv_14)
    conv_15 = Convolution2D(512, (kernel, kernel), padding="same")(conv_14)
    conv_15 = BatchNormalization()(conv_15)
    conv_15 = Activation("relu")(conv_15)
    conv_16 = Convolution2D(512, (kernel, kernel), padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)
    conv_16 = Activation("relu")(conv_16)

    unpool_2 = MaxUnpooling2D(pool_size)([conv_16, mask_4])

    conv_17 = Convolution2D(512, (kernel, kernel), padding="same")(unpool_2)
    conv_17 = BatchNormalization()(conv_17)
    conv_17 = Activation("relu")(conv_17)
    conv_18 = Convolution2D(512, (kernel, kernel), padding="same")(conv_17)
    conv_18 = BatchNormalization()(conv_18)
    conv_18 = Activation("relu")(conv_18)
    conv_19 = Convolution2D(256, (kernel, kernel), padding="same")(conv_18)
    conv_19 = BatchNormalization()(conv_19)
    conv_19 = Activation("relu")(conv_19)

    unpool_3 = MaxUnpooling2D(pool_size)([conv_19, mask_3])

    conv_20 = Convolution2D(256, (kernel, kernel), padding="same")(unpool_3)
    conv_20 = BatchNormalization()(conv_20)
    conv_20 = Activation("relu")(conv_20)
    conv_21 = Convolution2D(256, (kernel, kernel), padding="same")(conv_20)
    conv_21 = BatchNormalization()(conv_21)
    conv_21 = Activation("relu")(conv_21)
    conv_22 = Convolution2D(128, (kernel, kernel), padding="same")(conv_21)
    conv_22 = BatchNormalization()(conv_22)
    conv_22 = Activation("relu")(conv_22)

    unpool_4 = MaxUnpooling2D(pool_size)([conv_22, mask_2])

    conv_23 = Convolution2D(128, (kernel, kernel), padding="same")(unpool_4)
    conv_23 = BatchNormalization()(conv_23)
    conv_23 = Activation("relu")(conv_23)
    conv_24 = Convolution2D(64, (kernel, kernel), padding="same")(conv_23)
    conv_24 = BatchNormalization()(conv_24)
    conv_24 = Activation("relu")(conv_24)

    unpool_5 = MaxUnpooling2D(pool_size)([conv_24, mask_1])

    conv_25 = Convolution2D(64, (kernel, kernel), padding="same")(unpool_5)
    conv_25 = BatchNormalization()(conv_25)
    conv_25 = Activation("relu")(conv_25)

    conv_26 = Convolution2D(n_labels, (1, 1), padding="valid")(conv_25)
    conv_26 = BatchNormalization()(conv_26)

    outputs = Activation(output_mode)(conv_26)
    print("Build decoder done..")

    model = Model(inputs=inputs, outputs=outputs, name="SegNet")

    return model

model = segnet(input_shape, n_labels=1, kernel=3, pool_size=(2,2), output_mode="sigmoid")
model.summary()

Build enceder done..
Build decoder done..
Model: "SegNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 4  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 256, 256, 64  2368        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 256, 256, 64  256        ['conv2d[0][0]']                 
 alization)                     )                  

 rmalization)                                                                                     
                                                                                                  
 activation_7 (Activation)      (None, 32, 32, 512)  0           ['batch_normalization_7[0][0]']  
                                                                                                  
 conv2d_8 (Conv2D)              (None, 32, 32, 512)  2359808     ['activation_7[0][0]']           
                                                                                                  
 batch_normalization_8 (BatchNo  (None, 32, 32, 512)  2048       ['conv2d_8[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_8 (Activation)      (None, 32, 32, 512)  0           ['batch_normalization_8[0][0]']  
          

 activation_16 (Activation)     (None, 32, 32, 512)  0           ['batch_normalization_16[0][0]'] 
                                                                                                  
 conv2d_17 (Conv2D)             (None, 32, 32, 512)  2359808     ['activation_16[0][0]']          
                                                                                                  
 batch_normalization_17 (BatchN  (None, 32, 32, 512)  2048       ['conv2d_17[0][0]']              
 ormalization)                                                                                    
                                                                                                  
 activation_17 (Activation)     (None, 32, 32, 512)  0           ['batch_normalization_17[0][0]'] 
                                                                                                  
 conv2d_18 (Conv2D)             (None, 32, 32, 256)  1179904     ['activation_17[0][0]']          
          

Total params: 29,459,525
Trainable params: 29,443,651
Non-trainable params: 15,874
__________________________________________________________________________________________________


In [None]:
# ----------- create directories
out_dir = '../results/' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + '_SegNet/'
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    os.makedirs(out_dir + '/plots')
    os.makedirs(out_dir + '/weights')
    os.makedirs(out_dir + '/predictions')
    os.makedirs(out_dir + '/bestweights')

    
# Define the path where you want to save the weights
checkpoint_path = out_dir + 'bestweights/' 

# Define the ModelCheckpoint callback
checkpoint = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,
    monitor='val_accuracy', 
    mode='max', 
    verbose=1
)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"]
)
history = model.fit(train_dataset, validation_data=val_dataset,epochs=epochs,callbacks=[checkpoint])

In [None]:
# ---------------------- save results


# Load the saved, optimal  weights
model.load_weights(checkpoint_path)

# Compile the model with the same optimizer and loss function used during training
model.compile(optimizer=keras.optimizers.Adam(),
    loss='binary_crossentropy',
    metrics=["accuracy"])
model.save_weights(out_dir+'model.hdf5')


# ----------- plot the training and validation loss
plt.plot(history.history['loss'], label='train loss')
plt.plot(history.history['val_loss'], label='val loss')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'loss.png')

# ----------- plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train accuracy')
plt.plot(history.history['val_accuracy'], label='val accuracy')
plt.legend()
plt.savefig(out_dir + '/plots/' + 'accuracy.png')

# ----------- save weights
#model.save(out_dir + '/weights/' + 'model.h5')

# ----------- save predictions
def visualize_predictions(index, test_dataset, out_dir):
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % batches
    image = image_batch[wrapped_index].numpy()
    
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)
    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    ground_truth = label_batch[wrapped_index].numpy()

    fig, ax = plt.subplots(2, 2, figsize=(10, 10));
    ax[0,0].imshow(image_rgb);
    ax[0,0].set_title("Input Image");
    ax[0,1].imshow(np.squeeze(ground_truth), cmap='gray');
    ax[0,1].set_title("Ground Truth");
    ax[1,0].imshow(np.squeeze(prediction), cmap='gray')
    ax[1,0].set_title("Prediction")
    ax[1,1].imshow(np.squeeze(prediction) > 0.5, cmap='gray')
    ax[1,1].set_title("Prediction (binary)")

    for i in range(2):
        for j in range(2):
            ax[i,j].axis('off')
            
    
    plt.savefig(out_dir + '/predictions/' + 'comparison_' + str(index) + '.png');
for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)

# ----------- save metrics

model_info = _02_evaluate_model.evaluate_model(
    "SegNet", 
    test_dataset, 
    model, 
    input_shape, 
    shuffled, 
    batches, 
    epochs, 
    augmentation_settings=augmentation_settings, 
    threshold=0.5)
df = pd.DataFrame(model_info)
df.to_csv(os.path.join(out_dir, 'metrics.csv'), index=False)

In [None]:
def visualize_predictions(index, test_dataset, out_dir, batches = 16):
    
    dir = "image_" + str(index)
    if not os.path.exists(out_dir + '/predictions/' + dir + '/'):
        os.makedirs(out_dir + '/predictions/' + dir + '/')
        os.makedirs(out_dir + '/predictions/' + dir + '/input_image')
        os.makedirs(out_dir + '/predictions/' + dir + '/ground_truth')
        os.makedirs(out_dir + '/predictions/' + dir + '/prediction')
        os.makedirs(out_dir + '/predictions/' + dir + '/prediction_binary')
    
    test_data_iter = iter(itertools.cycle(test_dataset))

    for i in range(index + 1):
        image_batch, label_batch = next(test_data_iter)

    wrapped_index = index % 16
    image = image_batch[wrapped_index].numpy()
    image_rgb = np.stack(
        (
            (image[:,:,0] - np.min(image[:,:,0])) * 255.0 / (np.max(image[:,:,0]) - np.min(image[:,:,0])),
            (image[:,:,1] - np.min(image[:,:,1])) * 255.0 / (np.max(image[:,:,1]) - np.min(image[:,:,1])),
            (image[:,:,2] - np.min(image[:,:,2])) * 255.0 / (np.max(image[:,:,2]) - np.min(image[:,:,2]))
        ),
        axis=-1
    ).astype(np.uint8)

    prediction = model.predict(np.expand_dims(image, axis=0))[0]
    plt.imsave(out_dir + '/predictions/' + dir + '/input_image/' + str(index) + '.png', image_rgb)

    ground_truth = label_batch[wrapped_index].numpy()
    plt.imsave(out_dir + '/predictions/' + dir + '/ground_truth/' + str(index) + '.png', np.squeeze(ground_truth), cmap='gray')

    plt.imsave(out_dir + '/predictions/' + dir + '/prediction/' + str(index) + '.png', np.squeeze(prediction), cmap='gray')

    prediction_binary = np.where(prediction > 0.5, 1, 0)
    plt.imsave(out_dir + '/predictions/' + dir + '/prediction_binary/' + str(index) + '.png', np.squeeze(prediction_binary), cmap='gray')

for i in range(80):
    visualize_predictions(i, test_dataset, out_dir)
