# Image Segmentation

## Description

## Exploratory Data Analysis

In [None]:
import tensorflow as tf
from pycocotools.coco import COCO
import os

os.chdir('/root')
dir = '../tf/notebooks/portfolio/image_segmentation'

train_annotations = COCO('datasets/coco2017/annotations/instances_train2017.json')
val_annotations = COCO('datasets/coco2017/annotations/instances_val2017.json')

In [None]:
catIDs = train_annotations.getCatIds()
cats = train_annotations.loadCats(catIDs)

print(f"Number of Unique Categories: {len(catIDs)}")
print("Categories Names:")
cats

In [None]:
filterClasses = ['laptop', 'tv', 'cat']

# Fetch class IDs only corresponding to the filterClasses
catIds = train_annotations.getCatIds(catNms=filterClasses) 
# Get all images containing the above Category IDs
imgIds = train_annotations.getImgIds(catIds=catIds)
print(train_annotations.getCatIds())
print("Number of images containing all the  classes:", len(imgIds))
print('imgIds:', imgIds)

In [None]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
%matplotlib inline

imgId = imgIds[10]
print('Image ID:', imgId)

img_meta = train_annotations.loadImgs(imgId)[0]
img = Image.open(f'datasets/coco2017/images/train2017/{img_meta["file_name"]}')

ann_ids = train_annotations.getAnnIds(imgIds=[imgId], iscrowd=None)
anns = train_annotations.loadAnns(ann_ids)

print('Example annotation for one object in scene:')
print(anns[0])

# Show image
plt.imshow(img)
plt.show()

# Show bounding boxes and segmentation
plt.imshow(img)
train_annotations.showAnns(anns, draw_bbox=True)
plt.show()

# Show masks
mask = np.zeros((img_meta['height'],img_meta['width']))
for i in range(len(anns)):
    pixel_value = i+1
    mask = np.maximum(train_annotations.annToMask(anns[i])*pixel_value, mask)
plt.imshow(mask)
plt.show()

In [None]:
class_counts = [len(train_annotations.getImgIds(catIds=cat['id'])) for cat in cats]
class_names = [cat['name'] for cat in cats]
figure(figsize=(16,8), dpi=80)
plt.bar(class_names, class_counts)
plt.xticks(rotation=90)
plt.show()

## Methods

In [None]:
import cv2
import gc

# Create a data generator to generate batches of image/segmentation pairs
# The ground truth will be a rank-3 tensor with 80 channels, where each channel represents a binary mask of a single class

class CategoryMappingHelper():
  def __init__(self, coco_annotations):
    self.catIds = coco_annotations.getCatIds()
    self.categories = [*coco_annotations.loadCats(self.catIds), {'id': 0, 'name': 'background'}]
    self.filter_idx_to_category_id = [cat['id'] for cat in self.categories]
    self.filter_idx_to_category_name = [cat['name'] for cat in self.categories]
    self.categry_id_to_category_name = {cat['id']: cat['name'] for cat in self.categories}
    self.category_id_to_filter_idx = {id: idx for idx, id in enumerate(self.filter_idx_to_category_id)}
    self.category_name_to_filter_idx = {name: idx for idx, name in enumerate(self.filter_idx_to_category_name)}
  def get (self, value, in_key, out_key):
    if   (in_key == 'filter_idx' and out_key == 'category_id'):
      return self.filter_idx_to_category_id[value]
    elif (in_key == 'filter_idx' and out_key == 'category_name'):
      return self.filter_idx_to_category_name[value]
    elif (in_key == 'category_id' and out_key == 'category_name'):
      return self.categry_id_to_category_name[value]
    elif (in_key == 'category_id' and out_key == 'filter_idx'):
      return self.category_id_to_filter_idx[value]
    elif (in_key == 'category_name' and out_key == 'filter_idx'):
      return self.category_name_to_filter_idx[value]
    else:
      raise Exception("in_key:out_key pair not supported. Valid pairs: \nfilter_idx:category_id\nfilter_idx:category_name\ncategory_id:category_name\ncategory_id:filter_idx\ncategory_name:filter_idx")


def cocoDataGenerator(coco_annotations, image_folder, catMapper, batch_size, shuffle=False, input_size=(128,128), output_size=(128,128)):
  numCategories = len(catMapper.catIds)
  imageIds = coco_annotations.getImgIds()
  if shuffle:
    np.random.shuffle(imageIds)

  X_batch = []
  y_batch = []
  
  for imageId in imageIds:
    img_meta = coco_annotations.loadImgs(imageId)[0]
    ann_ids = coco_annotations.getAnnIds(imgIds=[imageId], iscrowd=None)
    anns = coco_annotations.loadAnns(ann_ids)

    X = Image.open(f'{image_folder}/{img_meta["file_name"]}')
    X = np.array(X)
    X = cv2.resize(X, input_size)

    masks = [np.zeros(output_size)] * (numCategories)
    bg = np.zeros(output_size)
    for ann in anns:
      # if catMapper.get(ann['category_id'], 'category_id', 'category_name') != 'person':
      new_mask = cv2.resize(coco_annotations.annToMask(ann), output_size)
      # create mask for the background
      bg = np.maximum(new_mask, bg)
      prev_mask = masks[catMapper.get(ann['category_id'], 'category_id', 'filter_idx')]
      combined_masks = np.maximum(new_mask, prev_mask)
      masks[catMapper.get(ann['category_id'], 'category_id', 'filter_idx')] = combined_masks
    bg =  np.logical_not(bg).astype(int)
    masks.append(bg)
    y = np.dstack(masks)
    
    if X.shape == (*input_size,3) and y.shape == (*output_size,numCategories+1):
      X_batch.append(X)
      y_batch.append(y)

    if len(X_batch) == batch_size:
      yield np.array(X_batch), np.array(y_batch)
      X_batch = []
      y_batch = []
      gc.collect()

In [None]:
# test out generator
catMapper = CategoryMappingHelper(train_annotations)
test_generator = cocoDataGenerator(train_annotations, 'datasets/coco2017/images/train2017', catMapper, batch_size=1, shuffle=False)
test_iterator = iter(test_generator)

X_batch, y_batch = next(test_generator)
X, y = (X_batch[0], y_batch[0])
plt.imshow(X)
plt.show()
plt.title('motorcycle mask from ground truth')
plt.imshow(y[:,:,catMapper.get('motorcycle', 'category_name', 'filter_idx')])
plt.show()
plt.title('person mask from ground truth')
plt.imshow(y[:,:,catMapper.get('person', 'category_name', 'filter_idx')])
plt.show()
plt.title('background mask from ground truth')
plt.imshow(y[:,:,-1])
plt.show()

X_batch, y_batch = next(test_generator)
X, y = (X_batch[0], y_batch[0])
plt.imshow(X)
plt.show()
plt.title('person mask from ground truth')
plt.imshow(y[:,:,catMapper.get('person', 'category_name', 'filter_idx')])
plt.show()
plt.title('knife mask from ground truth')
plt.imshow(y[:,:,catMapper.get('knife', 'category_name', 'filter_idx')])
plt.show()
plt.title('background mask from ground truth')
plt.imshow(y[:,:,-1])
plt.show()

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

def augment(data_gen, imageDataGeneratorArgs={}, seed=None):
  rng = np.random.default_rng(seed if seed is not None else np.random.choice(range(9999)))

  X_gen = ImageDataGenerator(**imageDataGeneratorArgs)
  imageDataGeneratorArgs_mask = imageDataGeneratorArgs.copy()
   # Remove the brightness argument for the binary masks but keep spatial augmentations.
  imageDataGeneratorArgs_mask.pop('brightness_range', None)
  y_gen = ImageDataGenerator(**imageDataGeneratorArgs_mask)
  
  for X_batch, y_batch in data_gen:
    seed = rng.choice(range(9999))
    g_x = X_gen.flow(X_batch, 
                    batch_size = X_batch.shape[0], 
                    seed = seed, 
                    shuffle=False)
    g_y = y_gen.flow(y_batch, 
                    batch_size = y_batch.shape[0], 
                    seed = seed, 
                    shuffle=False)
      
    X_aug = next(g_x)
    y_aug = next(g_y)
                
    yield X_aug, y_aug
  

In [None]:
# show that augmentation works on test_generator
test_augment = dict(featurewise_center = False, 
                      samplewise_center = False,
                      rotation_range = 10, 
                      width_shift_range = 0.01, 
                      height_shift_range = 0.01, 
                      brightness_range = (0.8,1.2),
                      shear_range = 0.01,
                      zoom_range = [1, 1.25],  
                      horizontal_flip = True, 
                      vertical_flip = False,
                      fill_mode = 'reflect',
                      data_format = 'channels_last')
aug_test_generator = augment(test_generator, test_augment) 
aug_test_iterator = iter(aug_test_generator)

X_batch, y_batch = next(aug_test_iterator)
X, y = (X_batch[0], y_batch[0])
plt.imshow(X.astype(int))
plt.show()
plt.title('person mask from ground truth')
plt.imshow(y[:,:,catMapper.get('person', 'category_name', 'filter_idx')])
plt.show()

In [None]:
# val_generator = augment(cocoDataGenerator(val_annotations, 'datasets/coco2017/images/val2017', catMapper, batch_size=8, shuffle=True))

# train_generator = cocoDataGenerator(train_annotations, 'datasets/coco2017/images/train2017', catMapper, batch_size=8, shuffle=True)
# train_augment = dict(featurewise_center = False, 
#                       samplewise_center = False,
#                       rotation_range = 10, 
#                       width_shift_range = 0.01, 
#                       height_shift_range = 0.01, 
#                       brightness_range = (0.8,1.2),
#                       shear_range = 0.01,
#                       zoom_range = [1, 1.25],  
#                       horizontal_flip = True, 
#                       vertical_flip = False,
#                       fill_mode = 'reflect',
#                       data_format = 'channels_last')
# train_aug_generator = augment(train_generator, train_augment)                 

In [None]:
# Train for a a few epics to sanity check

# from architectures import u_net_x5
# import importlib
# importlib.reload(u_net_x5)

# model = u_net_x5.define_unet()
# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # 5 started working without label smoothing
#               loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
#               metrics=['mse'])

In [None]:
# !apt-get -y update
# !apt-get -y install graphviz
# model.summary()
# tf.keras.utils.plot_model(model, show_shapes=True)

In [None]:
# model_history = model.fit(train_aug_generator, epochs=5,
#                           steps_per_epoch=100,
#                           validation_steps=20,
#                           validation_data=val_generator)

In [None]:
def show_argmax_masks(img_path, model, input_size=(128,128), binary=False):
  img = np.asarray(Image.open(img_path))
  original_img_dims = (img.shape[0], img.shape[1])
  X = np.array([cv2.resize(img, input_size)])
  y = model.predict(X)
  y = np.squeeze(y)
  # y = np.squeeze(y)[:,:,:-1]

  

  if binary:
    # out = y[:,:,1]
    # out = y[:,:,1] > 0.3
    out = tf.argmax(y, axis=-1)

    plt.imshow(np.squeeze(X))
    plt.show()
    plt.imshow(out)
    plt.show()
  else:
    # print(y)
    out = np.apply_along_axis(np.argmax, 2, y)

    cat_pixel_counts = np.bincount(out.flatten())
    above_zero = np.count_nonzero(cat_pixel_counts > 0)
    top_n = above_zero if above_zero < 5 else 5
    most_freq_idxs = np.argpartition(cat_pixel_counts, -top_n)[-top_n:]
    print(most_freq_idxs)
  
    plt.imshow(np.squeeze(X))
    plt.show()

    masks = np.zeros(input_size)
    for i, filter_idx in enumerate(most_freq_idxs[:-1]):
        pixel_value = i + 1
        masks = np.maximum((out == filter_idx)*pixel_value, masks)
    plt.imshow(masks)
    plt.show()

    for i in most_freq_idxs[::-1]:
      plt.title(catMapper.get(i, 'filter_idx', 'category_name'))
      plt.imshow(out == i)
      plt.show()



## Results

Full Results from script along with tensorboard and results from best model on validation images and never before seen images

In [None]:
# load well trained model and run show functions on it

dir = '../tf/notebooks/portfolio/image_segmentation'

final_model = tf.keras.models.load_model(f'{dir}/models/test_model7.h5', compile=False)

In [None]:
show_argmax_masks('datasets/coco2017/images/val2017/000000150265.jpg', final_model, input_size=(224,224))
show_argmax_masks('datasets/coco2017/images/val2017/000000148707.jpg', final_model, input_size=(224,224))
show_argmax_masks('datasets/coco2017/images/val2017/000000290843.jpg', final_model, input_size=(224,224))
show_argmax_masks('datasets/coco2017/images/val2017/000000291861.jpg', final_model, input_size=(224,224))
show_argmax_masks('datasets/coco2017/images/val2017/000000291551.jpg', final_model, input_size=(224,224))
show_argmax_masks('datasets/coco2017/images/val2017/000000149406.jpg', final_model, input_size=(224,224))


In [None]:
show_argmax_masks(f'{dir}/assets/dog.jpg', final_model, input_size=(224,224))
show_argmax_masks(f'{dir}/assets/cat.jpg', final_model, input_size=(224,224))
show_argmax_masks(f'{dir}/assets/frisbee.jpg', final_model, input_size=(224,224))
show_argmax_masks(f'{dir}/assets/hot_dog.jpg', final_model, input_size=(224,224))
show_argmax_masks(f'{dir}/assets/tennis.jpg', final_model, input_size=(224,224))

In [None]:
# !ls datasets/coco2017/images/val2017

In [None]:
# plt.imshow(Image.open('datasets/coco2017/images/val2017/000000150265.jpg'))

In [None]:
# plt.imshow(Image.open('datasets/coco2017/images/val2017/000000438304.jpg'))

In [None]:
# plt.imshow(Image.open('datasets/coco2017/images/val2017/000000148707.jpg'))