Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions c_cxx/OpenVINO_EP/Windows/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ if(OPENCV_ROOTDIR)
set(OPENCV_FOUND true)
set(OPENCV_INCLUDE_DIRS "${OPENCV_ROOTDIR}/include")
set(OPENCV_LIBDIR "${OPENCV_ROOTDIR}/x64/vc16/lib")
file(GLOB OPENCV_DEBUG_LIBRARIES ${OPENCV_LIBDIR}/opencv_world470d.lib)
file(GLOB OPENCV_RELEASE_LIBRARIES ${OPENCV_LIBDIR}/opencv_world470.lib)
file(GLOB OPENCV_DEBUG_LIBRARIES "${OPENCV_LIBDIR}/opencv_world*d.lib")
file(GLOB OPENCV_RELEASE_LIBRARIES "${OPENCV_LIBDIR}/opencv_world*.lib")
list(FILTER OPENCV_RELEASE_LIBRARIES EXCLUDE REGEX ".*d\\.lib")
endif()

Expand All @@ -41,8 +41,4 @@ if(OPENCV_FOUND)
add_subdirectory(squeezenet_classification)
endif()

if(OPENCL_FOUND)
add_subdirectory(squeezenet_classification_io_buffer)
endif()

add_subdirectory(model-explorer)
80 changes: 32 additions & 48 deletions c_cxx/OpenVINO_EP/Windows/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,43 @@

1. model-explorer

This sample application demonstrates how to use components of the experimental C++ API to query for model inputs/outputs and how to run inferrence using OpenVINO Execution Provider for ONNXRT on a model. The source code for this sample is available [here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx/OpenVINO_EP/Windows/model-explorer).
This sample application demonstrates how to use the **ONNX Runtime C++ API** with the OpenVINO Execution Provider (OVEP).
It loads an ONNX model, inspects the input/output node names and shapes, creates random input data, and runs inference.
The sample is useful for exploring model structure and verifying end-to-end execution with OVEP. [here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx/OpenVINO_EP/Windows/model-explorer).

2. Squeezenet classification sample

The sample involves presenting an image to the ONNX Runtime (RT), which uses the OpenVINO Execution Provider for ONNXRT to run inference on various Intel hardware devices like Intel CPU, GPU, VPU and more. The sample uses OpenCV for image processing and ONNX Runtime OpenVINO EP for inference. After the sample image is inferred, the terminal will output the predicted label classes in order of their confidence. The source code for this sample is available [here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx/OpenVINO_EP/Windows/squeezenet_classification).

3. Squeezenet classification sample with IO Buffer feature

This sample is also doing the same process but with IO Buffer optimization enabled. With IO Buffer interfaces we can avoid any memory copy overhead when plugging OpenVINO™ inference into an existing GPU pipeline. It also enables OpenCL kernels to participate in the pipeline to become native buffer consumers or producers of the OpenVINO™ inference. Refer [here](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_GPU_RemoteTensor_API.html) for more details. This sample is for GPUs only. The source code for this sample is available [here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx/OpenVINO_EP/Windows/squeezenet_classification_io_buffer).
The sample involves presenting an image to the ONNX Runtime (RT), which uses the OpenVINO Execution Provider for ONNXRT to run inference on various Intel hardware devices like Intel CPU, GPU and NPU. The sample uses OpenCV for image processing and ONNX Runtime OpenVINO EP for inference. After the sample image is inferred, the terminal will output the predicted label classes in order of their confidence. The source code for this sample is available [here](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx/OpenVINO_EP/Windows/squeezenet_classification).

## How to build

## Prerequisites
1. [The Intel<sup>®</sup> Distribution of OpenVINO toolkit](https://docs.openvinotoolkit.org/latest/index.html)
2. Use opencv
3. Use opencl for IO buffer sample (squeezenet_cpp_app_io.cpp).
4. Use any sample image as input to the sample.
5. Download the latest Squeezenet model from the ONNX Model Zoo.
This example was adapted from [ONNX Model Zoo](https://github.com/onnx/models).Download the latest version of the [Squeezenet](https://github.com/onnx/models/tree/master/validated/vision/classification/squeezenet) model from here.
1. [The Intel<sup>®</sup> Distribution of OpenVINO toolkit](https://docs.openvino.ai/2025/get-started/install-openvino.html)
2. Use opencv [OpenCV](https://opencv.org/releases/)
3. Use any sample image as input to the sample.
4. Download the latest Squeezenet model from the ONNX Model Zoo.
This example was adapted from [ONNX Model Zoo](https://github.com/onnx/models). Download the latest version of the [Squeezenet](https://github.com/onnx/models/tree/master/vision/classification/squeezenet) model from here.

#### Build ONNX Runtime
Open x64 Native Tools Command Prompt for VS 2019.
For running the sample with IO Buffer optimization feature, make sure you set the OpenCL paths. For example if you are setting the path from openvino source build folder, the paths will be like:
## Build ONNX Runtime with OpenVINO on Windows

Make sure you open **x64 Native Tools Command Prompt for VS 2019** before running the following steps.

### 1. Download OpenVINO package
Download the OpenVINO archive package from the official repository:
[OpenVINO Archive Packages](https://storage.openvinotoolkit.org/repositories/openvino/packages)

Extract the downloaded archive to a directory (e.g., `C:\openvino`).

---

### 2. Set up OpenVINO environment
After extracting, run the following command to set up environment variables:

```cmd
"C:\openvino\setupvars.bat"
```
set OPENCL_LIBS=\path\to\openvino\folder\bin\intel64\Release\OpenCL.lib
set OPENCL_INCS=\path\to\openvino\folder\thirdparty\ocl\clhpp_headers\include
```

### 3. Build ONNX Runtime

```
build.bat --config RelWithDebInfo --use_openvino CPU --build_shared_lib --parallel --cmake_extra_defines CMAKE_INSTALL_PREFIX=c:\dev\ort_install --skip_tests
Expand All @@ -43,44 +52,23 @@ cd build\Windows\RelWithDebInfo
msbuild INSTALL.vcxproj /p:Configuration=RelWithDebInfo
```

#### Build the samples
### Build the samples

Open x64 Native Tools Command Prompt for VS 2019, Git clone the sample repo.
Open x64 Native Tools Command Prompt for VS 2022, Git clone the sample repo.
```
git clone https://github.com/microsoft/onnxruntime-inference-examples.git
```
Change your current directory to c_cxx\OpenVINO_EP\Windows, then run

```bat
mkdir build && cd build
cmake .. -A x64 -T host=x64 -Donnxruntime_USE_OPENVINO=ON -DONNXRUNTIME_ROOTDIR=c:\dev\ort_install -DOPENCV_ROOTDIR="path\to\opencv"
```
Choose required opencv path. Skip the opencv flag if you don't want to build squeezenet sample.

To get the squeezenet sample with IO buffer feature enabled, pass opencl paths as well:
```bat
mkdir build && cd build
cmake .. -A x64 -T host=x64 -Donnxruntime_USE_OPENVINO=ON -DONNXRUNTIME_ROOTDIR=c:\dev\ort_install -DOPENCV_ROOTDIR="path\to\opencv" -DOPENCL_LIB=path\to\openvino\folder\bin\intel64\Release\ -DOPENCL_INCLUDE="path\to\openvino\folder\thirdparty\ocl\clhpp_headers\include;path\to\openvino\folder\thirdparty\ocl\cl_headers"
```

**Note:**
If you are using the opencv from openvino package, below are the paths:
* For openvino version 2022.1.0, run download_opencv.ps1 in \path\to\openvino\extras\script and the opencv folder will be downloaded at \path\to\openvino\extras.
* For older openvino version, opencv folder is available at openvino directory itself.
* The current cmake files are adjusted with the opencv folders coming along with openvino packages. Plase make sure you are updating the opencv paths according to your custom builds.

For the squeezenet IO buffer sample:
Make sure you are creating the opencl context for the right GPU device in a multi-GPU environment.

Build samples using msbuild for Debug configuration. For Release configuration replace Debug with Release.

```bat
msbuild onnxruntime_samples.sln /p:Configuration=Debug
```

### Note
To run the samples make sure you source openvino variables using setupvars.bat.

To run the samples download and install(extract) OpenCV from: [download OpenCV](https://github.com/opencv/opencv/releases/download/4.7.0/opencv-4.7.0-windows.exe). Also copy OpenCV dll (opencv_world470.dll which is located at: "path\to\opencv\build\x64\vc16\bin") to the location of the application exe file(Release dll for release build) and (opencv_world470d.dll which is located at:"path\to\opencv\build\x64\vc16\bin") to the location of the application exe file (debug dll for debug build).

#### Run the sample

- To Run the general sample
Expand All @@ -96,13 +84,9 @@ To run the samples download and install(extract) OpenCV from: [download OpenCV](
```
run_squeezenet.exe --use_cpu <path_to_onnx_model> <path_to_sample_image> <path_to_labels_file>
```
- To Run the sample for IO Buffer Optimization feature
```
run_squeezenet.exe <path_to_onnx_model> <path_to_sample_image> <path_to_labels_file>
```

## References:

[OpenVINO Execution Provider](https://www.intel.com/content/www/us/en/artificial-intelligence/posts/faster-inferencing-with-one-line-of-code.html)
[OpenVINO Execution Provider](https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html)

[Other ONNXRT Reference Samples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx)
93 changes: 68 additions & 25 deletions c_cxx/OpenVINO_EP/Windows/model-explorer/model-explorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include <iostream>
#include <sstream>
#include <vector>
#include <experimental_onnxruntime_cxx_api.h>
#include <onnxruntime_cxx_api.h>

// pretty prints a shape dimension vector
std::string print_shape(const std::vector<int64_t>& v) {
Expand Down Expand Up @@ -64,59 +64,102 @@ int main(int argc, char** argv) {
//Appending OpenVINO Execution Provider API
#ifdef USE_OPENVINO
// Using OPENVINO backend
OrtOpenVINOProviderOptions options;
options.device_type = "CPU";
std::cout << "OpenVINO device type is set to: " << options.device_type << std::endl;
session_options.AppendExecutionProvider_OpenVINO(options);
std::unordered_map<std::string, std::string> options;
options["device_type"] = "CPU";
std::cout << "OpenVINO device type is set to: " << options["device_type"] << std::endl;
session_options.AppendExecutionProvider_OpenVINO_V2(options);
#endif
Ort::Experimental::Session session = Ort::Experimental::Session(env, model_file, session_options); // access experimental components via the Experimental namespace

// print name/shape of inputs
std::vector<std::string> input_names = session.GetInputNames();
std::vector<std::vector<int64_t> > input_shapes = session.GetInputShapes();
cout << "Input Node Name/Shape (" << input_names.size() << "):" << endl;
for (size_t i = 0; i < input_names.size(); i++) {
Ort::Session session(env, model_file.c_str(), session_options);
Ort::AllocatorWithDefaultOptions allocator;

size_t num_input_nodes = session.GetInputCount();
std::vector<std::string> input_names;
std::vector<std::vector<int64_t>> input_shapes;

cout << "Input Node Name/Shape (" << num_input_nodes << "):" << endl;
for (size_t i = 0; i < num_input_nodes; i++) {
// Get input name
auto input_name = session.GetInputNameAllocated(i, allocator);
input_names.push_back(std::string(input_name.get()));

// Get input shape
Ort::TypeInfo input_type_info = session.GetInputTypeInfo(i);
auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> input_dims = input_tensor_info.GetShape();
input_shapes.push_back(input_dims);

cout << "\t" << input_names[i] << " : " << print_shape(input_shapes[i]) << endl;

}

// print name/shape of outputs
std::vector<std::string> output_names = session.GetOutputNames();
std::vector<std::vector<int64_t> > output_shapes = session.GetOutputShapes();
cout << "Output Node Name/Shape (" << output_names.size() << "):" << endl;
for (size_t i = 0; i < output_names.size(); i++) {
size_t num_output_nodes = session.GetOutputCount();
std::vector<std::string> output_names;
std::vector<std::vector<int64_t>> output_shapes;

cout << "Output Node Name/Shape (" << num_output_nodes << "):" << endl;
for (size_t i = 0; i < num_output_nodes; i++) {
// Get output name
auto output_name = session.GetOutputNameAllocated(i, allocator);
output_names.push_back(std::string(output_name.get()));

// Get output shape
Ort::TypeInfo output_type_info = session.GetOutputTypeInfo(i);
auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> output_dims = output_tensor_info.GetShape();
output_shapes.push_back(output_dims);

cout << "\t" << output_names[i] << " : " << print_shape(output_shapes[i]) << endl;

}

// Assume model has 1 input node and 1 output node.

assert(input_names.size() == 1 && output_names.size() == 1);

// Create a single Ort tensor of random numbers
auto input_shape = input_shapes[0];
int total_number_elements = calculate_product(input_shape);
std::vector<float> input_tensor_values(total_number_elements);
std::generate(input_tensor_values.begin(), input_tensor_values.end(), [&] { return rand() % 255; }); // generate random numbers in the range [0, 255]
std::generate(input_tensor_values.begin(), input_tensor_values.end(), [&] { return rand() % 255; });

// Create input tensor
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<Ort::Value> input_tensors;
input_tensors.push_back(Ort::Experimental::Value::CreateTensor<float>(input_tensor_values.data(), input_tensor_values.size(), input_shape));
input_tensors.push_back(Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(),
input_tensor_values.size(), input_shape.data(),
input_shape.size()));

// double-check the dimensions of the input tensor
assert(input_tensors[0].IsTensor() &&
input_tensors[0].GetTensorTypeAndShapeInfo().GetShape() == input_shape);
cout << "\ninput_tensor shape: " << print_shape(input_tensors[0].GetTensorTypeAndShapeInfo().GetShape()) << endl;

// Create input/output name arrays for Run()
std::vector<const char*> input_names_char(input_names.size(), nullptr);
std::vector<const char*> output_names_char(output_names.size(), nullptr);

for (size_t i = 0; i < input_names.size(); i++) {
input_names_char[i] = input_names[i].c_str();
}
for (size_t i = 0; i < output_names.size(); i++) {
output_names_char[i] = output_names[i].c_str();
}

// pass data through model
cout << "Running model...";
try {
auto output_tensors = session.Run(session.GetInputNames(), input_tensors, session.GetOutputNames());
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names_char.data(),
input_tensors.data(), input_names_char.size(),
output_names_char.data(), output_names_char.size());
cout << "done" << endl;

// double-check the dimensions of the output tensors
// NOTE: the number of output tensors is equal to the number of output nodes specifed in the Run() call
assert(output_tensors.size() == session.GetOutputNames().size() &&
output_tensors[0].IsTensor());
assert(output_tensors.size() == output_names.size() && output_tensors[0].IsTensor());
cout << "output_tensor_shape: " << print_shape(output_tensors[0].GetTensorTypeAndShapeInfo().GetShape()) << endl;

} catch (const Ort::Exception& exception) {
cout << "ERROR running model inference: " << exception.what() << endl;
exit(-1);
}

return 0;

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ if(OPENCV_LIBDIR)
endif()

#In onnxruntime deafault install path, the required dlls are in lib and bin folders
set(DLL_DIRS "${ONNXRUNTIME_ROOTDIR}/lib;${ONNXRUNTIME_ROOTDIR}/bin")
set(DLL_DIRS "${ONNXRUNTIME_ROOTDIR}/lib;${ONNXRUNTIME_ROOTDIR}/bin;${OPENCV_ROOTDIR}/x64/vc16/bin")

foreach(DLL_DIR IN LISTS DLL_DIRS)
file(GLOB ALL_DLLS ${DLL_DIR}/*.dll)
foreach(ORTDll IN LISTS ALL_DLLS)
foreach(DLLFile IN LISTS ALL_DLLS)
add_custom_command(TARGET run_squeezenet POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${ORTDll}"
$<TARGET_FILE_DIR:run_squeezenet>)
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${DLLFile}"
$<TARGET_FILE_DIR:run_squeezenet>)
endforeach()
endforeach()
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,16 @@ int main(int argc, char* argv[])
//Appending OpenVINO Execution Provider API
if (useOPENVINO) {
// Using OPENVINO backend
OrtOpenVINOProviderOptions options;
options.device_type = "CPU";
std::cout << "OpenVINO device type is set to: " << options.device_type << std::endl;
sessionOptions.AppendExecutionProvider_OpenVINO(options);
std::unordered_map<std::string, std::string> options;
options["device_type"] = "CPU";
std::string config = R"({
"CPU": {
"INFERENCE_NUM_THREADS": "1"
}
})";
options["load_config"] = config;
std::cout << "OpenVINO device type is set to: " << options["device_type"] << std::endl;
sessionOptions.AppendExecutionProvider_OpenVINO_V2(options);
}

// Sets graph optimization level
Expand Down
2 changes: 1 addition & 1 deletion python/OpenVINO_EP/tiny_yolo_v2_object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ python3 tiny_yolov2_obj_detection_sample.py --h
```
## Running the ONNXRuntime OpenVINO™ Execution Provider sample
```bash
python3 tiny_yolov2_obj_detection_sample.py --video face-demographics-walking-and-pause.mp4 --model tinyyolov2.onnx --device CPU_FP32
python3 tiny_yolov2_obj_detection_sample.py --video face-demographics-walking-and-pause.mp4 --model tinyyolov2.onnx --device CPU
```

## To stop the sample from running
Expand Down
Loading
Loading