In [3]:
!pip --version

pip 21.3.1 from /usr/local/lib/python3.6/dist-packages/pip (python 3.6)


In [4]:
import tensorflow as tf
import numpy as np
import argparse
import re
import time
from PIL import Image
from PIL import ImageDraw
import collections

In [46]:
def draw_objects(draw, objs, labels):
    """Draws the bounding box and label for each object."""
    for obj in objs:
        bbox = obj.bbox
        draw.rectangle([(bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax)], outline="red", width=3)
        draw.text(
            (bbox.xmin + 10, bbox.ymin + 10),
            "%s\n%.2f" % (labels.get(obj.id, obj.id), obj.score),
            fill="red",
        )


def read_label_file(file_path):
    """Reads labels from a text file and returns it as a dictionary.
    This function supports label files with the following formats:
    + Each line contains id and description separated by colon or space.
        Example: ``0:cat`` or ``0 cat``.
    + Each line contains a description only. The returned label id's are based on
        the row number.
    Args:
        file_path (str): path to the label file.
    Returns:
        Dict of (int, string) which maps label id to description.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    ret = {}
    for row_number, content in enumerate(lines):
        pair = re.split(r"[:\s]+", content.strip(), maxsplit=1)
        if len(pair) == 2 and pair[0].strip().isdigit():
            ret[int(pair[0])] = pair[1].strip()
        else:
            ret[row_number] = content.strip()
    return ret


def input_details(interpreter, key):
    """Gets a model's input details by specified key.
    Args:
      interpreter: The ``tf.lite.Interpreter`` holding the model.
      key (int): The index position of an input tensor.
    Returns:
      The input details.
    """
    return interpreter.get_input_details()[0][key]


def input_size(interpreter):
    """Gets a model's input size as (width, height) tuple.
    Args:
        interpreter: The ``tf.lite.Interpreter`` holding the model.
      Returns:
        The input tensor size as (width, height) tuple.
    """
    _, height, width, _ = input_details(interpreter, "shape")
    return width, height


def input_tensor(interpreter):
    """Gets a model's input tensor view as numpy array of shape (height, width, 3).

    Args:
        interpreter: The ``tf.lite.Interpreter`` holding the model.
    Returns:
        The input tensor view as :obj:`numpy.array` (height, width, 3).
    """
    tensor_index = input_details(interpreter, "index")
    return interpreter.tensor(tensor_index)()[0]


def set_resized_input(interpreter, size, resize):
    """Copies a resized and properly zero-padded image to a model's input tensor.
    Args:
        interpreter: The ``tf.lite.Interpreter`` to update.
        size (tuple): The original image size as (width, height) tuple.
        resize: A function that takes a (width, height) tuple, and returns an
        image resized to those dimensions.
    Returns:
        The resized tensor with zero-padding as tuple
        (resized_tensor, resize_ratio).
    """
    width, height = input_size(interpreter)
    w, h = size
    scale = min(width / w, height / h)
    w, h = int(w * scale), int(h * scale)
    tensor = input_tensor(interpreter)
    tensor.fill(0)  # padding
    _, _, channel = tensor.shape
    print("CHANNEL", channel, (h, w, channel))
    result = resize((w, h))
    tensor[:h, :w] = np.reshape(result, (h, w, channel))
    return result, (scale, scale)


def output_tensor(interpreter, i):
    """Gets a model's ith output tensor.
    Args:
      interpreter: The ``tf.lite.Interpreter`` holding the model.
      i (int): The index position of an output tensor.
    Returns:
      The output tensor at the specified position.
    """
    return interpreter.tensor(interpreter.get_output_details()[i]['index'])()


class BBox(collections.namedtuple('BBox', ['xmin', 'ymin', 'xmax', 'ymax'])):
    """The bounding box for a detected object.
    .. py:attribute:: xmin
        X-axis start point
    .. py:attribute:: ymin
        Y-axis start point
    .. py:attribute:: xmax
        X-axis end point
    .. py:attribute:: ymax
        Y-axis end point
    """
    __slots__ = ()

    @property
    def width(self):
        """The bounding box width."""
        return self.xmax - self.xmin

    @property
    def height(self):
        """The bounding box height."""
        return self.ymax - self.ymin

    @property
    def area(self):
        """The bound box area."""
        return self.width * self.height

    @property
    def valid(self):
        """Indicates whether bounding box is valid or not (boolean).
        A valid bounding box has xmin <= xmax and ymin <= ymax (equivalent
        to width >= 0 and height >= 0).
        """
        return self.width >= 0 and self.height >= 0

    def scale(self, sx, sy):
        """Scales the bounding box.
        Args:
          sx (float): Scale factor for the x-axis.
          sy (float): Scale factor for the y-axis.
        Returns:
          A :obj:`BBox` object with the rescaled dimensions.
        """
        return BBox(
            xmin=sx * self.xmin,
            ymin=sy * self.ymin,
            xmax=sx * self.xmax,
            ymax=sy * self.ymax)

    def translate(self, dx, dy):
        """Translates the bounding box position.
        Args:
          dx (int): Number of pixels to move the box on the x-axis.
          dy (int): Number of pixels to move the box on the y-axis.
        Returns:
          A :obj:`BBox` object at the new position.
        """
        return BBox(
            xmin=dx + self.xmin,
            ymin=dy + self.ymin,
            xmax=dx + self.xmax,
            ymax=dy + self.ymax)

    def map(self, f):
        """Maps all box coordinates to a new position using a given function.
        Args:
          f: A function that takes a single coordinate and returns a new one.
        Returns:
          A :obj:`BBox` with the new coordinates.
        """
        return BBox(
            xmin=f(self.xmin),
            ymin=f(self.ymin),
            xmax=f(self.xmax),
            ymax=f(self.ymax))

    @staticmethod
    def intersect(a, b):
        """Gets a box representing the intersection between two boxes.
        Args:
          a: :obj:`BBox` A.
          b: :obj:`BBox` B.
        Returns:
          A :obj:`BBox` representing the area where the two boxes intersect
          (may be an invalid box, check with :func:`valid`).
        """
        return BBox(
            xmin=max(a.xmin, b.xmin),
            ymin=max(a.ymin, b.ymin),
            xmax=min(a.xmax, b.xmax),
            ymax=min(a.ymax, b.ymax))

    @staticmethod
    def union(a, b):
        """Gets a box representing the union of two boxes.
        Args:
          a: :obj:`BBox` A.
          b: :obj:`BBox` B.
        Returns:
          A :obj:`BBox` representing the unified area of the two boxes
          (always a valid box).
        """
        return BBox(
            xmin=min(a.xmin, b.xmin),
            ymin=min(a.ymin, b.ymin),
            xmax=max(a.xmax, b.xmax),
            ymax=max(a.ymax, b.ymax))

    @staticmethod
    def iou(a, b):
        """Gets the intersection-over-union value for two boxes.
        Args:
          a: :obj:`BBox` A.
          b: :obj:`BBox` B.
        Returns:
          The intersection-over-union value: 1.0 meaning the two boxes are
          perfectly aligned, 0 if not overlapping at all (invalid intersection).
        """
        intersection = BBox.intersect(a, b)
        if not intersection.valid:
            return 0.0
        area = intersection.area
        return area / (a.area + b.area - area)



def get_objects(interpreter,
                score_threshold=-float('inf'),
                image_scale=(1.0, 1.0)):
    """Gets results from a detection model as a list of detected objects.
    Args:
      interpreter: The ``tf.lite.Interpreter`` to query for results.
      score_threshold (float): The score threshold for results. All returned
        results have a score greater-than-or-equal-to this value.
      image_scale (float, float): Scaling factor to apply to the bounding boxes as
        (x-scale-factor, y-scale-factor), where each factor is from 0 to 1.0.
    Returns:
      A list of :obj:`Object` objects, which each contains the detected object's
      id, score, and bounding box as :obj:`BBox`.
    """
    print(interpreter.get_output_details())
    Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])
    # If a model has signature, we use the signature output tensor names to parse
    # the results. Otherwise, we parse the results based on some assumption of the
    # output tensor order and size.
    # pylint: disable=protected-access
    signature_list = interpreter._get_full_signature_list()
    # pylint: enable=protected-access
    if signature_list:
        if len(signature_list) > 1:
            raise ValueError('Only support model with one signature.')
        print("Option 1")
        signature = signature_list[next(iter(signature_list))]
        count = int(interpreter.tensor(signature['outputs']['output_0'])()[0])
        scores = interpreter.tensor(signature['outputs']['output_1'])()[0]
        class_ids = interpreter.tensor(signature['outputs']['output_2'])()[0]
        boxes = interpreter.tensor(signature['outputs']['output_3'])()[0]
    elif output_tensor(interpreter, 3).size == 1:
        print("Option 2")
        boxes = output_tensor(interpreter, 0)[0]
        class_ids = output_tensor(interpreter, 1)[0]
        scores = output_tensor(interpreter, 2)[0]
        count = int(output_tensor(interpreter, 3)[0])
    else:
        print("Option 3")
        scores = output_tensor(interpreter, 0)[0]
        boxes = output_tensor(interpreter, 1)[0]
        count = (int)(output_tensor(interpreter, 2)[0])
        class_ids = output_tensor(interpreter, 3)[0]

    width, height = input_size(interpreter)
    image_scale_x, image_scale_y = image_scale
    sx, sy = width / image_scale_x, height / image_scale_y

    def make(i):
        ymin, xmin, ymax, xmax = boxes[i]
        return Object(
            id=int(class_ids[i]),
            score=float(scores[i]),
            bbox=BBox(xmin=xmin, ymin=ymin, xmax=xmax,
                      ymax=ymax).scale(sx, sy).map(int))

    return [make(i) for i in range(count) if scores[i] >= score_threshold]


In [47]:
class args:
    model = "ssd_mobilenet_v2_coco_quant_postprocess.tflite"
    input_img = "image_0_resized_300.jpg"
    labels = "coco_labels.txt"
    threshold = 0.5
    output_img = "processed.jpg"
    count = 5

In [48]:
def main():
    labels = read_label_file(args.labels) if args.labels else {}
    interpreter = tf.lite.Interpreter(args.model)
    interpreter.allocate_tensors()

    orig_image = Image.open(args.input_img)
    w, h = orig_image.size
    tiles = [
        orig_image
#         orig_image.crop((0, 0, w/2 + w/5, h/2 + h/5)), 
#              orig_image.crop((w/2 - w/5, 0, w, h/2 + h/5)),
#              orig_image.crop((0, h/2 - h/5, w/2 + w/5, h)),
#              orig_image.crop((w/2 - w/5, h/2 - h/5, w, h))
            ]
    all_objs = []
    
    for image in tiles:
        _, scale = set_resized_input(
            interpreter, image.size, lambda size: image.resize(size, Image.ANTIALIAS)
        )
        print(scale)

        print("----INFERENCE TIME----")
        print(
            "Note: The first inference is slow because it includes",
            "loading the model into Edge TPU memory.",
        )
        for _ in range(args.count):
            start = time.perf_counter()
            interpreter.invoke()
            inference_time = time.perf_counter() - start
            objs = get_objects(interpreter, args.threshold, scale)
            print("%.2f ms" % (inference_time * 1000))

        print("-------RESULTS--------")
        if not objs:
            print("No objects detected")

        for obj in objs:
            all_objs.append(obj)
    #         print(labels.get(obj.id, obj.id))
    #         print("  id:    ", obj.id)
    #         print("  score: ", obj.score)
    #         print("  bbox:  ", obj.bbox)

    if args.output_img:
        # image = orig_image.crop((0, 0, w/2 + w/5, h/2 + h/5)).convert("RGB")
        image = orig_image.convert("RGB")
        draw_objects(ImageDraw.Draw(image), all_objs, labels)
        image.save(args.output_img)
        image.show()

In [49]:
main()

CHANNEL 3 (300, 300, 3)
(1.0, 1.0)
----INFERENCE TIME----
Note: The first inference is slow because it includes loading the model into Edge TPU memory.
[{'name': 'TFLite_Detection_PostProcess', 'index': 259, 'shape': array([ 1, 20,  4], dtype=int32), 'shape_signature': array([ 1, 20,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 260, 'shape': array([ 1, 20], dtype=int32), 'shape_signature': array([ 1, 20], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 261, 'shape': array([ 1, 20], dtype=int32), 'shape_signature': 



[{'name': 'TFLite_Detection_PostProcess', 'index': 259, 'shape': array([ 1, 20,  4], dtype=int32), 'shape_signature': array([ 1, 20,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 260, 'shape': array([ 1, 20], dtype=int32), 'shape_signature': array([ 1, 20], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 261, 'shape': array([ 1, 20], dtype=int32), 'shape_signature': array([ 1, 20], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32)

Unescaped left brace in regex is deprecated, passed through in regex; marked by <-- HERE in m/%{ <-- HERE (.*?)}/ at /usr/bin/run-mailcap line 528.
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: www-browser: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: links2: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: elinks: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: links: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: lynx: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: w3m: not found
xdg-open: no method available for opening '/tmp/tmp_41dxa8p.PNG'


In [63]:
display_size = (500, 500)
def main():
    labels = read_label_file(args.labels) if args.labels else {}
    interpreter = tf.lite.Interpreter(args.model)
    interpreter.allocate_tensors()
    
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test the model on input data.
    orig_image = Image.open(args.input_img)
    input_shape = input_details[0]['shape']
    input_data = np.reshape(orig_image, (1, 300, 300, 3))
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    ###################START POSTPROCESSING##################### 
    def get_objects(interpreter,
                score_threshold=-float('inf'),
                image_scale=(1.0, 1.0)):
        """Gets results from a detection model as a list of detected objects.
        Args:
          interpreter: The ``tf.lite.Interpreter`` to query for results.
          score_threshold (float): The score threshold for results. All returned
            results have a score greater-than-or-equal-to this value.
          image_scale (float, float): Scaling factor to apply to the bounding boxes as
            (x-scale-factor, y-scale-factor), where each factor is from 0 to 1.0.
        Returns:
          A list of :obj:`Object` objects, which each contains the detected object's
          id, score, and bounding box as :obj:`BBox`.
        """
        Object = collections.namedtuple('Object', ['id', 'score', 'bbox'])

        boxes = interpreter.get_tensor(output_details[0]['index'])[0]
        class_ids = interpreter.get_tensor(output_details[1]['index'])[0]
        scores = interpreter.get_tensor(output_details[2]['index'])[0]
        count = int(interpreter.get_tensor(output_details[3]['index'])[0])
                
        _, width, height, _ = interpreter.get_input_details()[0]["shape"]
        image_scale_x, image_scale_y = image_scale
        sx, sy = width / image_scale_x, height / image_scale_y
        print(sx, sy)

        def make(i):
            ymin, xmin, ymax, xmax = boxes[i]
            return Object(
                id=int(class_ids[i]),
                score=float(scores[i]),
                bbox=BBox(xmin=xmin*sx, ymin=ymin*sy, xmax=xmax*sx,
                          ymax=ymax*sy).map(int))

        return [make(i) for i in range(count) if scores[i] >= score_threshold]

    objs = get_objects(interpreter, args.threshold)
    
    
    
    
    ###################END POSTPROCESSING#####################
    all_objs = []
    for obj in objs:
        all_objs.append(obj)
        print(labels.get(obj.id, obj.id))
        print("  id:    ", obj.id)
        print("  score: ", obj.score)
        print("  bbox:  ", obj.bbox)

    if args.output_img:
        # image = orig_image.crop((0, 0, w/2 + w/5, h/2 + h/5)).convert("RGB")
        orig_image = orig_image.resize((500, 500))
        image = orig_image.convert("RGB")
        draw_objects(ImageDraw.Draw(image), all_objs, labels)
        image.save(args.output_img)
        image.show()

main()

300.0 300.0
person
  id:     0
  score:  0.8046875
  bbox:   BBox(xmin=132, ymin=106, xmax=201, ymax=250)
person
  id:     0
  score:  0.72265625
  bbox:   BBox(xmin=214, ymin=86, xmax=300, ymax=227)
backpack
  id:     26
  score:  0.72265625
  bbox:   BBox(xmin=39, ymin=153, xmax=131, ymax=247)
person
  id:     0
  score:  0.55859375
  bbox:   BBox(xmin=0, ymin=87, xmax=120, ymax=252)


Unescaped left brace in regex is deprecated, passed through in regex; marked by <-- HERE in m/%{ <-- HERE (.*?)}/ at /usr/bin/run-mailcap line 528.
Error: no "view" rule for type "image/png" passed its test case
       (for more information, add "--debug=1" on the command line)
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: www-browser: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: links2: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: elinks: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: links: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: lynx: not found
/usr/bin/xdg-open: 778: /usr/bin/xdg-open: w3m: not found
xdg-open: no method available for opening '/tmp/tmps_pxybno.PNG'
