Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AC: add automatic model search for OpenCV launcher #3443

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def provide_precision_and_layout(launchers, input_precisions, input_layouts):


def provide_model_type(launcher, arguments):
if 'model_type' in arguments:
if 'model_type' in arguments and arguments.model_type is not None:
launcher['_model_type'] = arguments.model_type
if launcher['framework'] in ['dlsdk', 'openvino', 'g-api'] and 'model_is_blob' in arguments:
launcher['_model_is_blob'] = arguments.model_is_blob
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import re
from collections import OrderedDict
from pathlib import Path
import numpy as np
import cv2

from ..config import PathField, StringField, ConfigError, ListInputsField
from ..logging import print_info
from .launcher import Launcher, LauncherConfigValidator
from ..utils import get_or_parse_value
from ..utils import get_or_parse_value, get_path

DEVICE_REGEX = r'(?P<device>cpu$|gpu|gpu_fp16)?'
BACKEND_REGEX = r'(?P<backend>ocv|ie)?'
Expand Down Expand Up @@ -63,8 +64,11 @@ class OpenCVLauncher(Launcher):
def parameters(cls):
parameters = super().parameters()
parameters.update({
'model': PathField(description="Path to model file."),
'weights': PathField(description="Path to weights file.", optional=True, default='', check_exists=False),
'model': PathField(description="Path to model file.", file_or_directory=True),
'weights': PathField(
description="Path to weights file.", optional=True,
check_exists=False, file_or_directory=True
),
'device': StringField(
regex=DEVICE_REGEX, choices=OpenCVLauncher.TARGET_DEVICES.keys(),
description="Device name: {}".format(', '.join(OpenCVLauncher.TARGET_DEVICES.keys()))
Expand Down Expand Up @@ -100,8 +104,10 @@ def __init__(self, config_entry: dict, *args, **kwargs):
raise ConfigError('{} is not supported device'.format(selected_device))

if not self._delayed_model_loading:
self.model = self.get_value_from_config('model')
self.weights = self.get_value_from_config('weights')
self.model, self.weights = self.automatic_model_search(self._model_name,
self.get_value_from_config('model'), self.get_value_from_config('weights'),
self.get_value_from_config('_model_type')
)
self.network = self.create_network(self.model, self.weights)
self._inputs_shapes = self.get_inputs_from_config(self.config)
self.network.setInputsNames(list(self._inputs_shapes.keys()))
Expand Down Expand Up @@ -130,6 +136,71 @@ def batch(self):
def output_blob(self):
return next(iter(self.output_names))

def automatic_model_search(self, model_name, model_cfg, weights_cfg, model_type=None):
model_type_ext = {
'xml': 'xml',
'blob': 'blob',
'onnx': 'onnx',
'caffe': 'prototxt',
'tf': 'pb'
}
def get_model_by_suffix(model_name, model_dir, suffix):
model_list = list(Path(model_dir).glob('{}.{}'.format(model_name, suffix)))
if not model_list:
model_list = list(Path(model_dir).glob('*.{}'.format(suffix)))
if not model_list:
model_list = list(Path(model_dir).parent.rglob('*.{}'.format(suffix)))
return model_list

def get_model():
model = Path(model_cfg)
if not model.is_dir():
accepted_suffixes = list(model_type_ext.values())
if model.suffix[1:] not in accepted_suffixes:
raise ConfigError('Models with following suffixes are allowed: {}'.format(accepted_suffixes))
print_info('Found model {}'.format(model))
return model, model.suffix == '.blob'
model_list = []
if model_type is not None:
model_list = get_model_by_suffix(model_name, model, model_type_ext[model_type])
else:
for ext in model_type_ext.values():
model_list = get_model_by_suffix(model_name, model, ext)
if model_list:
break
if not model_list:
raise ConfigError('suitable model is not found')
if len(model_list) != 1:
raise ConfigError('More than one model matched, please specify explicitly')
model = model_list[0]
print_info('Found model {}'.format(model))
return model, model.suffix == '.blob'

model, is_blob = get_model()
if is_blob:
return model, None
weights = weights_cfg
if (weights is None or Path(weights).is_dir()) and model.suffix != '.onnx':
weights_dir = weights or model.parent
weights_list = []
if model.suffix == '.xml':
weights = Path(weights_dir) / model.name.replace('xml', 'bin')
else:
if model.suffix == '.prototxt':
weights_list = list(Path(weights_dir).glob('*.{}'.format('caffemodel')))
if not weights_list:
raise ConfigError('Suitable weights is not detected')
if len(weights_list) != 1:
raise ConfigError('Several suitable weights found, please specify required explicitly')
weights = weights_list[0]
if weights is not None:
accepted_weights_suffixes = ['.bin', '.caffemodel']
if weights.suffix not in accepted_weights_suffixes:
raise ConfigError('Weights with following suffixes are allowed: {}'.format(accepted_weights_suffixes))
print_info('Found weights {}'.format(get_path(weights)))

return model, weights

def predict(self, inputs, metadata=None, **kwargs):
"""
Args:
Expand Down