In [1]:
import argparse
import time
import urllib.request
from PIL import Image

import classify
import tflite_runtime.interpreter as tflite
import platform
from pathlib import Path

In [2]:
EDGETPU_SHARED_LIB = {
  'Linux': 'libedgetpu.so.1',
  'Darwin': 'libedgetpu.1.dylib',
  'Windows': 'edgetpu.dll'
}[platform.system()]

In [3]:
def load_labels(path, encoding='utf-8'):
    """Loads labels from file (with or without index numbers).

    Args:
    path: path to label file.
    encoding: label file encoding.
    Returns:
    Dictionary mapping indices to labels.
    """
    with open(path, 'r', encoding=encoding) as f:
        
        lines = f.readlines()
        if not lines:
            return {}

        if lines[0].split(' ', maxsplit=1)[0].isdigit():
            pairs = [line.split(' ', maxsplit=1) for line in lines]
            return {int(index): label.strip() for index, label in pairs}
        else:
            return {index: line.strip() for index, line in enumerate(lines)}

In [4]:
def make_interpreter(model_file):
    model_file, *device = model_file.split('@')
    return tflite.Interpreter(
        model_path=model_file,
        experimental_delegates=[
            tflite.load_delegate(EDGETPU_SHARED_LIB,{'device': device[0]} if device else {})])

In [5]:
def get_output(interpreter, top_k=1, score_threshold=0.0):
    """Returns no more than top_k classes with score >= score_threshold."""
    scores = output_tensor(interpreter)
    classes = [
      Class(i, scores[i])
      for i in np.argpartition(scores, -top_k)[-top_k:]
      if scores[i] >= score_threshold
    ]
    return sorted(classes, key=operator.itemgetter(1), reverse=True)

In [6]:
Path("models").mkdir(parents=True, exist_ok=True)
Path("images").mkdir(parents=True, exist_ok=True)

In [7]:
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/mobilenet_v2_1.0_224_inat_bird_quant.tflite'
urllib.request.urlretrieve(url, 'models/mobilenet_v2_1.0_224_inat_bird_quant.tflite')
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite'
urllib.request.urlretrieve(url, 'models/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite')
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/inat_bird_labels.txt'
urllib.request.urlretrieve(url, 'models/inat_bird_labels.txt')
url = 'https://github.com/google-coral/edgetpu/raw/master/test_data/parrot.jpg'
urllib.request.urlretrieve(url, 'images/parrot.jpg')

('images/parrot.jpg', <http.client.HTTPMessage at 0x2617e56b3c8>)

In [8]:
interpreter = make_interpreter('models/mobilenet_v2_1.0_224_inat_bird_quant.tflite')
#interpreter = make_interpreter('models/mobilenet_v2_1.0_224_inat_bird_quant_edgetpu.tflite')

In [9]:
interpreter.allocate_tensors()

In [10]:
size = classify.input_size(interpreter)

In [11]:
image = Image.open('images/parrot.jpg').convert('RGB').resize(size, Image.ANTIALIAS)

In [12]:
classify.set_input(interpreter, image)

In [13]:
print('----INFERENCE TIME----')
print('Note: The first inference on Edge TPU is slow because it includes',
    'loading the model into Edge TPU memory.')

----INFERENCE TIME----
Note: The first inference on Edge TPU is slow because it includes loading the model into Edge TPU memory.


In [14]:
for _ in range(5):
    start = time.perf_counter()
    interpreter.invoke()
    inference_time = time.perf_counter() - start
    classes = classify.get_output(interpreter, 1, 0.0)
    print('%.1fms' % (inference_time * 1000))

275.8ms
275.8ms
285.0ms
272.5ms
271.4ms


In [15]:
labels = load_labels('models/inat_bird_labels.txt')

In [16]:
print('-------RESULTS--------')
for klass in classes:
    print('%s: %.5f' % (labels.get(klass.id, klass.id), klass.score))

-------RESULTS--------
Ara macao (Scarlet Macaw): 0.78906
