# Image Segmentation
Image Segmentation is a technique in which every pixel is given a class label. Thus, the task of image segmentation is to train a neural network to output a pixel-wise mask of the image. 

The dataset that we will use is the Oxford-IIIT Pet Dataset, created by Parkhi et al. Each pixel is maps to one of three categories:

* Class 1 : Pixel belonging to the pet.
* Class 2 : Pixel bordering the pet.
* Class 3 : None of the above/ Surrounding pixel.



In [0]:
# Installing necessary packages
!pip install -q git+https://github.com/tensorflow/examples.git
!pip install -q -U tfds-nightly

In [0]:
# Importing libraries
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output
from datetime import datetime

tfds.disable_progress_bar()
%matplotlib inline

In [0]:
# Gloabl vars & hyper-params
BATCH_SIZE = 64
BUFFER_SIZE = 1000
LABELS = 3
TRAIN_ITERATIONS = 57
VALID_ITERATIONS = 20
EPOCHS = 20

## Utility Functions

In [0]:
def show_results(images, save_fig=False):
  fig = plt.figure() 
  fig.figsize=(15,15)
  titles = ['Input Image', 'True Mask', 'Predicted Mask']
  n_cols = len(images)
  for i in range(n_cols):
    plt.subplot(1, n_cols, i+1)
    plt.title(titles[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(images[i]))
    plt.axis('off')
  plt.show()

  if save_fig:
    f_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    fig.savefig(f"{f_name}.png")

In [0]:
def normalize(image, mask):
  """
  This function will normalize the image in the range of 0-1. For the sake of 
  convenience, will covert the mask from {1, 2, 3} into {0, 1, 2} by subtracting
  1 from the original mask.
  """

  normalized_image = tf.cast(image, tf.float32) / 255.0
  mask -= 1
  return normalized_image, mask

In [0]:
def resize_image(image, height, width):
  resized_image = tf.image.resize(image, [height, width], 
                                      method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return resized_image

In [0]:
@tf.function
def preprocess_image(image, mask, flip=True):
  # Resizing image and mask to 128 X 128 X 3
  resized_image = resize_image(image, 128, 128)
  resized_mask = resize_image(mask, 128, 128)

  # Normalizing image and updating mask values
  processed_image, processed_mask = normalize(resized_image, resized_mask)

  if flip and tf.random.uniform(()) > 0.5:
    processed_image = tf.image.flip_left_right(processed_image)
    processed_mask = tf.image.flip_left_right(processed_mask)

  return processed_image, processed_mask

In [0]:
def prepare_training_data(datapoint):
  image = datapoint["image"]
  mask = datapoint["segmentation_mask"]
  return preprocess_image(image, mask)

In [0]:
def prepare_test_data(datapoint):
  image = datapoint["image"]
  mask = datapoint["segmentation_mask"]
  return preprocess_image(image, mask, False)

In [0]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)[..., tf.newaxis]
  return pred_mask[0]

In [0]:
class TestCallback(tf.keras.callbacks.Callback):

  def __init__(self, test_dataset):
    for images, real_masks in test_dataset.take(1):
        self.test_img, self.test_mask = images[0], real_masks[0]
    self.model = model

  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    predicted_mask = create_mask(self.model.predict(self.test_img[tf.newaxis, ...]))
    show_results([self.test_img, self.test_mask, predicted_mask])
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

## Preparing Input Pipeline

In [0]:
# Loading Dataset
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

In [0]:
# Creating our training data
train = dataset['train'].map(prepare_training_data, num_parallel_calls=
                             tf.data.experimental.AUTOTUNE)
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat().\
                            prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

# Creating our test data
test = dataset['test'].map(prepare_test_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

## Model Definition

The model being used here is a modified U-Net. In-order to learn robust features, and reduce the number of trainable parameters, we will use a pretrained model for the encoder. Thus, our encoder will be a pretrained MobileNetV2 model, whose intermediate outputs will be used, and the decoder will be the upsample block already implemented in TensorFlow Examples in the Pix2pix tutorial.

**NOTE**: Any pretrained network can be used for encoder.


In [0]:
class Masker(tf.keras.Model):

  def __init__(self, base_encoder_model, ):
    super(Masker, self).__init__()
    self.loss_cal = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)
    self._prepare_encoder(base_encoder_model)
    self.decoder = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
    ]

    self.out = tf.keras.layers.Conv2DTranspose(LABELS, 3, 
                                    strides=2, 
                                    padding='same')  # 64x64 -> 128x128

  def call(self, inputs):
    # Running encoder
    skips = self.encoder(inputs)
    
    inputs = skips[-1]
    skips = reversed(skips[:-1])

    # Running decoder with skip connections
    for up, skip in zip(self.decoder, skips):
      inputs = up(inputs)
      inputs = tf.keras.layers.Concatenate()([inputs, skip])

    return self.out(inputs)

  def _prepare_encoder(self, base_encoder_model):
    # Using the activations of the below layers for our encoder network
    layer_names = [
        'block_1_expand_relu',   # 64x64
        'block_3_expand_relu',   # 32x32
        'block_6_expand_relu',   # 16x16
        'block_13_expand_relu',  # 8x8
        'block_16_project',      # 4x4
    ]
    layers = [base_encoder_model.get_layer(name).output for name in layer_names]

    # Creating the feature extraction model
    self.encoder = tf.keras.Model(inputs=base_encoder_model.input, outputs=layers)
    self.encoder.trainable = False


## Let Training Begin

In [0]:
# Loading MobileNetV2 model
base_encoder_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

In [0]:
# Loading our segmentation model
model = Masker(base_encoder_model)
model.compile(optimizer ='adam',
              loss = model.loss_cal,
              metrics = ['accuracy'])

callback_ob = TestCallback(test_dataset)

In [0]:
model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=TRAIN_ITERATIONS,
                          validation_steps=VALID_ITERATIONS,
                          validation_data=test_dataset,
                          callbacks=[callback_ob])

In [0]:
# Plotting performance
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
epochs = range(EPOCHS)

plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

In [0]:
# More Predictions
for test_imgs, test_masks in test_dataset.take(1):
  for sample in range(5):
    test_img = test_imgs[sample]
    test_mask = test_masks[sample]
    predicted_mask = create_mask(model.predict(test_img[tf.newaxis, ...]))
    print(predicted_mask[0,0,0])
    show_results([test_img, test_mask, predicted_mask], True)
