<a href="https://colab.research.google.com/github/harjeevanmaan/SemanticSegmentation/blob/master/Semantic_Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import logging
import PIL.Image as Image
!pip install git+https://github.com/tensorflow/examples.git
!pip install -U tfds-nightly
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output

tfds.disable_progress_bar()
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

!mkdir -p /content/downloads/manual/cityscapes    #need to store data here to make use of the tfds.load function
!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=my_username&password=my_password&submit=Login' https://www.cityscapes-dataset.com/login/
#need to use your own username and password to download the cityscapes dataset
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -O /content/downloads/manual/cityscapes/leftImg8bit_trainvaltest.zip
!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -O /content/downloads/manual/cityscapes/gtFine_trainvaltest.zip

cscape, info = tfds.load('cityscapes', data_dir="/content", with_info=True)

IMAGE_RES = 128
BATCH_SIZE = 8
OUTPUT_CHANNELS = 35
EPOCHS = 20
VAL_SUBSPLITS = 5

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  return input_image, input_mask

def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))
  input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))
 
  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask
  
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))
  input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))

  input_image, input_mask = normalize(input_image, input_mask)
  return input_image, input_mask

tr_num = info.splits['train'].num_examples
val_num = info.splits['validation'].num_examples
test_num = info.splits['test'].num_examples
train = cscape['train'].map(load_image_train)

tr_batches = cscape['train'].map(load_image_train).batch(BATCH_SIZE).shuffle(BATCH_SIZE).repeat().prefetch(1)
test_batches = cscape['test'].map(load_image_test).batch(BATCH_SIZE).prefetch(1)

STEPS_PER_EPOCH = tr_num//BATCH_SIZE

base_model = tf.keras.applications.MobileNetV2(input_shape=[IMAGE_RES, IMAGE_RES, 3], include_top=False)

layer_names = [
               'block_1_expand_relu',
               'block_3_expand_relu',
               'block_6_expand_relu',
               'block_13_expand_relu',
               'block_16_expand_relu',
]
layers = [base_model.get_layer(name).output for name in layer_names]

down_stack= tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False

up_stack = [
            pix2pix.upsample(512, 3),
            pix2pix.upsample(256, 3),
            pix2pix.upsample(128, 3),
            pix2pix.upsample(64, 3),
]

def unet_model(output_channels):
  inputs = tf.keras.layers.Input(shape=[IMAGE_RES, IMAGE_RES, 3])
  x = inputs

  skips = down_stack(x)
  x = skips[-1]
  skips = reversed(skips[:-1])

  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same')
  x = last(x)
  return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print('\nSample Predictions after epoch {}\n'.format(epoch+1))

VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(tr_batches, epochs=EPOCHS,
                          steps_per_epoch = STEPS_PER_EPOCH,
                          validation_steps = VALIDATION_STEPS,
                          validation_data = test_batches,
                          callbacks=[DisplayCallback()])