# Mask RCNN inference


# Imports

In [None]:
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

# This is needed since the notebook is stored in the object_detection folder.
# sys.path.append("..")
from object_detection.utils import ops as utils_ops

if tf.__version__ < '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')

from copy import copy
import cv2
from skimage.morphology import skeletonize

import time
import datetime
import pickle

# This is needed to display the images.
%matplotlib inline

from tensorflow.models.research.object_detection.utils import label_map_util
from tensorflow.models.research.object_detection.utils import visualization_utils as vis_util

from utility_functions import load_image_into_numpy_array
from utility_functions import run_inference_for_single_image

# Model preparation 

In [None]:
# Path to frozen detection graph.
MODEL_NAME = 'inception_v2/fine_tuned_model_100k_final_sets'
PATH_TO_CKPT = os.path.join('./downloaded_models', MODEL_NAME, 'frozen_inference_graph.pb')

## Load a (frozen) Tensorflow model into memory.

In [None]:
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

## Loading label map


In [None]:
PATH_TO_LABELS = '/Users/daniel/Documents/UCL/Project/Code/tensorflow_Mask_RCNN/data/worm_label_map.pbtxt'
NUM_CLASSES = 1

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


# Load images for inference

In [None]:
NUM_IMAGES = 5

DATASET_DIR = './data/fullsize_images/'
datasets = [f for f in os.listdir(DATASET_DIR) if not f.startswith('.')]
print(datasets)

In [None]:
dataset = datasets[4]
print("Running inference on dataset {}".format(dataset))

PATH_TO_TEST_IMAGES_DIR = os.path.join(DATASET_DIR, dataset)
# PATH_TO_TEST_IMAGES_DIR = '/Users/daniel/Documents/UCL/Project/Data/annotation-data/fullsize_collated_dataset/NIC199_worms10_food1-10_Set7_Pos4_Ch4_19052017_153012/'


FNAMES = [f for f in os.listdir(PATH_TO_TEST_IMAGES_DIR) if not f.startswith('.')]
FNAMES = sorted(FNAMES, key=int)[:NUM_IMAGES]
# TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, '{}/image/image_{}.png'.format(i,i)) for i in FNAMES] # np.random.choice(len(FNAMES), size=NUM_IMAGES, replace=False)]

# Size of output images.
IMAGE_SIZE = (40,40)

# Detection

In [None]:
output_dicts_list = []
inference_times = []

visualise_outputs = True
save_anns_to_file = False
save_overlays_to_file = False


for fName in FNAMES:
    
    image_path = os.path.join(PATH_TO_TEST_IMAGES_DIR, '{}/image/image_{}.png'.format(fName,fName))

    image = Image.open(image_path)
    # the array based representation of the image will be used later in order to prepare the
    # result image with boxes and labels on it.
    image_np = load_image_into_numpy_array(image)
    
    start = datetime.datetime.now()

    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_np_expanded = np.expand_dims(image_np, axis=0)
    
    # Actual detection.
    output_dict = run_inference_for_single_image(image_np, detection_graph)
    
    end = datetime.datetime.now()
    elapsed = end - start
    print("Inference for one image: {}.{}s".format(elapsed.seconds,round(elapsed.microseconds,2))) 
    inference_times.append(elapsed.seconds)

    if visualise_outputs:
        vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          output_dict['detection_boxes'],
          output_dict['detection_classes'],
          output_dict['detection_scores'],
          category_index,
          instance_masks=output_dict.get('detection_masks'),
          use_normalized_coordinates=True,
          line_thickness=2)
        plt.figure(figsize=IMAGE_SIZE)
        plt.title("Image {}".format(image_path[:-4]))
        plt.imshow(image_np)
        plt.show()
        plt.close()

    # Keeping only worms which scored > 0.5
    found_worms = np.where(output_dict['detection_scores'] > 0.5)
    output_dict['detection_boxes'] = output_dict['detection_boxes'][found_worms]
    output_dict['detection_classes'] = output_dict['detection_classes'][found_worms]
    output_dict['detection_scores'] = output_dict['detection_scores'][found_worms]
    output_dict['detection_masks'] = output_dict['detection_masks'][found_worms]
    output_dict['skeletons'] = []
    output_dict['frame_num'] = fName
    
    for m in output_dict['detection_masks']:
        output_dict['skeletons'].append(skeletonize(m).astype(np.uint8))

    
    OUTPUT_DIR_PATH = os.path.join('./data/inference_outputs', MODEL_NAME, dataset)

    
    if save_anns_to_file:
        #Save outputs to Pickle file
        os.makedirs(os.path.join(OUTPUT_DIR_PATH, 'annotations'), exist_ok=True)
        ANNS_OUTPUT_PATH = os.path.join(OUTPUT_DIR_PATH,'annotations', fName) + '.pickle'
        with open(ANNS_OUTPUT_PATH, 'wb') as fp:
            pickle.dump(output_dict, fp, protocol=pickle.HIGHEST_PROTOCOL)
            
    if save_overlays_to_file:
        #save image with annotations overlaid to file
        os.makedirs(os.path.join(OUTPUT_DIR_PATH, 'images'), exist_ok=True)
        IMG_OUTPUT_PATH = os.path.join(OUTPUT_DIR_PATH,'images', fName) + '.png'
        
        plt.figure(figsize=IMAGE_SIZE)
        
        # If the image ahsn't already been visualised, we need 
        # to add the masks and boxes now
        if not visualise_outputs:
            vis_util.visualize_boxes_and_labels_on_image_array(
              image_np,
              output_dict['detection_boxes'],
              output_dict['detection_classes'],
              output_dict['detection_scores'],
              category_index,
              instance_masks=output_dict.get('detection_masks'),
              use_normalized_coordinates=True,
              line_thickness=1)
            
        plt.imshow(image_np)
        plt.axis('off')
        plt.savefig(fname=IMG_OUTPUT_PATH, bbox_inches='tight', pad_inches=0)
        plt.close
        
    
    output_dicts_list.append(output_dict)
    

In [None]:
mean_time = round(sum(inference_times) / len(inference_times), 2)
print("Average inference time for {} images: {}s".format(NUM_IMAGES, mean_time))

# Skeletonisation


Visualise the last image and its skeletons as a sense-check

In [None]:
skeletons = np.array(output_dicts_list[-1]['skeletons'])

In [None]:
plt.figure(figsize=(80,80))

plt.subplot(1,2,1)
plt.imshow(image_np)

plt.subplot(1,2,2)
plt.imshow(np.sum(skeletons, axis=0))

plt.show()
plt.close()

# Tracking using SORT

In [None]:
from sort import sort
import matplotlib.patches as patches

In [None]:
display = True
verbose = False
colours = np.random.rand(32,3)



#create instance of SORT
worm_tracker = sort.Sort(max_age=5, min_hits=0) 

# get detections
all_detections = [np.hstack((i['detection_boxes']*2048, np.expand_dims(i['detection_scores'],1))) for i in output_dicts_list]


for frame in range(len(all_detections)):
        
    
    # update SORT
    detections = all_detections[frame]

    
    ids = worm_tracker.update(detections)
    
    if verbose:
        print('Detections: {}'.format(detections))
        print('ids shape: {}'.format(ids.shape))
        print('ids: {}'.format(ids))
    
    print("Worm IDs: {}".format(sorted(ids[:,-1], key=int)))
    
    if(display):
        
        fName = FNAMES[frame]
    
        image_path = os.path.join(PATH_TO_TEST_IMAGES_DIR, '{}/image/image_{}.png'.format(fName,fName))

        img = Image.open(image_path)
        img_np = load_image_into_numpy_array(img)
        
        fig = plt.figure(figsize=(IMAGE_SIZE))
        
        
        ax = fig.add_subplot(121, aspect='equal')
        ax.imshow(img_np)

        for d in ids:
            d = d.astype(np.int32)
            ax.add_patch(patches.Rectangle((d[1],d[0]),d[3]-d[1],d[2]-d[0],
                                           fill=False,
                                           lw=3,
                                           ec=colours[d[4]%32,:]))
            ax.set_adjustable('box-forced')
            plt.text(d[1],d[0],d[-1],color='w')
        
        ax_2 = fig.add_subplot(122, aspect='equal')
        vis_util.visualize_boxes_and_labels_on_image_array(
              img_np,
              output_dicts_list[frame]['detection_boxes'],
              output_dicts_list[frame]['detection_classes'],
              output_dicts_list[frame]['detection_scores'],
              category_index,
              instance_masks=output_dicts_list[frame].get('detection_masks'),
              use_normalized_coordinates=True,
              line_thickness=2)
        ax_2.imshow(img_np)
        
        plt.legend()
        plt.show()
        plt.close
