#TensorFlow - Advanced Computer Vision Module 3 Exercise

This notebook is the solution to the exercise of Module 3.

**Question**

Build a model that predicts the segmentation masks (pixel wise label map) of MNIST handwritten digits. 
The model should be trained on the [M2NIST dataset](https://www.kaggle.com/farhanhubble/multimnistm2nist).

M2NIST is multi digit MNIST. Each image has upto 3 digits from MNIST digits and the corresponding labels file has the segmentstion masks.

You can train a CNN from scratch for the downsampling path and use FCN-8 upsampling for producing the pixel wise label map.

##Imports

In [None]:
import os, re, time, json
import PIL.Image, PIL.ImageFont, PIL.ImageDraw
import numpy as np
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
from matplotlib import pyplot as plt
import tensorflow_datasets as tfds
import seaborn as sns
from sklearn.model_selection import train_test_split
import os

print("Tensorflow version " + tf.__version__)

##Parameters

In [None]:
BATCH_SIZE = 32

#There are 11 classes (0 to 9 digits + background class)
n_classes = 11

colors = [tuple(np.random.randint(256, size=3) / 255.0) for i in range(n_classes)]

##Download Dataset

[M2NIST](https://www.kaggle.com/farhanhubble/multimnistm2nist) is **multi digit** [MNIST](http://yann.lecun.com/exdb/mnist/). 
Each image has upto 3 digits from MNIST digits and the corresponding labels file has the segmentstion masks.

The dataset is available on [Kaggle](https://www.kaggle.com).
Link to the dataset: https://www.kaggle.com/farhanhubble/multimnistm2nist

To make it easier for you, we're hosting it on Google Cloud so you can download without kaggle credentials


In [None]:
!wget --no-check-certificate \
    https://storage.googleapis.com/laurencemoroney-blog.appspot.com/m2nist.zip \
    -O /tmp/m2nist.zip
import os
import zipfile



local_zip = '/tmp/m2nist.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp/training')
zip_ref.close()

##Load and Pre Process Dataset

This dataset can be easily preprocessed since it is available as **Numpy Array Files (.npy)**

1. **tmp/training/combined.npy** has the image files containing the multiple MNIST digits. Each image is of size **64 x 84**.

2. **tmp/training/segmented.npy** has corresponding segmentation masks. Each segmentation mask is also of size **64 x 84**.

This dataset has **5000** samples. You can make appropriate training and validation splits as required for the problem.

In [None]:
#Training, validation and test sizes
train_size = 4000
val_size = 800
test_size = 200

'''
This function maps image and segmentation masks. Images are normalized so that each pixel is in the range [-1, 1]
'''
def read_image_and_annotation(image, annotation):
  image = tf.cast(image, dtype=tf.float32)
  image = tf.reshape(image, (image.shape[0], image.shape[1], 1,))
  annotation = tf.cast(annotation, dtype=tf.int32)
  image = image / 127.5
  image -= 1
  return image, annotation

'''
This function creates training dataset from training splits of images and segmentation masks.
'''
def get_training_dataset(images, annos):
  training_dataset = tf.data.Dataset.from_tensor_slices((images, annos))
  training_dataset = training_dataset.map(read_image_and_annotation)

  training_dataset = training_dataset.shuffle(512, reshuffle_each_iteration=True)
  training_dataset = training_dataset.batch(BATCH_SIZE)
  training_dataset = training_dataset.repeat()
  training_dataset = training_dataset.prefetch(-1)

  return training_dataset


'''
This function creates validation dataset from validation splits of images and segmentation masks.
'''
def get_validation_dataset(images, annos):
  validation_dataset = tf.data.Dataset.from_tensor_slices((images, annos))
  validation_dataset = validation_dataset.map(read_image_and_annotation)
  validation_dataset = validation_dataset.batch(BATCH_SIZE)
  validation_dataset = validation_dataset.repeat()

  return validation_dataset


'''
This function creates test dataset from test splits of images and segmentation masks.
'''
def get_test_dataset(images, annos):
  test_dataset = tf.data.Dataset.from_tensor_slices((images, annos))
  test_dataset = test_dataset.map(read_image_and_annotation)
  test_dataset = test_dataset.batch(BATCH_SIZE, drop_remainder=True)

  return test_dataset

'''
This function loads the images and segments as numpy arrays from npy files and makes splits for training, validation and test datasets.
'''

def load_images_and_segments():
  #Loads images and segmentation masks.
  images = np.load('/tmp/training/combined.npy')
  segments = np.load('/tmp/training/segmented.npy')

  #Makes training, validation, test splits from loaded images and segmentation masks.
  train_images, val_images, train_annos, val_annos = train_test_split(images, segments, test_size=0.2, shuffle=True)
  val_images, test_images, val_annos, test_annos = train_test_split(val_images, val_annos, test_size=0.2, shuffle=True)

  return (train_images, train_annos), (val_images, val_annos), (test_images, test_annos)


In [None]:
#Load Dataset
train_slices, val_slices, test_slices = load_images_and_segments()

#Create training, validation, test datasets.
training_dataset = get_training_dataset(train_slices[0], train_slices[1])
validation_dataset = get_validation_dataset(val_slices[0], val_slices[1])
test_Dataset = get_test_dataset(test_slices[0], test_slices[1])


In [None]:
#@title Plot Utilities [RUN ME]
def fuse_with_pil(images):
  widths = (image.shape[1] for image in images)
  heights = (image.shape[0] for image in images)
  total_width = sum(widths)
  max_height = max(heights)

  new_im = PIL.Image.new('RGB', (total_width, max_height))

  x_offset = 0
  for im in images:
    pil_image = PIL.Image.fromarray(np.uint8(im))
    new_im.paste(pil_image, (x_offset,0))
    x_offset += im.shape[1]
  
  return new_im

def give_color_to_annotation(annotation):
  seg_img = np.zeros( (annotation.shape[0],annotation.shape[1],3) ).astype('float')
  for c in range(n_classes):
    segc = (annotation == c) #np.int32(annotation[:, :, c])
    seg_img[:,:,0] += segc*( colors[c][0] * 255.0)
    seg_img[:,:,1] += segc*( colors[c][1] * 255.0)
    seg_img[:,:,2] += segc*( colors[c][2] * 255.0)
  return seg_img

def show_annotation_and_prediction(image, annotation, prediction, iou_list, dice_score_list):
  new_ann = np.argmax(annotation, axis=2)
  true_img = give_color_to_annotation(new_ann)
  pred_img = give_color_to_annotation(prediction)


  image = image + 1
  image = image * 127.5
  image = np.reshape(image, (image.shape[0], image.shape[1],))
  image = np.uint8(image)
  images = [image, np.uint8(pred_img), np.uint8(true_img)]

  metrics_by_id = [(idx, iou, dice_score) for idx, (iou, dice_score) in enumerate(zip(iou_list, dice_score_list)) if iou > 0.0 and idx < 10]
  metrics_by_id.sort(key=lambda tup: tup[1], reverse=True)  # sorts in place

  
  display_string_list = ["{}: IOU: {} Dice Score: {}".format(idx, iou, dice_score) for idx, iou, dice_score in metrics_by_id]
  display_string = "\n".join(display_string_list)

  plt.figure(figsize=(15, 4))

  for idx, im in enumerate(images):
    plt.subplot(1, 3, idx+1)
    if idx == 1:
      plt.xlabel(display_string)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(im)
    # if idx == 1:
    #     plt.text(0.5, 4.2, display_string, verticalalignment='center')


  



def show_annotation_and_image(image, annotation):
  new_ann = np.argmax(annotation, axis=2)
  seg_img = give_color_to_annotation(new_ann)
  
  image = image + 1
  image = image * 127.5
  image = np.reshape(image, (image.shape[0], image.shape[1],))

  image = np.uint8(image)
  images = [image, seg_img]
  
  images = [image, seg_img]
  fused_img = fuse_with_pil(images)
  plt.imshow(fused_img)


def list_show_annotation(dataset):
  ds = dataset.unbatch()
  #ds = ds.shuffle(buffer_size=100)

  plt.figure(figsize=(20, 15))
  plt.title("Images And Annotations")
  plt.subplots_adjust(bottom=0.1, top=0.9, hspace=0.05)

  for idx, (image, annotation) in enumerate(ds.take(9)):
    plt.subplot(5, 5, idx + 1)
    plt.yticks([])
    plt.xticks([])
    show_annotation_and_image(image.numpy(), annotation.numpy())

 

###Let's Take a Look at the Dataset

In [None]:
list_show_annotation(training_dataset)

In [None]:
list_show_annotation(validation_dataset)

##Define the Model

Like any FCN, our model will have two paths:
1. **Downsampling Path** - Custom CNN from scratch
2. **Upsampling Path** - FCN 8

###Define Baisc Convolution Block

In [None]:
IMAGE_ORDERING = 'channels_last'
'''
This function defines a basic convolution block having two Conv2D layers each followed by a LeakyReLU and Batch normalization.
This is the basic convolution block for our CNN
'''
def conv_block(x, filters, strides, pooling_size, pool_strides):
  x = tf.keras.layers.Conv2D(filters, strides, padding='same', data_format=IMAGE_ORDERING)(x)
  x = tf.keras.layers.LeakyReLU()(x)
  x = tf.keras.layers.Conv2D(filters, strides, padding='same', data_format=IMAGE_ORDERING)(x)
  x = tf.keras.layers.MaxPooling2D(pooling_size, pool_strides, data_format=IMAGE_ORDERING)(x)
  x = tf.keras.layers.LeakyReLU()(x)
  x = tf.keras.layers.BatchNormalization()(x)

  return x

###Define DownSampling Path

In [None]:
'''
This function chains together convolution building blocks to create a feature extraction CNN minus the fully connected layers.
'''
def FCN8(inputs, nClasses ,  input_height=64, input_width=84):
   

    img_input = tf.keras.layers.Input(shape=(input_height,input_width, 1)) ## Assume 224,224,3

    
    ## Block 1
    x=tf.keras.layers.ZeroPadding2D(((0, 0), (0, 96-input_width)))(img_input)
    x = conv_block(x, 32, 3, 2, 2)
    f1 = x
    
    # Block 2
    x = conv_block(x, 64, 3, 2, 2)
    f2 = x

    # Block 3
    x = conv_block(x, 128, 3, 2, 2)
    f3 = x

    # Block 4
    x = conv_block(x, 256, 3, 2, 2)
    f4 = x

    #Block 5
    x = conv_block(x, 256, 3, 2, 2)
    f5 = x

  
    return (f1, f2, f3, f4, f5), img_input

###Define FCN 8 decoder

In [None]:
'''
This function defines the upsampling path taking the outputs of convolutions at each stage as arguments.
Note the FCN - 8 style upsampling done in this function.
'''
def fcn8_decoder(convs, n_classes):
  f1, f2, f3, f4, f5 = convs
  n = 512
  #Add convolutional layers on top of the CNN extractor.
  o = ( tf.keras.layers.Conv2D( n , (7 , 7) , activation='relu' , padding='same', name="conv6", data_format=IMAGE_ORDERING))(f5)
  o = tf.keras.layers.Dropout(0.5)(o)

  o = ( tf.keras.layers.Conv2D( n , (1 , 1) , activation='relu' , padding='same', name="conv7", data_format=IMAGE_ORDERING))(o)
  o = tf.keras.layers.Dropout(0.5)(o)
   
  ##Create a label map from the output of the CNN
  o = tf.keras.layers.Conv2D(n_classes,  (1, 1), activation='relu' , padding='same', data_format=IMAGE_ORDERING)(o)
  o = tf.keras.layers.Conv2DTranspose(n_classes , kernel_size=(4,4) ,  strides=(2,2) , use_bias=False, data_format=IMAGE_ORDERING )(o)
  o = tf.keras.layers.Cropping2D(cropping=(1,1), data_format=IMAGE_ORDERING)(o)

  o2 = f4
  o2 = ( tf.keras.layers.Conv2D(n_classes , ( 1 , 1 ) , activation='relu' , padding='same', data_format=IMAGE_ORDERING))(o2)

  o = tf.keras.layers.Add()([o, o2])

  o = (tf.keras.layers.Conv2DTranspose( n_classes , kernel_size=(4,4) ,  strides=(2,2) , use_bias=False, data_format=IMAGE_ORDERING ))(o)

  o2 = ( tf.keras.layers.Conv2D(n_classes , ( 1 , 1 ) , activation='relu' , padding='same', data_format=IMAGE_ORDERING))(f3)

  o = tf.keras.layers.Cropping2D(cropping=(1, 1), data_format=IMAGE_ORDERING)(o)
  o = tf.keras.layers.Add()([o, o2])
     
  o = tf.keras.layers.Conv2DTranspose(n_classes , kernel_size=(8,8) ,  strides=(8,8) , use_bias=False , data_format=IMAGE_ORDERING )(o)
  o=tf.keras.layers.Cropping2D(((0, 0), (0, 96-84)))(o)

  o = (tf.keras.layers.Activation('sigmoid'))(o)


  return o

In [None]:
img_input = tf.keras.layers.Input(shape=(64,84, 1))
convs, img_input = FCN8(n_classes, img_input)
dec_op = fcn8_decoder(convs, n_classes)
model = tf.keras.Model(inputs = img_input, outputs = dec_op)


In [None]:
model.summary()

##Compile Model

In [None]:
model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(0.001),
              metrics=['accuracy'])

##Train Model

In [None]:
EPOCHS = 70
steps_per_epoch = 4000//BATCH_SIZE
validation_steps = 800//BATCH_SIZE
test_steps = 200//BATCH_SIZE


history = model.fit(training_dataset,
                    steps_per_epoch=steps_per_epoch, validation_data=validation_dataset, validation_steps=validation_steps, epochs=EPOCHS)

##Evaluate Model

###Make Predictions

In [None]:
results = model.predict(test_Dataset, steps=test_steps)


In [None]:
results = np.argmax(results, axis=3)

###Visualize Predictions

In [None]:
def class_wise_metrics(y_true, y_pred):
  class_wise_iou = []
  class_wise_dice_score = []

  smoothening_factor = 0.00001

  for i in range(n_classes):
    intersection = np.sum((y_pred == i) * (y_true == i))
    y_true_area = np.sum((y_true == i))
    y_pred_area = np.sum((y_pred == i))
    combined_area = y_true_area + y_pred_area
    
    iou = (intersection) / (combined_area - intersection + smoothening_factor)
    class_wise_iou.append(iou)
    
    dice_score =  2 * ((intersection) / (combined_area + smoothening_factor))
    class_wise_dice_score.append(dice_score)

  return class_wise_iou, class_wise_dice_score



In [None]:
#@title Visualize Output [RUN ME]
integer_slider = 173 #@param {type:"slider", min:0, max:191, step:1}

ds = test_Dataset.unbatch()
ds = ds.batch(200)
images = []

y_true_segments = []
for image, annotation in ds.take(1):
  y_true_segments = annotation
  images = image
  
  
  
iou, dice_score = class_wise_metrics(np.argmax(y_true_segments[integer_slider], axis=2), results[integer_slider])  
show_annotation_and_prediction(image[integer_slider], annotation[integer_slider], results[integer_slider], iou, dice_score)



###Compute IOU Score and Dice Score

In [None]:
cls_wise_iou, cls_wise_dice_score = class_wise_metrics(np.argmax(y_true_segments, axis=3), results)


for idx, (iou, dice_score) in enumerate(zip(cls_wise_iou[:-1], cls_wise_dice_score[:-1])):
  print("Digit {}: IOU: {} Dice Score: {}".format(idx, iou, dice_score)) 
