In [2]:
import numpy as np
import os
import tensorflow as tf
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
import glob
import random
import time

%matplotlib inline

In [80]:
def load_model(file_path):
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(file_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return detection_graph
            
def load_image(image_path):
    image = Image.open(image_path)
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)

def run_inference(sess, ops, image_tensor, image):
    output_dict = {}
    
    time_s = time.time()
    num_detections, boxes, scores, classes = sess.run(ops, feed_dict={image_tensor: image})
    time_t = time.time() - time_s
    
    output_dict['num_detections'] = int(num_detections[0])
    output_dict['detection_classes'] = classes[0].astype(np.uint8)
    output_dict['detection_boxes'] = boxes[0]
    output_dict['detection_scores'] = scores[0]
    output_dict['detection_time'] = time_t
    
    return output_dict

def detect_and_visualize(name, graph, image_paths, category_index, save = False):
    
    with graph.as_default():
    
        image_tensor = graph.get_tensor_by_name('image_tensor:0')
        boxes_tensor = graph.get_tensor_by_name('detection_boxes:0')
        scores_tensor = graph.get_tensor_by_name('detection_scores:0')
        classes_tensor = graph.get_tensor_by_name('detection_classes:0')
        detections_tensor = graph.get_tensor_by_name('num_detections:0')

        ops = [detections_tensor, boxes_tensor, scores_tensor, classes_tensor]

        with tf.Session() as sess:
            
            cols = 2
            rows = np.ceil(len(image_paths) / cols).astype('uint32')

            fig = plt.figure(figsize = (15, rows * 6))
            fig.suptitle('Model: {}'.format(name))
            
            for i, image_path in enumerate(image_paths):
                image = load_image(image_path)
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                # Actual detection.
                output_dict = run_inference(sess, ops, image_tensor, image_np_expanded)
                # Visualization of the results of a detection.
                vis_util.visualize_boxes_and_labels_on_image_array(
                  image,
                  output_dict['detection_boxes'],
                  output_dict['detection_classes'],
                  output_dict['detection_scores'],
                  category_index,
                  use_normalized_coordinates=True,
                  line_thickness=3)
                
                plt.subplot(rows, cols, i + 1)
                plt.xticks([])
                plt.yticks([])
                
                plt.imshow(image)
            
            fig.tight_layout()
            fig.subplots_adjust(top=0.97)

            if save: 
                fig.savefig(name + '.jpg', bbox_inches='tight')

In [91]:
MODELS_DIR = os.path.join('..', 'models', 'exported')
LABELS_MAP_PATH = os.path.join('..', 'config', 'labels_map.pbtxt')

CATEGORY_INDEX = label_map_util.create_category_index_from_labelmap(LABELS_MAP_PATH, use_display_name=True)

def analyze(data_dir, models, limit = 6):
    
    test_images = glob.glob(os.path.join(data_dir, '*.jpg'))

    random.shuffle(test_images)

    test_images = test_images[:limit]
    
    for model_name in models:
        model_path = os.path.join(MODELS_DIR, model_name, 'frozen_inference_graph.pb')
        detection_graph = load_model(model_path)
        detect_and_visualize(model_name, detection_graph, test_images, CATEGORY_INDEX, save = True)

In [None]:
DATA_DIR = os.path.join('..', 'data', 'simulator')

models = ['ssd_inception_v2_sim', 'ssd_inception_v2', 'ssd_mobilenet_v1', 'ssd_mobilenet_v2', 'ssdlite_mobilenet_v2']

analyze(DATA_DIR, models)