In [1]:
import os
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy
from skonnxrt.sklapi import OnnxTransformer
model_file = "mobilenetv2-1.0.onnx"
if not os.path.exists(model_file):
    print("Download '{0}'...".format(model_file))
    import urllib.request
    url = "https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/mobilenetv2-1.0.onnx"
    urllib.request.urlretrieve(url, model_file)
    print("Done.")

class_names = "imagenet_class_index.json"
if not os.path.exists(class_names):
    print("Download '{0}'...".format(class_names))
    import urllib.request
    url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
    urllib.request.urlretrieve(url, class_names)
    print("Done.")

import json
with open(class_names, "r", encoding="utf-8") as f:
    content_classes = f.read()
labels = json.loads(content_classes)

with open(model_file, "rb") as f:
    model_bytes = f.read()

ot = OnnxTransformer(model_bytes)

img = Image.open('/home/data/test/img.jpg')
img2 = img.resize((224, 224))
X = numpy.asarray(img2).transpose((2, 0, 1))
X = X[numpy.newaxis, :, :, :] / 255.0
print(X.shape)
print(X.shape, X.min(), X.max())

pred = ot.fit_transform(X)[0, :]
print(pred.shape)

from heapq import nlargest
results = nlargest(10, range(pred.shape[0]), pred.take)
print(results)

import pandas
data=[{"index": i, "label": labels.get(str(i), ('?', '?'))[1], 'score': pred[i]} \
      for i in results]
df = pandas.DataFrame(data)
print(df)


Download 'mobilenetv2-1.0.onnx'...
Done.
Download 'imagenet_class_index.json'...
Done.
(1, 3, 224, 224)
(1, 3, 224, 224) 0.0 1.0
(1000,)
[985, 584, 892, 729, 818, 828, 916, 310, 451, 446]
   index       label      score
0    985       daisy  11.125216
1    584  hair_slide   7.783502
2    892  wall_clock   7.076154
3    729  plate_rack   7.066480
4    818   spotlight   7.032015
5    828    strainer   6.537438
6    916    web_site   6.018525
7    310         ant   5.946991
8    451    bolo_tie   5.880325
9    446      binder   5.861617
