In [None]:
import tensorflow as tf
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
PATH_TO_MODEL = './frozen_inference_graph.pb'
NUM_CLASSES = 5
SCORE_THRESHOLD = .5
COLORS = ['?', 'red', 'purple', 'green', 'blue', 'orange']

In [None]:
def get_label(class_index):
    return ['?', 'arrows', 'test', 'ABC', 'squares', 'influenza'][class_index];

class Detector(object):
    def __init__(self):
        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
            self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
            self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
        self.sess = tf.Session(graph=self.detection_graph)
        
    def run_detection(self, image):
        img = np.asarray(image, dtype="int32")
        # Bounding Box Detection.
        with self.detection_graph.as_default():
            # Expand dimension since the model expects image to have shape [1, None, None, 3].
            img_expanded = np.expand_dims(img, axis=0)  
            (boxes, scores, classes, num) = self.sess.run(
                [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
                feed_dict={self.image_tensor: img_expanded})
            
        # We only accept one input image (batch_size == 1), so only need to return the first result
        return boxes[0], scores[0], classes[0], num[0]
    
    def get_top_predictions(self, scores, classes, detections):
        top_predictions =  [None] * (NUM_CLASSES + 1) # 1 based label indexing

        for i in range(int(detections)):
            if scores[i] < SCORE_THRESHOLD:
                continue

            label = int(classes[i])
            if top_predictions[label] is None or scores[i] > scores[top_predictions[label]]:
                top_predictions[label] = i

        return top_predictions
    
    def add_patch(self, axs, image, box, color, label, score):
        width = image.width
        height = image.height
        ymin = box[0] * height
        xmin = box[1] * width
        ymax = box[2] * height
        xmax = box[3] * width
        axs.add_patch(patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor=color,
            facecolor='none'
        ))
        axs.annotate(label + " " + str(round(score, 2)),
                     color=color,
                     fontsize=15,
                     xy=(xmin + 10, ymax - 10)
                    )
    
    def display_inference(self, image_path):
        image = Image.open(image_path)
        (boxes, scores, classes, detections) = self.run_detection(image)
        top_predictions = self.get_top_predictions(scores, classes, detections)
        
        fig, ax = plt.subplots(figsize=(20, 10), ncols=1, nrows=1)
        ax.imshow(image)
        ax.set_axis_off()
        
        for i in range(1, len(top_predictions)):
            z = top_predictions[i]
            if z is not None:
                print(get_label(i))
                print(boxes[z])
                color = COLORS[i]
                self.add_patch(ax, image, boxes[z], color, get_label(i), scores[z])

        plt.show()
        
        image.close()

In [None]:
detector = Detector()
detector.display_inference("rdt.jpg")