-
Notifications
You must be signed in to change notification settings - Fork 45
/
bird_classifier_test.py
35 lines (26 loc) · 1.14 KB
/
bird_classifier_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# RUN: %PYTHON %s %config_flag
# XFAIL: *
import absl.testing
import numpy
import test_util
import urllib.request
from PIL import Image
model_path = "https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite"
class BirdClassifierTest(test_util.TFLiteModelTest):
def __init__(self, *args, **kwargs):
super(BirdClassifierTest, self).__init__(model_path, *args, **kwargs)
def compare_results(self, iree_results, tflite_results, details):
super(BirdClassifierTest, self).compare_results(iree_results, tflite_results, details)
self.assertTrue(numpy.isclose(iree_results[0], tflite_results[0], atol=4.0).all())
def generate_inputs(self, input_details):
img_path = "https://github.com/google-coral/test_data/raw/master/bird.bmp"
local_path = "/".join([self.workdir, "bird.bmp"])
urllib.request.urlretrieve(img_path, local_path)
shape = input_details[0]["shape"]
im = numpy.array(Image.open(local_path).resize((shape[1], shape[2])))
args = [im.reshape(shape)]
return args
def test_compile_tflite(self):
self.compile_and_execute()
if __name__ == '__main__':
absl.testing.absltest.main()