-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathdetect_image.py
108 lines (90 loc) · 3.65 KB
/
detect_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Lint as: python3
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Example using PyCoral to detect objects in a given image.
To run this code, you must attach an Edge TPU attached to the host and
install the Edge TPU runtime (`libedgetpu.so`) and `tflite_runtime`. For
device setup instructions, see coral.ai/docs/setup.
Example usage:
```
bash examples/install_requirements.sh detect_image.py
python3 examples/detect_image.py \
--model test_data/ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite \
--labels test_data/coco_labels.txt \
--input test_data/grace_hopper.bmp \
--output ${HOME}/grace_hopper_processed.bmp
```
"""
import argparse
import time
from PIL import Image
from PIL import ImageDraw
from pycoral.adapters import common
from pycoral.adapters import detect
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
def draw_objects(draw, objs, labels):
"""Draws the bounding box and label for each object."""
for obj in objs:
bbox = obj.bbox
draw.rectangle([(bbox.xmin, bbox.ymin), (bbox.xmax, bbox.ymax)],
outline='red')
draw.text((bbox.xmin + 10, bbox.ymin + 10),
'%s\n%.2f' % (labels.get(obj.id, obj.id), obj.score),
fill='red')
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-m', '--model', required=True,
help='File path of .tflite file')
parser.add_argument('-i', '--input', required=True,
help='File path of image to process')
parser.add_argument('-l', '--labels', help='File path of labels file')
parser.add_argument('-t', '--threshold', type=float, default=0.4,
help='Score threshold for detected objects')
parser.add_argument('-o', '--output',
help='File path for the result image with annotations')
parser.add_argument('-c', '--count', type=int, default=5,
help='Number of times to run inference')
args = parser.parse_args()
labels = read_label_file(args.labels) if args.labels else {}
interpreter = make_interpreter(args.model)
interpreter.allocate_tensors()
image = Image.open(args.input)
_, scale = common.set_resized_input(
interpreter, image.size, lambda size: image.resize(size, Image.LANCZOS))
print('----INFERENCE TIME----')
print('Note: The first inference is slow because it includes',
'loading the model into Edge TPU memory.')
for _ in range(args.count):
start = time.perf_counter()
interpreter.invoke()
inference_time = time.perf_counter() - start
objs = detect.get_objects(interpreter, args.threshold, scale)
print('%.2f ms' % (inference_time * 1000))
print('-------RESULTS--------')
if not objs:
print('No objects detected')
for obj in objs:
print(labels.get(obj.id, obj.id))
print(' id: ', obj.id)
print(' score: ', obj.score)
print(' bbox: ', obj.bbox)
if args.output:
image = image.convert('RGB')
draw_objects(ImageDraw.Draw(image), objs, labels)
image.save(args.output)
image.show()
if __name__ == '__main__':
main()