Skip to content
Permalink
Browse files

feat(encoder): add onnxruntime suport for image encoder

Signed-off-by: raccoonliukai <903896015@qq.com>
  • Loading branch information...
raccoonliukai committed Aug 1, 2019
1 parent 8150cf1 commit f03e6fc20125d6d877f80ae462546cdde96a62a0
@@ -0,0 +1,69 @@
# Tencent is pleased to support the open source community by making GNES available.
#
# Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List, Callable

import numpy as np

from ..base import BaseImageEncoder
from ...helper import batching


class BaseONNXImageEncoder(BaseImageEncoder):

def __init__(self, model_name: str,
model_dir: str,
batch_size: int = 64,
use_cuda: bool = False,
*args, **kwargs):
super().__init__(*args, **kwargs)

self.batch_size = batch_size
self.model_dir = model_dir
self.model_name = model_name
self._use_cuda = use_cuda

def post_init(self):
import onnxruntime as ort

self.sess = ort.InferenceSession(self.model_dir + '/' + self.model_name)
inputs_info = self.sess.get_inputs()

if len(inputs_info) != 1:
raise ValueError('Now only support encoder with one input')
else:
self.input_name = inputs_info[0].name
self.input_shape = inputs_info[0].shape
self.input_type = inputs_info[0].type
self.batch_size = self.input_shape[0]

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
pad_batch = 0
if len(img) != self.input_shape[0]:
pad_batch = self.input_shape[0] - len(img)
for i in range(pad_batch):
img.append(np.zeros_like(img[0]))

img_ = np.array(img, dtype=np.float32).transpose(0, 3, 1, 2)
if list(img_.shape) != self.input_shape:
raise ValueError('Map size not match net, expect', self.input_shape, ',got', img_.shape)

result_npy = self.sess.run(None, {self.input_name: img_})
if pad_batch != 0:
return result_npy[0][0:len(img)]
else:
return result_npy[0]
@@ -52,7 +52,8 @@
'vision': ['opencv-python>=4.0.0', 'imagehash>=4.0'],
'leveldb': ['plyvel>=1.0.5'],
'test': ['pylint', 'memory_profiler>=0.55.0', 'psutil>=5.6.1', 'gputil>=1.4.0'],
'http': ['flask', 'flask-compress', 'flask-cors', 'flask-json', 'aiohttp==3.5.4']
'http': ['flask', 'flask-compress', 'flask-cors', 'flask-json', 'aiohttp==3.5.4'],
'onnx': ['onnxruntime']
}


@@ -0,0 +1,89 @@
import copy
import os
import unittest
import zipfile

from gnes.encoder.image.onnx import BaseONNXImageEncoder
from gnes.preprocessor.base import BaseUnaryPreprocessor
from gnes.preprocessor.image.sliding_window import VanillaSlidingPreprocessor
from gnes.proto import gnes_pb2, blob2array


def img_process_for_test(dirname):
zipfile_ = zipfile.ZipFile(os.path.join(dirname, 'imgs/test.zip'), "r")
all_bytes = [zipfile_.open(v).read() for v in zipfile_.namelist()]
test_img = []
for raw_bytes in all_bytes:
d = gnes_pb2.Document()
d.raw_bytes = raw_bytes
test_img.append(d)

test_img_all_preprocessor = []
for preprocessor in [BaseUnaryPreprocessor(doc_type=gnes_pb2.Document.IMAGE),
VanillaSlidingPreprocessor()]:
test_img_copy = copy.deepcopy(test_img)
for img in test_img_copy:
preprocessor.apply(img)
test_img_all_preprocessor.append([blob2array(chunk.blob)
for img in test_img_copy for chunk in img.chunks])
return test_img_all_preprocessor


class TestONNXImageEncoder(unittest.TestCase):

def setUp(self):
dirname = os.path.dirname(__file__)
self.dump_path = os.path.join(dirname, 'model.bin')
self.test_img = img_process_for_test(dirname)
self.vgg_yaml = os.path.join(dirname, 'yaml', 'onnx-vgg-encoder.yml')
self.res_yaml = os.path.join(dirname, 'yaml', 'onnx-resnet-encoder.yml')
self.inception_yaml = os.path.join(dirname, 'yaml', 'onnx-inception-encoder.yml')
self.mobilenet_yaml = os.path.join(dirname, 'yaml', 'onnx-mobilenet-encoder.yml')

def test_vgg_encoding(self):
self.encoder = BaseONNXImageEncoder.load_yaml(self.vgg_yaml)
for test_img in self.test_img:
vec = self.encoder.encode(test_img)
print("the length of data now is:", len(test_img))
self.assertEqual(vec.shape[0], len(test_img))
self.assertEqual(vec.shape[1], 1000)

def test_resnet_encoding(self):
self.encoder = BaseONNXImageEncoder.load_yaml(self.res_yaml)
for test_img in self.test_img:
vec = self.encoder.encode(test_img)
print("the length of data now is:", len(test_img))
self.assertEqual(vec.shape[0], len(test_img))
self.assertEqual(vec.shape[1], 1000)

def test_inception_encoding(self):
self.encoder = BaseONNXImageEncoder.load_yaml(self.inception_yaml)
for test_img in self.test_img:
vec = self.encoder.encode(test_img)
print("the length of data now is:", len(test_img))
self.assertEqual(vec.shape[0], len(test_img))
self.assertEqual(vec.shape[1], 1000)

def test_mobilenet_encoding(self):
self.encoder = BaseONNXImageEncoder.load_yaml(self.mobilenet_yaml)
for test_img in self.test_img:
vec = self.encoder.encode(test_img)
print("the length of data now is:", len(test_img))
self.assertEqual(vec.shape[0], len(test_img))
self.assertEqual(vec.shape[1], 1000)

def test_dump_load(self):
self.encoder = BaseONNXImageEncoder.load_yaml(self.inception_yaml)

self.encoder.dump(self.dump_path)

vgg_encoder2 = BaseONNXImageEncoder.load(self.dump_path)

for test_img in self.test_img:
vec = vgg_encoder2.encode(test_img)
self.assertEqual(vec.shape[0], len(test_img))
self.assertEqual(vec.shape[1], 1000)

def tearDown(self):
if os.path.exists(self.dump_path):
os.remove(self.dump_path)
@@ -0,0 +1,6 @@
!BaseONNXImageEncoder
parameter:
model_dir: ${INCEPTION_MODEL}
model_name: inception_v2.onnx
gnes_config:
is_trained: true
@@ -0,0 +1,6 @@
!BaseONNXImageEncoder
parameter:
model_dir: ${MOBILENET_MODEL}
model_name: mobilenetv2.onnx
gnes_config:
is_trained: true
@@ -0,0 +1,6 @@
!BaseONNXImageEncoder
parameter:
model_dir: ${RESNET_MODEL}
model_name: resnet50.onnx
gnes_config:
is_trained: true
@@ -0,0 +1,6 @@
!BaseONNXImageEncoder
parameter:
model_dir: ${VGG_MODEL}
model_name: vgg16.onnx
gnes_config:
is_trained: true

0 comments on commit f03e6fc

Please sign in to comment.
You can’t perform that action at this time.