# Image Segmentation with Layer

[![Open in Layer](https://development.layer.co/assets/badge.svg)](https://app.layer.ai/layer/image-segmentation) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/layerai/examples/blob/main/image-segmentation/segmentation.ipynb) [![Layer Examples Github](https://badgen.net/badge/icon/github?icon=github&label)](https://github.com/layerai/examples/tree/main/image-segmentation)

In this project, we are going to focus on image segmentation with a modified [U-net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/). 
This project has been created out of the [Image Segmentation tutorial](https://www.tensorflow.org/tutorials/images/segmentation) of Tensorflow.


## What is image segmentation?

In an image classification task the network assigns a label (or class) to each input image. However, suppose you want to know the shape of that object, which pixel belongs to which object, etc. In this case you will want to assign a class to each pixel of the image. This task is known as segmentation. A segmentation model returns much more detailed information about the image. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging to name a few.

This tutorial uses the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/) ([Parkhi et al, 2012](https://www.robots.ox.ac.uk/~vgg/publications/2012/parkhi12a/parkhi12a.pdf)). The dataset consists of images of 37 pet breeds, with 200 images per breed (~100 each in the training and test splits). Each image includes the corresponding labels, and pixel-wise masks. The masks are class-labels for each pixel. Each pixel is given one of three categories:

- Class 1: Pixel belonging to the pet.
- Class 2: Pixel bordering the pet.
- Class 3: None of the above/a surrounding pixel.

## Install Layer

Let's start with installing Layer.

In [None]:
!pip install layer -U

## Login to Layer

Let's login to Layer first.

In [None]:
import layer
layer.login()

## Initialize Layer Project
Now we are ready to init our project. Layer Project is basically an ML Repo hosted on Layer where you can store your datasets, models, metrics

In [None]:
layer.init("image-segmentation")

## Download the Oxford-IIIT Pets dataset

The dataset is [available from TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet). The segmentation masks are included in version 3+.

In [None]:
import tensorflow as tf
from IPython.display import clear_output
import matplotlib.pyplot as plt

In [None]:
def get_dataset():
  import tensorflow_datasets as tfds  
  return tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

 In addition, the image color values are normalized to the `[0,1]` range. Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}.

In [None]:
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

In [None]:
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

The dataset already contains the required training and test splits, so continue to use the same splits.

The following class performs a simple augmentation by randomly-flipping an image.
Go to the [Image augmentation](data_augmentation.ipynb) tutorial to learn more.


In [None]:
class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
  
  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

Build the input pipeline, applying the Augmentation after batching the inputs.

In [None]:
def get_batches():
  dataset, info = get_dataset()

  TRAIN_LENGTH = info.splits['train'].num_examples
  BATCH_SIZE = 64
  BUFFER_SIZE = 1000
  STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

  train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
  test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

  train_batches = (
      train_images
      .cache()
      .shuffle(BUFFER_SIZE, seed=20)
      .batch(BATCH_SIZE)
      .repeat()
      .map(Augment())
      .prefetch(buffer_size=tf.data.AUTOTUNE))

  test_batches = test_images.batch(BATCH_SIZE)
  return train_batches, test_batches

Visualize an image example and its corresponding mask from the dataset.

In [None]:
def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

## Define the model
The model being used here is a modified [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). In-order to learn robust features and reduce the number of trainable parameters, you will use a pretrained model - MobileNetV2 - as the encoder. For the decoder, you will use the upsample block, which is already implemented in the [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) example in the TensorFlow Examples repo. (Check out the [pix2pix: Image-to-image translation with a conditional GAN](../generative/pix2pix.ipynb) tutorial in a notebook.)


As mentioned, the encoder will be a pretrained MobileNetV2 model which is prepared and ready to use in `tf.keras.applications`. The encoder consists of specific outputs from intermediate layers in the model. Note that the encoder will not be trained during the training process.

Following `upsample` function is from [Tensorflow Examples](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) repo.

In [None]:
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  """Upsamples an input.
  Conv2DTranspose => Batchnorm => Dropout => Relu
  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer
  Returns:
    Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  elif norm_type.lower() == 'instancenorm':
    result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

In [None]:
def unet_model(output_channels:int):
  base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

  # Use the activations of these layers
  layer_names = [
      'block_1_expand_relu',   # 64x64
      'block_3_expand_relu',   # 32x32
      'block_6_expand_relu',   # 16x16
      'block_13_expand_relu',  # 8x8
      'block_16_project',      # 4x4
  ]
  base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

  # Create the feature extraction model
  down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

  down_stack.trainable = False


  up_stack = [
    upsample(512, 3),  # 4x4 -> 8x8
    upsample(256, 3),  # 8x8 -> 16x16
    upsample(128, 3),  # 16x16 -> 32x32
    upsample(64, 3),   # 32x32 -> 64x64
  ]


  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Note that the number of filters on the last layer is set to the number of `output_channels`. This will be one output channel per class.

## Train the model

Now, all that is left to do is to compile and train the model. 

Since this is a multiclass classification problem, use the `tf.keras.losses.CategoricalCrossentropy` loss function with the `from_logits` argument set to `True`, since the labels are scalar integers instead of vectors of scores for each pixel of every class. 

When running inference, the label assigned to the pixel is the channel with the highest value. This is what the `create_mask` function is doing.

In [None]:
def show_predictions(model, ataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

The callback defined below is used to observe how the model improves while it is training.

In [None]:
class LayerCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    layer.log(logs, epoch)

In [None]:
def array_to_img(arr):
  return tf.keras.utils.array_to_img(arr).resize((512,512))

def predict(model, img):
  return create_mask(model.predict(img[tf.newaxis, ...]))

In [None]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

## Train your model with Layer

You can train, register and share your model with Layer. Layer handles all the overhead, the pain, and the complexity of infrastructure. In this example we will train our mask prediction model on a remote GPU instance from our local notebook.

In [None]:
from layer.decorators import model, fabric
@model("mask_predictor")
@fabric("f-gpu-small")
def train():
    train_batches, test_batches = get_batches()
    parameters = {"EPOCHS": 20, "STEPS_PER_EPOCH" : 50, 
                "OUTPUT_CLASSES": 3, "VALIDATION_STEPS":10}
    layer.log(parameters)

    images, masks = next(test_batches.as_numpy_iterator())
    layer.log({"original_image": array_to_img(images[0])})
    layer.log({"original_mask": array_to_img(masks[0])})
    mask_model = unet_model(output_channels=parameters["OUTPUT_CLASSES"])
    mask_model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
    model_history = mask_model.fit(train_batches, epochs=parameters["EPOCHS"],
                            steps_per_epoch=parameters["STEPS_PER_EPOCH"],
                            validation_steps=parameters["VALIDATION_STEPS"],
                            validation_data=test_batches, callbacks=[LayerCallback()]
                            )
    layer.log({"predicted_mask": array_to_img(predict(mask_model, images[0]))})  

    return mask_model

layer.run([train])

## Make predictions

Now, let's make some predictions with fetching the model from Layer. We can fetch any version of our trained model and use it for predictions. Let's fetch the latest model:

In [None]:
import layer
my_model = layer.get_model('layer/image-segmentation/models/mask_predictor:3.1').get_train()

In [None]:
_, test_batches = get_batches()
images, masks = next(test_batches.as_numpy_iterator())
display([images[0], masks[0], predict(my_model, images[0])])

## Where to go from here?

Now that you have ran your Layer Project, you can:

- Join our [Slack Community ](https://bit.ly/layercommunityslack)
- Visit [Layer Examples Repo](https://github.com/layerai/examples) for more examples
- Browse [Trending Layer Projects](https://layer.ai) on our mainpage
- Check out [Layer Documentation](https://docs.app.layer.ai) to learn more