# **Semantic Segmentation Using TensorFlow Keras**

Semantic Segmentation laid down the fundamental path to advanced Computer Vision tasks such as object detection, shape recognition, autonomous driving, robotics, and virtual reality. Semantic segmentation can be defined as the process of pixel-level image classification into two or more Object classes. It differs from image classification entirely, as the latter performs image-level classification. For instance, consider an image that consists mainly of a zebra, surrounded by grass fields, a tree and a flying bird. Image classification tells us that the image belongs to the ‘zebra’ class. It can not tell where the zebra is or what its size or pose is. But, semantic segmentation of that image may tell that there is a zebra, grass field, a bird and a tree in the given image (classifies parts of an image into separate classes). And it tells us which pixels in the image belong to which class.

In this practice session, we will discuss semantic segmentation using TensorFlow Keras. Readers are expected to have a fundamental knowledge of deep learning, image classification and transfer learning. 

## **Implementation**




In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn tensorflow keras opencv-python pillow scikit-image torch torchvision \
     tqdm --user -q --no-warn-script-location

!python -m pip install -q git+https://github.com/tensorflow/examples.git --user -q

import IPython
IPython.Application.instance().kernel.do_shutdown(True)


# Import necessary frameworks, libraries and modules

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import cv2
from scipy import io
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import matplotlib.pyplot as plt

# Download and prepare data

We use Clothing Co-Parsing public dataset as our supervised dataset. This dataset has 1000 images of people (one person per image). There are 1000 label images corresponding to those original images. Label images have 59 segmented classes corresponding to classes such as hair, shirt, shoes, skin, sunglasses and cap

In [None]:
!git clone https://github.com/bearpaw/clothing-co-parsing.git

In [None]:
!echo $PWD

In [None]:
!ls clothing-co-parsing/

Input images are in the photos directory, and labelled images are in the annotations directory. Let’s extract the input images from the respective source directory.

In [None]:
images = []
for i in range(1,1001):
    url = './clothing-co-parsing/photos/%04d.jpg'%(i)
    img = cv2.imread(url)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    images.append(tf.convert_to_tensor(img))

Let’s extract the labelled images (segmented images) from the respective source directory.

In [None]:
masks = []
for i in range(1,1001):
    url = './clothing-co-parsing/annotations/pixel-level/%04d.mat'%(i)
    file = io.loadmat(url)
    mask = tf.convert_to_tensor(file['groundtruth'])
    masks.append(mask)

How many examples do we have now?

In [None]:
len(images), len(masks)

As mentioned in the source files, there are 1000 images and 1000 labels. Visualize some images to get better insights.

In [None]:
plt.figure(figsize=(10,4))
for i in range(1,4):
    plt.subplot(1,3,i)
    img = images[i]
    plt.imshow(img, cmap='jet')
    plt.colorbar()
    plt.axis('off')
plt.show()

Each colour in the above images refer to a specific class. We observe that the person and her/his wearings are segmented, leaving the surrounding unsegmented. 

In [None]:
masks[0].numpy().min(), masks[0].numpy().max()

In [None]:
plt.figure(figsize=(10,4))
for i in range(1,4):
    plt.subplot(1,3,i)
    img = masks[i]
    plt.imshow(img, cmap='jet')
    plt.colorbar()
    plt.axis('off')
plt.show()

# Build Downstack with a Pre-trained CNN

Load DenseNet121 from in-built applications.

In [None]:
base = keras.applications.DenseNet121(input_shape=[128,128,3], 
                                      include_top=False, 
                                      weights='imagenet')

In [None]:
len(base.layers)

The DenseNet121 model has 427 layers. We need to identify suitable layers whose output will be used for skip connections. Plot the entire model, along with the feature shapes.

In [None]:
keras.utils.plot_model(base, show_shapes=True)

We select the final ReLU activation layer for each feature map size, i.e. 4, 8, 16, 32, and 64, required for skip-connections. Write down the names of the selected ReLU layers in a list.

In [None]:
skip_names = ['conv1/relu', # size 64*64
             'pool2_relu',  # size 32*32
             'pool3_relu',  # size 16*16
             'pool4_relu',  # size 8*8
             'relu'        # size 4*4
             ]

Obtain the outputs of these layers.

In [None]:
skip_outputs = [base.get_layer(name).output for name in skip_names]
for i in range(len(skip_outputs)):
    print(skip_outputs[i])

Build the downstack with the above layers. We use the pre-trained model as such, without any fine-tuning.

In [None]:
downstack = keras.Model(inputs=base.input,
                       outputs=skip_outputs)
downstack.trainable = False

# Build Upstack


Build the upstack using an upsampling template.

In [None]:
# Four upstack layers for upsampling sizes 
# 4->8, 8->16, 16->32, 32->64 
upstack = [pix2pix.upsample(512,3),
          pix2pix.upsample(256,3),
          pix2pix.upsample(128,3),
          pix2pix.upsample(64,3)]

We can explore the individual layers in each upstack block.

In [None]:
upstack[0].layers

# Build U-Net model with skip-connections

Build a U-Net model by merging downstack and upstack with skip-connections.

In [None]:
# define the input layer
inputs = keras.layers.Input(shape=[128,128,3])

# downsample 
down = downstack(inputs)
out = down[-1]

# prepare skip-connections
skips = reversed(down[:-1])
# choose the last layer at first 4 --> 8

# upsample with skip-connections
for up, skip in zip(upstack,skips):
    out = up(out)
    out = keras.layers.Concatenate()([out,skip])
    
# define the final transpose conv layer
# image 128 by 128 with 59 classes
out = keras.layers.Conv2DTranspose(59, 3,
                                  strides=2,
                                  padding='same',
                                  )(out)
# complete unet model
unet = keras.Model(inputs=inputs, outputs=out)

Visualize our U-Net model.

In [None]:
keras.utils.plot_model(unet, show_shapes=True)

# Data Preprocessing

The model is perfectly ready. We can start training if data preprocessing is performed. Since we have limited images, we prepare more data through augmentation.

In [None]:
def resize_image(image):
    image = tf.cast(image, tf.float32)
    image = image/255.0
    # resize image
    image = tf.image.resize(image, (128,128))
    return image

def resize_mask(mask):
    mask = tf.expand_dims(mask, axis=-1)
    mask = tf.image.resize(mask, (128,128))
    mask = tf.cast(mask, tf.uint8)
    return mask    

Resize images and masks to the size 128 by 128.

In [None]:
X = [resize_image(i) for i in images]
y = [resize_mask(m) for m in masks]
len(X), len(y)

In [None]:
images[0].dtype, masks[0].dtype, X[0].dtype, y[0].dtype

In [None]:
plt.imshow(X[0])
plt.colorbar()
plt.show()

plt.imshow(np.squeeze(y[0]), cmap='jet')
plt.colorbar()
plt.show()

Functions for augmentation

We have 800 train examples. That’s too low for training. We define a couple of augmentation functions to generate more train examples.

In [None]:
def brightness(img, mask):
    img = tf.image.adjust_brightness(img, 0.1)
    return img, mask

def gamma(img, mask):
    img = tf.image.adjust_gamma(img, 0.1)
    return img, mask

def hue(img, mask):
    img = tf.image.adjust_hue(img, -0.1)
    return img, mask

def crop(img, mask):
    img = tf.image.central_crop(img, 0.7)
    img = tf.image.resize(img, (128,128))
    mask = tf.image.central_crop(mask, 0.7)
    mask = tf.image.resize(mask, (128,128))
    mask = tf.cast(mask, tf.uint8)
    return img, mask

def flip_hori(img, mask):
    img = tf.image.flip_left_right(img)
    mask = tf.image.flip_left_right(mask)
    return img, mask

def flip_vert(img, mask):
    img = tf.image.flip_up_down(img)
    mask = tf.image.flip_up_down(mask)
    return img, mask

def rotate(img, mask):
    img = tf.image.rot90(img)
    mask = tf.image.rot90(mask)
    return img, mask

# Split Data for training and validation

In [None]:
from sklearn.model_selection import train_test_split

train_X, val_X,train_y, val_y = train_test_split(X,y, 
                                                      test_size=0.2, 
                                                      random_state=0
                                                     )
train_X = tf.data.Dataset.from_tensor_slices(train_X)
val_X = tf.data.Dataset.from_tensor_slices(val_X)

train_y = tf.data.Dataset.from_tensor_slices(train_y)
val_y = tf.data.Dataset.from_tensor_slices(val_y)

train_X.element_spec, train_y.element_spec, val_X.element_spec, val_y.element_spec

# Data Augmentation 

Apply augmentation to the data with the above functions. With 7 augmentation functions and 800 input examples, we can get 7*800 = 5600 new examples. Including original examples, we get 5600+800 = 6400 examples for training. That sounds good!

In [None]:
#Zip input images and ground truth masks. 
train = tf.data.Dataset.zip((train_X, train_y))
val = tf.data.Dataset.zip((val_X, val_y))

# perform augmentation on train data only

a = train.map(brightness)
b = train.map(gamma)
c = train.map(hue)
d = train.map(crop)
e = train.map(flip_hori)
f = train.map(flip_vert)
g = train.map(rotate)

train = train.concatenate(a)
train = train.concatenate(b)
train = train.concatenate(c)
train = train.concatenate(d)
train = train.concatenate(e)
train = train.concatenate(f)
train = train.concatenate(g)

Prepare data batches. Shuffle the train data.

In [None]:
BATCH = 64
AT = tf.data.AUTOTUNE
BUFFER = 1000

STEPS_PER_EPOCH = 800//BATCH
VALIDATION_STEPS = 200//BATCH

In [None]:
train = train.cache().shuffle(BUFFER).batch(BATCH).repeat()
train = train.prefetch(buffer_size=AT)
val = val.batch(BATCH)

# Check for Model and Data compatibility

Let’s check whether everything is good with the data and the model by sampling one example image and predict it with the untrained model.

In [None]:
example = next(iter(train))
preds = unet(example[0])
plt.imshow(example[0][60])
plt.colorbar()
plt.show()

In [None]:
pred_mask = tf.argmax(preds, axis=-1)
pred_mask = tf.expand_dims(pred_mask, -1)
plt.imshow(np.squeeze(pred_mask[0]))
plt.colorbar()

Compile the model with RMSprop optimizer, Sparse Categorical Cross-entropy loss function and accuracy metric. Train the model for 20 epochs.

In [None]:
unet.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=keras.optimizers.RMSprop(lr=0.001),
            metrics=['accuracy']) 

# Training

In [None]:
hist = unet.fit(train,
               validation_data=val,
               steps_per_epoch=STEPS_PER_EPOCH,
               validation_steps=VALIDATION_STEPS,
               epochs=50)

# Prediction

Let’s check how our model performs.

In [None]:
img, mask = next(iter(val))
pred = unet.predict(img)
plt.figure(figsize=(10,5))
for i in pred:
    plt.subplot(121)
    i = tf.argmax(i, axis=-1)
    plt.imshow(i,cmap='jet')
    plt.axis('off')
    plt.title('Prediction')
    break
plt.subplot(122)
plt.imshow(np.squeeze(mask[0]), cmap='jet')
plt.axis('off')
plt.title('Ground Truth')
plt.show()

# Performance Curves

In [None]:
history = hist.history
acc=history['accuracy']
val_acc = history['val_accuracy']

plt.plot(acc, '-', label='Training Accuracy')
plt.plot(val_acc, '--', label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

#**Related Articles:**

> * [Semantic Segmentation Using TensorFlow Keras](https://analyticsindiamag.com/semantic-segmentation-using-tensorflow-keras/)

> * [Convert Image to Pencil Sketch](https://analyticsindiamag.com/converting-image-into-a-pencil-sketch-in-python/)

> * [Image Classification Task with and without Data Augmentation](https://analyticsindiamag.com/image-data-augmentation-impacts-performance-of-image-classification-with-codes/)

> * [Image Data Augmentation Work As A Regularizer](https://analyticsindiamag.com/why-does-image-data-augmentation-work-as-a-regularizer-in-deep-learning/)

> * [Guide to Pillow](https://analyticsindiamag.com/hands-on-guide-to-pillow-python-library-for-image-processing/)

> * [Guide to Pgmagick](https://analyticsindiamag.com/complete-guide-on-pgmagick-python-tool-for-image-processing/)

