In [9]:
import argparse
import collections
import common
import cv2
import numpy as np
import os
from PIL import Image
import re
import tflite_runtime.interpreter as tflite

from pathlib import Path
import urllib.request

In [10]:
Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])

In [19]:
def load_labels(path):
    p = re.compile(r'\s*(\d+)(.+)')
    with open(path, 'r', encoding='utf-8') as f:
       lines = (p.match(line).groups() for line in f.readlines())
       return {int(num): text.strip() for num, text in lines}

In [20]:
class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
    """Bounding box.
    Represents a rectangle which sides are either vertical or horizontal, parallel
    to the x or y axis.
    """
    __slots__ = ()

def get_output(interpreter, score_threshold, top_k, image_scale=1.0):
    """Returns list of detected objects."""
    boxes = common.output_tensor(interpreter, 0)
    class_ids = common.output_tensor(interpreter, 1)
    scores = common.output_tensor(interpreter, 2)
    count = int(common.output_tensor(interpreter, 3))

    def make(i):
        ymin, xmin, ymax, xmax = boxes[i]
        return Object(
            id=int(class_ids[i]),
            score=scores[i],
            bbox=BBox(xmin=np.maximum(0.0, xmin),
                      ymin=np.maximum(0.0, ymin),
                      xmax=np.minimum(1.0, xmax),
                      ymax=np.minimum(1.0, ymax)))

    return [make(i) for i in range(top_k) if scores[i] >= score_threshold]

In [21]:
Path("models").mkdir(parents=True, exist_ok=True)

In [22]:
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/mobilenet_ssd_v2_coco_quant_postprocess.tflite'
urllib.request.urlretrieve(url, 'models/mobilenet_ssd_v2_coco_quant_postprocess.tflite')
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/coco_labels.txt'
urllib.request.urlretrieve(url, 'models/coco_labels.txt')

('models/coco_labels.txt', <http.client.HTTPMessage at 0x28b4b700e08>)

In [23]:
interpreter = common.make_interpreter('models/mobilenet_ssd_v2_coco_quant_postprocess.tflite')

In [24]:
interpreter.allocate_tensors()

In [25]:
labels = load_labels('models/coco_labels.txt')

In [26]:
cap = cv2.VideoCapture(0)

In [28]:
def append_objs_to_img(cv2_im, objs, labels):
    height, width, channels = cv2_im.shape
    for obj in objs:
        x0, y0, x1, y1 = list(obj.bbox)
        x0, y0, x1, y1 = int(x0*width), int(y0*height), int(x1*width), int(y1*height)
        percent = int(100 * obj.score)
        label = '{}% {}'.format(percent, labels.get(obj.id, obj.id))

        cv2_im = cv2.rectangle(cv2_im, (x0, y0), (x1, y1), (0, 255, 0), 2)
        cv2_im = cv2.putText(cv2_im, label, (x0, y0+30),
                             cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0), 2)
    return cv2_im

In [None]:
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    cv2_im = frame

    cv2_im_rgb = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
    pil_im = Image.fromarray(cv2_im_rgb)

    common.set_input(interpreter, pil_im)
    interpreter.invoke()
    objs = get_output(interpreter, score_threshold=0.1, top_k=3)
    cv2_im = append_objs_to_img(cv2_im, objs, labels)

    cv2.imshow('frame', cv2_im)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()