<a href="https://colab.research.google.com/github/consequencesunintended/Petnet/blob/master/TernausNet_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## What is image segmentation?
So far you have seen image classification, where the task of the network is to assign a label or class to an input image. However, suppose you want to know where an object is located in the image, the shape of that object, which pixel belongs to which object, etc. In this case you will want to segment the image, i.e., each pixel of the image is given a label. Thus, the task of image segmentation is to train a neural network to output a pixel-wise mask of the image. This helps in understanding the image at a much lower level, i.e., the pixel level. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging to name a few.

The dataset that will be used for this tutorial is the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), created by Parkhi *et al*. The dataset consists of images, their corresponding labels, and pixel-wise masks. The masks are basically labels for each pixel. Each pixel is given one of three categories :

*   Class 1 : Pixel belonging to the pet.
*   Class 2 : Pixel bordering the pet.
*   Class 3 : None of the above/ Surrounding pixel.

In [0]:
!pip install tensorflow_datasets

In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

In [0]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from IPython.display import clear_output
import matplotlib.pyplot as plt

## Download the Oxford-IIIT Pets dataset

The dataset is already included in TensorFlow datasets, all that is needed to do is download it. The segmentation masks are included in version 3+.

In [0]:
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

The following code performs a simple augmentation of flipping an image. In addition,  image is normalized to [0,1]. Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, let's subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}.

In [0]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

In [0]:
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

In [0]:
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

The dataset already contains the required splits of test and train and so let's continue to use the same split.

In [0]:
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 50
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

In [0]:
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

In [0]:
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

Let's take a look at an image example and it's correponding mask from the dataset.

In [0]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

In [0]:
for image, mask in train.take(1):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

In [0]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation, Reshape
from tensorflow.keras.layers import Convolution2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Convolution2DTranspose
from tensorflow.keras.layers import concatenate

## Define the model
The model being used here is a modified U-Net called [TernausNet](https://github.com/ternaus/TernausNet). A U-Net consists of an encoder (downsampler) and decoder (upsampler). TernausNet is U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation. But for the purpose of this project I tried a completely pre-trained model to evaluate the network on its own.

The reason to output three channels is because there are three possible labels for each pixel. Think of this as multi-classification where each pixel is being classified into three classes.

In [0]:
n_labels = 3
input_shape = [128,128,3]
kernel=3

# encoder
inputs = Input(shape=input_shape)

conv_1 = Convolution2D(3, (kernel, kernel), padding="same")(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Activation("relu")(conv_1)

conv_2 = Convolution2D(64, (kernel, kernel), padding="same")(conv_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Activation("relu")(conv_2)

#128
pool_1 = MaxPooling2D((2, 2))(conv_2)
#64

conv_3 = Convolution2D(64, (kernel, kernel), padding="same")(pool_1)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Activation("relu")(conv_3)

conv_4 = Convolution2D(128, (kernel, kernel), padding="same")(conv_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Activation("relu")(conv_4)

#64
pool_2 = MaxPooling2D((2, 2))(conv_4)
#32

conv_5 = Convolution2D(128, (kernel, kernel), padding="same")(pool_2)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Activation("relu")(conv_5)

conv_6 = Convolution2D(256, (kernel, kernel), padding="same")(conv_5)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Activation("relu")(conv_6)

conv_7 = Convolution2D(256, (kernel, kernel), padding="same")(conv_6)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Activation("relu")(conv_7)

#32
pool_3 = MaxPooling2D((2, 2))(conv_7)
#16

conv_8 = Convolution2D(256, (kernel, kernel), padding="same")(pool_3)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Activation("relu")(conv_8)

conv_9 = Convolution2D(512, (kernel, kernel), padding="same")(conv_8)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Activation("relu")(conv_9)

conv_10 = Convolution2D(512, (kernel, kernel), padding="same")(conv_9)
conv_10 = BatchNormalization()(conv_10)
conv_10 = Activation("relu")(conv_10)

#16
pool_4 = MaxPooling2D((2, 2))(conv_10)
#8

conv_11 = Convolution2D(512, (kernel, kernel), padding="same")(pool_4)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Activation("relu")(conv_11)

conv_12 = Convolution2D(512, (kernel, kernel), padding="same")(conv_11)
conv_12 = BatchNormalization()(conv_12)
conv_12 = Activation("relu")(conv_12)

conv_13 = Convolution2D(512, (kernel, kernel), padding="same")(conv_12)
conv_13 = BatchNormalization()(conv_13)
conv_13 = Activation("relu")(conv_13)

#8
pool_5 = MaxPooling2D((2, 2))(conv_13)
#4

print("Build enceder done..")

#middle section

conv_14 = Convolution2D(512, (kernel, kernel), padding="same")(pool_5)
conv_14 = BatchNormalization()(conv_14)
conv_14 = Activation("relu")(conv_14)

conv_15 = Convolution2D(512, (kernel, kernel), padding="same")(conv_14)
conv_15 = BatchNormalization()(conv_15)
conv_15 = Activation("relu")(conv_15)

print("Build midde part done..")

# decoder

#4
unpool1 = Convolution2DTranspose(256, (kernel, kernel), strides=(2,2), padding="same")(conv_15)
#8

conv_16 = Convolution2D(256, (kernel, kernel), padding="same")(unpool1)
conv_16 = BatchNormalization()(conv_16)
conv_16 = Activation("relu")(conv_16)


conv_17 = concatenate([conv_16, conv_13])


conv_18 = Convolution2D(512, (3, 3), padding="same")(conv_17)
conv_18 = BatchNormalization()(conv_18)
conv_18 = Activation("relu")(conv_18)

#8
unpool2 = Convolution2DTranspose(256, (3, 3), strides=(2,2), padding="same")(conv_18)
#16

conv_19 = Convolution2D(256, (3, 3), padding="same")(unpool2)
conv_19 = BatchNormalization()(conv_19)
conv_19 = Activation("relu")(conv_19)


conv_20 = concatenate([conv_19, conv_10])


conv_21 = Convolution2D(512, (3, 3), padding="same")(conv_20)
conv_21 = BatchNormalization()(conv_21)
conv_21 = Activation("relu")(conv_21)

#16
unpool3 = Convolution2DTranspose(128, (3, 3), strides=(2,2), padding="same")(conv_21)
#32

conv_22 = Convolution2D(128, (3, 3), padding="same")(unpool3)
conv_22 = BatchNormalization()(conv_22)
conv_22 = Activation("relu")(conv_22)


conv_23 = concatenate([conv_22, conv_7])


conv_24 = Convolution2D(256, (3, 3), padding="same")(conv_23)
conv_24 = BatchNormalization()(conv_24)
conv_24 = Activation("relu")(conv_24)

#32
unpool4 = Convolution2DTranspose(64, (3, 3), strides=(2,2), padding="same")(conv_24)
#64

conv_25 = Convolution2D(64, (3, 3), padding="same")(unpool4)
conv_25 = BatchNormalization()(conv_25)
conv_25 = Activation("relu")(conv_25)


conv_26 = concatenate([conv_25, conv_4])


conv_27 = Convolution2D(128, (3, 3), padding="same")(conv_26)
conv_27 = BatchNormalization()(conv_27)
conv_27 = Activation("relu")(conv_27)

#64
unpool5 = Convolution2DTranspose(32, (3, 3), strides=(2,2), padding="same")(conv_27)
#128

conv_28 = Convolution2D(32, (3, 3), padding="same")(unpool5)
conv_28 = BatchNormalization()(conv_28)
conv_28 = Activation("relu")(conv_28)


conv_29 = concatenate([conv_28, conv_2])


conv_30 = Convolution2D(n_labels, (3, 3), padding="same")(conv_29)
conv_30 = BatchNormalization()(conv_30)
conv_30 = Activation("relu")(conv_30)

outputs = Activation("softmax")(conv_30)
print("Build decoder done..")


## Train the model
Now, all that is left to do is to compile and train the model. The loss being used here is losses.sparse_categorical_crossentropy. The reason to use this loss function is because the network is trying to assign each pixel a label, just like multi-class prediction. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class, and losses.sparse_categorical_crossentropy is the recommended loss for such a scenario. Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing.

In [0]:
model = Model(inputs=inputs, outputs=outputs, name="TernausNet")

from keras import optimizers

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Let's try out the model to see what it predicts before training.

In [0]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

In [0]:
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

In [0]:
show_predictions()

Let's observe how the model improves while it is training. To accomplish this task, a callback function is defined below. 

In [0]:
class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [0]:
%%time
EPOCHS = 40
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])

In [0]:
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

## Make predictions

Let's make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results.

In [0]:
show_predictions(test_dataset, 3)