Skip to content

Commit

Permalink
[tools] new api and support dynamic models in python benchmark app (#…
Browse files Browse the repository at this point in the history
…8582)

* Preprocessing API - base classes

Includes API definition for trivial mean/scale operations (which don't require layout)

Mean/scale with 'layout' support will be done under separate task together
 with Layout

Current test code coverage: 100%

* Python bindings for base preprocessing API

* remove pre_post_process directory from ngraph/core

* remove files from ngraph/python dir

* move pyngraph pre_post_process files from ngraph/python to runtime

* remove pre_post_process test from CMakeList

* move include to the header

* update include path for pre_post_process

* style fix

* bind InputTensorInfo::set_layout

* cleaned test_preprocess

* fix test expected output

* remove duplicate test

* update description of set_element_type

* fix style

* move preprocess from pyngraph to pyopenvino/graph

* update test_preprocess imports and remove unnecessary test

* remove duplicate import

* update custom method

* update test

* update test

* create decorator that changes Node into Output<Node>

* create function that cast Node to Output<Node>

* update test_preprocess to use decorator for custom function

* change _cast_to_output -> _from_node

* style fix

* add tests fro scale and mean with vector input

* style fix

* add docstring for custom_preprocess_function

* bind InputInfo network method

* style fix

* bind OutputInfo

* fix description of preprocess submodule

* fix style

* update copyright year

* bind OutputTensorInfo

* bind OutputNetworkInfo and InputNetworkInfo

* Bind exec core ov (#50)

* Output const node python tests (#52)

* add python bindings tests for Output<const ov::None>

* add proper tests

* add new line

* rename ie_version to version

* Pszmel/bind infer request (#51)

* remove set_batch, get_blob and set_blob

* update InferRequest class

* change InferenceEngine::InferRequest to ov::runtime::InferRequest

* update set_callback body

* update bindings to reflect ov::runtime::InferRequest

* bind set_input_tensor and get_input_tensor

* style fix

* clen ie_infer_queue.cpp

* Bind exec core ov (#50)

* bind core, exec_net classes

* rm unused function

* add new line

* rename ie_infer_request -> infer_request

* update imports

* update __init__.py

* update ie_api.py

* Replace old containers with the new one

* create impl for create_infer_request

* comment out infer_queue to avoid errors with old infer_request

* update infer_request bind to reflect new infer_request api

* comment out inpuit_info from ie_network to avoid errors with old containers

* Register new containers and comment out InferQueue

* update infer request tests

* style fix

* remove unused imports

* remove unused imports and 2 methods

* add tests to cover all new methods from infer_request

* style fix

* add test

* remove registration of InferResults

* update name of exception_ptr parameter

* update the loops that iterate through inputs and outputs

* clean setCustomCallbacks

* style fix

* add Tensor import

* style fix

* update infer and normalize_inputs

* style fix

* rename startTime and endTime

* Create test for mixed keys as infer arguments

* update infer function

* update return type of infer

Co-authored-by: Bartek Szmelczynski <bartosz.szmelczynski@intel.com>

* fix get_version

* fix opaque issue

* some cosmetic changes

* fix codestyle in tests

* make tests green

* Extend python InferRequest

* Extend python Function

* Change return value of infer call

* Fix missing precisions conversions in CPU plugin

* Rework of runtime for new tests

* Fixed onnx reading in python tests

* Edit compatibility tests

* Edit tests

* Add FLOAT_LIKE xfails

* bind ColorFormat and ResizeAlgorithm

* clean imports

* fix typo

* [Python API] bind ProfilingInfo (#55)

* bind ProfilingInfo

* Add tests

* Fix code style

* Add property

* fix codestyle

* Infer new request method (#56)

* fix conflicts, add infer_new_request function

* remove redundant functions, fix style

* revert the unwanted changes

* revert removal of the Blob

* revert removal of isTblob

* add add_extension from path

* codestyle

* add PostProcessSteps to init

* bind PreProcessSteps

* create additional tests

* fix win build

* add inputs-outputs to function

* update infer queue

* fix code style

* Hot-fix CPU plugin with precision

* fix start_async

* add performance hint to time infer (#8480)

* Updated common migration pipeline (#8176)

* Updated common migration pipeline

* Fixed merge issue

* Added new model and extended example

* Fixed typo

* Added v10-v11 comparison

* Avoid redundant graph nodes scans (#8415)

* Refactor work with env variables (#8208)

* del MO_ROOT

* del MO_ROOT from common_utils.py

* add MO_PATH to common_utils.py

* change mo_path

* [IE Sample Scripts] Use cmake to build samples (#8442)

* Use cmake to build samples

* Add the option to set custom build output folder

* Remove opset8 from compatibility ngraph python API (#8452)

* [GPU] OneDNN gpu submodule update to version 2.5 (#8449)

* [GPU] OneDNN gpu submodule update to version 2.5

* [GPU] Updated onednn submodule and added layout optimizer fix

* Install rules for static libraries case (#8384)

* Proper cmake install for static libraries case

* Added an ability to skip template plugin

* Added install rules for VPU / GPU

* Install more libraries

* Fixed absolute TBB include paths

* Disable GNA

* Fixed issue with linker

* Some fixes

* Fixed linkage issues in tests

* Disabled some tests

* Updated CI pipelines

* Fixed Windows linkage

* Fixed custom_opset test for static casr

* Fixed CVS-70313

* Continue on error

* Fixed clanf-format

* Try to fix Windows linker

* Fixed compilation

* Disable samples

* Fixed samples build with THREADING=SEQ

* Fixed link error on Windows

* Fixed ieFuncTests

* Added static Azure CI

* Revert "Fixed link error on Windows"

This reverts commit 78cca36.

* Merge static and dynamic linux pipelines

* Fixed Azure

* fix codestyle

* rename all methods in this class to snake_case

* some updates

* code style

* fix code style in tests

* update statistics reporting

* update filling inputs

* change ngraph.Type to ov.Type

* fix typo

* save work

* save work

* save work

* compute latency in callback

* save work

* Fix get_idle_request

* save work

* fix latency

* Fix code style

* update AppInputInfo

* add iteration to PatrialShape

* fix rebasing

* bind result::get_layout()

* correct mistakes

* fix setup

* use parameters/results instead inputs/outputs

* move _from_node to node_output.hpp

* add read_model from buffer

* update imports

* revert package struct

* add new line

* remove bad quotes

* update imports

* style fix

* add new line

* Fix preprocessing

* rename functin args

* set NCHW layout to image as default

* Fix input fillings

* remove Type import

* update tests

* style fix

* test clean

* remove blank line

* Add tensor_shape

* fix comments

* update PrePostProcessor init and build methods

* create test with model update tests with new PrePostProcessor init and build

* Change filling inputs

* fix preprocessing

* basic support dynamic shapes

* fix legacy mode

* rename ie to core

* fix cpp code style

* fix input files parsing

* fix binary filling

* support dynamic batch size

* process images with original shapes if no tensor shapes were given

* fix fps and number of iterations

* Add new metrics

* support pass path to folder into input mapping

* add pcseq flag

* fix resolving conflicts

* dump statistic per group

* check for compatibility with partial shape

* revert statistic report names

* code refactoring

* update parameters

* enable legacy_mode if data size less than nireq

* add serialize to offline_transformations

* Fix preprocessing import

* change log output due to ci parsing

* fix layout

* allow to pass batch size with undefined layout

* add serializer

* fix comments from jiwaszki

* Fix latency parsing for ci

* code style

* rename tensor_shape to data_shape

* add message if image is processed with original shape

* fix syntax warning

* remove default legacy_mode if requests cover all data

* rewrite all file parsing

* fix preprocessing

* Fix preprocessing #2

* Use layout instead str

* Fix file extensions

* Fix image sizes filling

* sort input files

* [Python API] quick fix of packaging

* update tests

* fix setup.py

* small fix

* small fixes according to comments

* skip mo frontend tests

* full mode is default for dynamic models only

* backward compatibility

* Fix package

* set layout in runtime

* static mode for dynamic models with all equal data shapes

* use get_tensor instead set_tensor in legacy mode

* benchmarking dynamic model available in full mode only

* fix layout detection

* use batch_size * iteration instead processed_frames in legacy mode

* fix tensor naming

* represent --inference_only

* refactoring main loop

* Fix number of iterations for full mode

Co-authored-by: Michael Nosov <mikhail.nosov@intel.com>
Co-authored-by: pszmel <piotr.szmelczynski@intel.com>
Co-authored-by: Bartek Szmelczynski <bartosz.szmelczynski@intel.com>
Co-authored-by: Anastasia Kuporosova <anastasia.kuporosova@intel.com>
Co-authored-by: jiwaszki <jan.iwaszkiewicz@intel.com>
Co-authored-by: Victor Kuznetsov <victor.kuznetsov@intel.com>
Co-authored-by: Ilya Churaev <ilya.churaev@intel.com>
Co-authored-by: Tomasz Jankowski <tomasz1.jankowski@intel.com>
Co-authored-by: Dmitry Pigasin <dmitry.pigasin@intel.com>
Co-authored-by: Artur Kulikowski <artur.kulikowski@intel.com>
Co-authored-by: Ilya Znamenskiy <ilya.znamenskiy@intel.com>
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
  • Loading branch information
13 people committed Dec 2, 2021
1 parent 87ea55f commit f9bd740
Show file tree
Hide file tree
Showing 9 changed files with 1,076 additions and 624 deletions.
19 changes: 19 additions & 0 deletions src/bindings/python/src/pyopenvino/graph/partial_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,25 @@ void regclass_graph_PartialShape(py::module m) {
},
py::is_operator());

shape.def("__len__", [](const ov::PartialShape& self) {
return self.size();
});

shape.def("__setitem__", [](ov::PartialShape& self, size_t key, ov::Dimension& d) {
self[key] = d;
});

shape.def("__getitem__", [](const ov::PartialShape& self, size_t key) {
return self[key];
});

shape.def(
"__iter__",
[](ov::PartialShape& self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>()); /* Keep vector alive while iterator is used */

shape.def("__str__", [](const ov::PartialShape& self) -> std::string {
std::stringstream ss;
ss << self;
Expand Down
9 changes: 8 additions & 1 deletion src/bindings/python/src/pyopenvino/graph/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <sstream>
#include <string>

#include "openvino/core/dimension.hpp" // ov::Dimension
#include "pyopenvino/graph/shape.hpp"

namespace py = pybind11;
Expand All @@ -24,7 +25,13 @@ void regclass_graph_Shape(py::module m) {
shape.def("__len__", [](const ov::Shape& v) {
return v.size();
});
shape.def("__getitem__", [](const ov::Shape& v, int key) {
shape.def("__setitem__", [](ov::Shape& self, size_t key, size_t d) {
self[key] = d;
});
shape.def("__setitem__", [](ov::Shape& self, size_t key, ov::Dimension d) {
self[key] = d.get_length();
});
shape.def("__getitem__", [](const ov::Shape& v, size_t key) {
return v[key];
});

Expand Down
237 changes: 141 additions & 96 deletions tools/benchmark_tool/openvino/tools/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import os
from datetime import datetime
from math import ceil
from openvino.inference_engine import IENetwork, IECore, get_version, StatusCode
from typing import Union
from openvino.runtime import Core, get_version, AsyncInferQueue

from .utils.constants import MULTI_DEVICE_NAME, HETERO_DEVICE_NAME, CPU_DEVICE_NAME, GPU_DEVICE_NAME, XML_EXTENSION, BIN_EXTENSION
from .utils.logging import logger
Expand All @@ -15,151 +16,195 @@ def percentile(values, percent):
return values[ceil(len(values) * percent / 100) - 1]

class Benchmark:
def __init__(self, device: str, number_infer_requests: int = None, number_iterations: int = None,
duration_seconds: int = None, api_type: str = 'async'):
def __init__(self, device: str, number_infer_requests: int = 0, number_iterations: int = None,
duration_seconds: int = None, api_type: str = 'async', inference_only = None):
self.device = device
self.ie = IECore()
self.nireq = number_infer_requests
self.core = Core()
self.nireq = number_infer_requests if api_type == 'async' else 1
self.niter = number_iterations
self.duration_seconds = get_duration_seconds(duration_seconds, self.niter, self.device)
self.api_type = api_type
self.inference_only = inference_only
self.latency_groups = []

def __del__(self):
del self.ie
del self.core

def add_extension(self, path_to_extension: str=None, path_to_cldnn_config: str=None):
if path_to_cldnn_config:
self.ie.set_config({'CONFIG_FILE': path_to_cldnn_config}, GPU_DEVICE_NAME)
self.core.set_config({'CONFIG_FILE': path_to_cldnn_config}, GPU_DEVICE_NAME)
logger.info(f'GPU extensions is loaded {path_to_cldnn_config}')

if path_to_extension:
self.ie.add_extension(extension_path=path_to_extension, device_name=CPU_DEVICE_NAME)
self.core.add_extension(extension_path=path_to_extension)
logger.info(f'CPU extensions is loaded {path_to_extension}')

def get_version_info(self) -> str:
logger.info(f"InferenceEngine:\n{'': <9}{'API version':.<24} {get_version()}")
version_string = 'Device info\n'
for device, version in self.ie.get_versions(self.device).items():
for device, version in self.core.get_versions(self.device).items():
version_string += f"{'': <9}{device}\n"
version_string += f"{'': <9}{version.description:.<24}{' version'} {version.major}.{version.minor}\n"
version_string += f"{'': <9}{'Build':.<24} {version.build_number}\n"
return version_string

def set_config(self, config = {}):
for device in config.keys():
self.ie.set_config(config[device], device)
self.core.set_config(config[device], device)

def set_cache_dir(self, cache_dir: str):
self.ie.set_config({'CACHE_DIR': cache_dir}, '')
self.core.set_config({'CACHE_DIR': cache_dir}, '')

def read_network(self, path_to_model: str):
def read_model(self, path_to_model: str):
model_filename = os.path.abspath(path_to_model)
head, ext = os.path.splitext(model_filename)
weights_filename = os.path.abspath(head + BIN_EXTENSION) if ext == XML_EXTENSION else ""
ie_network = self.ie.read_network(model_filename, weights_filename)
return ie_network

def load_network(self, ie_network: IENetwork, config = {}):
exe_network = self.ie.load_network(ie_network,
self.device,
config=config,
num_requests=1 if self.api_type == 'sync' else self.nireq or 0)
# Number of requests
self.nireq = len(exe_network.requests)

return exe_network

def load_network_from_file(self, path_to_model: str, config = {}):
exe_network = self.ie.load_network(path_to_model,
self.device,
config=config,
num_requests=1 if self.api_type == 'sync' else self.nireq or 0)
# Number of requests
self.nireq = len(exe_network.requests)

return exe_network

def import_network(self, path_to_file : str, config = {}):
exe_network = self.ie.import_network(model_file=path_to_file,
device_name=self.device,
config=config,
num_requests=1 if self.api_type == 'sync' else self.nireq or 0)
# Number of requests
self.nireq = len(exe_network.requests)
return exe_network

def first_infer(self, exe_network):
infer_request = exe_network.requests[0]

# warming up - out of scope
return self.core.read_model(model_filename, weights_filename)

def create_infer_requests(self, exe_network):
if self.api_type == 'sync':
infer_request.infer()
requests = [exe_network.create_infer_request()]
else:
infer_request.async_infer()
status = infer_request.wait()
if status != StatusCode.OK:
raise Exception(f"Wait for all requests is failed with status code {status}!")
return infer_request.latency
requests = AsyncInferQueue(exe_network, self.nireq)
self.nireq = len(requests)
return requests

def infer(self, exe_network, batch_size, latency_percentile, progress_bar=None):
def first_infer(self, requests):
if self.api_type == 'sync':
requests[0].infer()
return requests[0].latency
else:
id = requests.get_idle_request_id()
requests.start_async()
requests.wait_all()
return requests[id].latency

def update_progress_bar(self, progress_bar, exec_time, progress_count):
if self.duration_seconds:
# calculate how many progress intervals are covered by current iteration.
# depends on the current iteration time and time of each progress interval.
# Previously covered progress intervals must be skipped.
progress_interval_time = self.duration_seconds / progress_bar.total_num
new_progress = int(exec_time / progress_interval_time - progress_count)
progress_bar.add_progress(new_progress)
progress_count += new_progress
elif self.niter:
progress_bar.add_progress(1)
return progress_count

def sync_inference(self, request, data_queue, progress_bar):
progress_count = 0
infer_requests = exe_network.requests

exec_time = 0
iteration = 0
times = []
start_time = datetime.utcnow()
while (self.niter and iteration < self.niter) or \
(self.duration_seconds and exec_time < self.duration_seconds):
if self.inference_only == False:
request.set_input_tensors(data_queue.get_next_input())
request.infer()
times.append(request.latency)
iteration += 1

exec_time = (datetime.utcnow() - start_time).total_seconds()

if progress_bar:
progress_count = self.update_progress_bar(progress_bar, exec_time, progress_count)

total_duration_sec = (datetime.utcnow() - start_time).total_seconds()
return sorted(times), total_duration_sec, iteration

def async_inference_only(self, infer_queue, progress_bar):
progress_count = 0
exec_time = 0
iteration = 0
times = []
in_fly = set()
start_time = datetime.utcnow()
while (self.niter and iteration < self.niter) or \
(self.duration_seconds and exec_time < self.duration_seconds) or \
(iteration % self.nireq):
idle_id = infer_queue.get_idle_request_id()
if idle_id in in_fly:
times.append(infer_queue[idle_id].latency)
else:
in_fly.add(idle_id)
infer_queue.start_async()
iteration += 1

exec_time = (datetime.utcnow() - start_time).total_seconds()

if progress_bar:
progress_count = self.update_progress_bar(progress_bar, exec_time, progress_count)

infer_queue.wait_all()
total_duration_sec = (datetime.utcnow() - start_time).total_seconds()
for infer_request_id in in_fly:
times.append(infer_queue[infer_request_id].latency)
return sorted(times), total_duration_sec, iteration

def async_inference_full_mode(self, infer_queue, data_queue, progress_bar, pcseq):
progress_count = 0
processed_frames = 0
exec_time = 0
iteration = 0
times = []
num_groups = len(self.latency_groups)
in_fly = set()
# Start inference & calculate performance
# to align number if iterations to guarantee that last infer requests are executed in the same conditions **/
start_time = datetime.utcnow()
while (self.niter and iteration < self.niter) or \
(self.duration_seconds and exec_time < self.duration_seconds) or \
(self.api_type == 'async' and iteration % self.nireq):
if self.api_type == 'sync':
infer_requests[0].infer()
times.append(infer_requests[0].latency)
(iteration % num_groups):
processed_frames += data_queue.get_next_batch_size()
idle_id = infer_queue.get_idle_request_id()
if idle_id in in_fly:
times.append(infer_queue[idle_id].latency)
if pcseq:
self.latency_groups[infer_queue.userdata[idle_id]].times.append(infer_queue[idle_id].latency)
else:
infer_request_id = exe_network.get_idle_request_id()
if infer_request_id < 0:
status = exe_network.wait(num_requests=1)
if status != StatusCode.OK:
raise Exception("Wait for idle request failed!")
infer_request_id = exe_network.get_idle_request_id()
if infer_request_id < 0:
raise Exception("Invalid request id!")
if infer_request_id in in_fly:
times.append(infer_requests[infer_request_id].latency)
else:
in_fly.add(infer_request_id)
infer_requests[infer_request_id].async_infer()
in_fly.add(idle_id)
group_id = data_queue.current_group_id
infer_queue[idle_id].set_input_tensors(data_queue.get_next_input())
infer_queue.start_async(userdata=group_id)
iteration += 1

exec_time = (datetime.utcnow() - start_time).total_seconds()

if progress_bar:
if self.duration_seconds:
# calculate how many progress intervals are covered by current iteration.
# depends on the current iteration time and time of each progress interval.
# Previously covered progress intervals must be skipped.
progress_interval_time = self.duration_seconds / progress_bar.total_num
new_progress = int(exec_time / progress_interval_time - progress_count)
progress_bar.add_progress(new_progress)
progress_count += new_progress
elif self.niter:
progress_bar.add_progress(1)

# wait the latest inference executions
status = exe_network.wait()
if status != StatusCode.OK:
raise Exception(f"Wait for all requests is failed with status code {status}!")
progress_count = self.update_progress_bar(progress_bar, exec_time, progress_count)

infer_queue.wait_all()
total_duration_sec = (datetime.utcnow() - start_time).total_seconds()
for infer_request_id in in_fly:
times.append(infer_requests[infer_request_id].latency)
times.sort()
latency_ms = percentile(times, latency_percentile)
fps = batch_size * 1000 / latency_ms if self.api_type == 'sync' else batch_size * iteration / total_duration_sec
times.append(infer_queue[infer_request_id].latency)
return sorted(times), total_duration_sec, processed_frames, iteration

def main_loop(self, requests, data_queue, batch_size, latency_percentile, progress_bar, pcseq):
if self.api_type == 'sync':
times, total_duration_sec, iteration = self.sync_inference(requests[0], data_queue, progress_bar)
elif self.inference_only:
times, total_duration_sec, iteration = self.async_inference_only(requests, progress_bar)
fps = len(batch_size) * iteration / total_duration_sec
else:
times, total_duration_sec, processed_frames, iteration = self.async_inference_full_mode(requests, data_queue, progress_bar, pcseq)
fps = processed_frames / total_duration_sec

median_latency_ms = percentile(times, latency_percentile)
avg_latency_ms = sum(times) / len(times)
min_latency_ms = times[0]
max_latency_ms = times[-1]

if self.api_type == 'sync':
fps = len(batch_size) * 1000 / median_latency_ms

if pcseq:
for group in self.latency_groups:
if group.times:
group.times.sort()
group.avg = sum(group.times) / len(group.times)
group.min = group.times[0]
group.max = group.times[-1]

if progress_bar:
progress_bar.finish()
return fps, latency_ms, total_duration_sec, iteration
return fps, median_latency_ms, avg_latency_ms, min_latency_ms, max_latency_ms, total_duration_sec, iteration
Loading

0 comments on commit f9bd740

Please sign in to comment.