From 25c43c1fd66de00d3e7ebe58ecd4b37e060e5be5 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Fri, 2 May 2025 21:38:54 -0700 Subject: [PATCH 01/84] K quant (#24615) ### Description Integrate some neural compressor code since the ORT side in the repo is in maintenance mode. ### Motivation and Context Enable k-quant quantization. --- cmake/onnxruntime_python.cmake | 7 + .../quantization/matmul_nbits_quantizer.py | 49 +- .../neural_compressor/__init__.py | 1 + .../neural_compressor/onnx_model.py | 1264 +++++++++++++++++ .../quantization/neural_compressor/util.py | 80 ++ .../neural_compressor/weight_only.py | 932 ++++++++++++ .../quantization/test_op_matmul_4bits.py | 4 + setup.py | 1 + 8 files changed, 2330 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/python/tools/quantization/neural_compressor/__init__.py create mode 100644 onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py create mode 100644 onnxruntime/python/tools/quantization/neural_compressor/util.py create mode 100644 onnxruntime/python/tools/quantization/neural_compressor/weight_only.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c57a2a962303d..8f7a96e052fa1 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -468,6 +468,9 @@ file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py" ) +file(GLOB onnxruntime_python_quantization_neural_compressor_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/quantization/neural_compressor/*.py" +) file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) @@ -581,6 +584,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/neural_compressor COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models @@ -660,6 +664,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_quantization_ep_qnn_src} $/onnxruntime/quantization/execution_providers/qnn/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_quantization_neural_compressor_src} + $/onnxruntime/quantization/neural_compressor/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_src} $/onnxruntime/transformers/ diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index b1d58b713eea8..ef08b56cfe7ad 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -21,6 +21,7 @@ from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_matmul_8bits, quantize_qdq_matmul_4bits from .calibrate import CalibrationDataReader +from .neural_compressor import gptq_quantize, rtn_quantize from .onnx_model import ONNXModel from .quant_utils import QuantFormat, attribute_to_kwarg @@ -98,6 +99,40 @@ def __init__( self.ratios = ratios +class KQuantWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + ratios=None, + quant_format=QuantFormat.QOperator, + op_types_to_quantize: tuple[str, ...] | None = None, + customized_weight_config: dict | None = None, + ): + """ + This is a class for k-quant algorithm Weight Only Quant Configuration. + + Args: + ratios: + percentile of clip. Defaults to {}. + quant_format (QuantFormat{QOperator, QDQ}, optional): + QOperator format quantizes the model with quantized operators directly. + QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor. + Defaults to QuantFormat.QOperator. + op_types_to_quantize (optional): + set of operator types to quantize. + """ + assert quant_format == QuantFormat.QOperator, "k-quant only supports QOperator format" + + if ratios is None: + ratios = {} + super().__init__( + algorithm="k_quant", + quant_format=quant_format, + op_types_to_quantize=op_types_to_quantize, + customized_weight_config=customized_weight_config, + ) + self.ratios = ratios + + class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, @@ -1258,14 +1293,12 @@ def inc_dataloader(): algorithm = self.algo_config.algorithm logger.info(f"start to quantize model with {algorithm} algorithm...") - if algorithm == "RTN": - from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize - + if algorithm in ["RTN", "k_quant"]: kwargs["ratios"] = self.algo_config.ratios + kwargs["algorithm"] = algorithm """ - neural-compressor uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though. - https://github.com/intel/neural-compressor/blob/a617115b1490bbe6163c0024fb55bd260c8914df/neural_compressor/adaptor/ox_utils/weight_only.py#L343 + We uses fp32 to represent the node that skip quantization, it does not mean this node is fp32 type though. """ for n in self.nodes_to_exclude: weight_only_node_config[n] = "fp32" @@ -1276,8 +1309,6 @@ def inc_dataloader(): **kwargs, ) elif algorithm == "GPTQ": - from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize - kwargs["percdamp"] = self.algo_config.percdamp kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder @@ -1370,7 +1401,7 @@ def parse_args(): "--quant_method", default="default", type=str, - choices=["default", "hqq", "rtn", "gptq", "nvidia_awq"], + choices=["default", "hqq", "rtn", "k_quant", "gptq", "nvidia_awq"], help="the algorithm used to quantize weight, \nrtn and gptq leverage IntelĀ® Neural Compressor", ) parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") @@ -1500,6 +1531,8 @@ def parse_args(): ) elif args.quant_method == "rtn": quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize) + elif args.quant_method == "k_quant": + quant_config = KQuantWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize) elif args.quant_method == "gptq": quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize) elif args.quant_method == "nvidia_awq": diff --git a/onnxruntime/python/tools/quantization/neural_compressor/__init__.py b/onnxruntime/python/tools/quantization/neural_compressor/__init__.py new file mode 100644 index 0000000000000..08b9a38624c98 --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/__init__.py @@ -0,0 +1 @@ +from .weight_only import gptq_quantize, rtn_quantize # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py b/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py new file mode 100644 index 0000000000000..f931045c4e349 --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/onnx_model.py @@ -0,0 +1,1264 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Class for ONNX model.""" + +import logging +import os +import sys +from pathlib import Path + +import onnx + +from .util import MAXIMUM_PROTOBUF, find_by_name + +logger = logging.getLogger("neural_compressor") + +# TODO: Check https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/onnx_model.py to see if we can integrate with it. + + +class ONNXModel: + """Build ONNX model.""" + + def __init__(self, model, **kwargs): + """Initialize an ONNX model. + + Args: + model (str or ModelProto): path to onnx model or loaded ModelProto model object. + ignore_warning (bool): ignore large model warning. Default is False. + load_external_data (bool): load external data for large model. Default is True. + """ + self._model = model if not isinstance(model, str) else onnx.load(model, load_external_data=False) + self._model_path = None if not isinstance(model, str) else model + + self.check_is_large_model() + if self._is_large_model and self._model_path is None and not kwargs.get("ignore_warning", False): + logger.warning("Model size > 2GB. Please use model path instead of onnx model object to quantize") + + if self._is_large_model and isinstance(model, str) and kwargs.get("load_external_data", True): + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.dirname(self._model_path)) + + self._config = None + if isinstance(model, str) and os.path.exists(Path(model).parent.joinpath("config.json").as_posix()): + from transformers import AutoConfig + + self._config = AutoConfig.from_pretrained(Path(model).parent.as_posix()) + + self.node_name_counter = {} + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + self._graph_info = {} + self._get_graph_info() + self._q_config = None + + def check_is_large_model(self): + """Check model > 2GB.""" + init_size = 0 + for init in self._model.graph.initializer: + # if initializer has external data location, return True + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + self._is_large_model = True + return + # if raise error of initializer size > 2GB, return True + try: + init_bytes = init.SerializeToString() + init_size += sys.getsizeof(init_bytes) + except Exception as e: + if "exceeds maximum protobuf size of 2GB" in str(e): + self._is_large_model = True + return + else: # pragma: no cover + raise e + if init_size > MAXIMUM_PROTOBUF: + self._is_large_model = True + return + self._is_large_model = False + + @property + def is_large_model(self): + """Check the onnx model is over 2GB.""" + return self._is_large_model + + @property + def model_path(self): + """Return model path.""" + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path.""" + self._model_path = path + + def framework(self): + """Return framework.""" + return "onnxruntime" + + @property + def q_config(self): + """Return q_config.""" + return self._q_config + + @q_config.setter + def q_config(self, q_config): + """Set q_config.""" + self._q_config = q_config + + @property + def hf_config(self): + """Return huggingface config if model is Transformer-based.""" + return self._config + + @property + def model(self): + """Return model itself.""" + return self._model + + @model.setter + def model(self, model): + """Set model itself.""" + self._model = model + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + def input(self): + """Return input of model.""" + return [i.name for i in self._model.graph.input] + + def output(self): + """Return output of model.""" + return [i.name for i in self._model.graph.output] + + def update(self): + """Update model info.""" + self._graph_info = {} + self._get_graph_info() + self._output_name_to_node = {} + self._input_name_to_nodes = {} + self._get_input_name_to_nodes(self._model.graph.node) + self._get_output_name_to_node(self._model.graph.node) + + @property + def graph_info(self): + """Return ORT Graph Info object holding information about backend graph.""" + return self._graph_info + + def _get_graph_info(self): + """Update graph info.""" + for node in self._model.graph.node: + self.graph_info.update({node.name: node.op_type}) + + def save(self, root): + """Save ONNX model.""" + if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]): + raise ValueError('"root" directory does not exists.') + if self.is_large_model: # pragma: no cover + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(self._model, os.path.split(self._model_path)[0]) + onnx.save_model( + self._model, + root, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=root.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + else: + onnx.save(self._model, root) + + if self._config is not None: + model_type = "" if not hasattr(self._config, "model_type") else self._config.model_type + self._config.__class__.model_type = model_type + output_config_file = Path(root).parent.joinpath("config.json").as_posix() + self._config.to_json_file(output_config_file, use_diff=False) + + def nodes(self): + """Return model nodes.""" + return self._model.graph.node + + def initializer(self): + """Return model initializer.""" + return self._model.graph.initializer + + def graph(self): + """Return model graph.""" + return self._model.graph + + def ir_version(self): + """Return model ir_version.""" + return self._model.ir_version + + def opset_import(self): + """Return model opset_import.""" + return self._model.opset_import + + def remove_node(self, node): + """Remove a node from model.""" + if node in self._model.graph.node: + self._model.graph.node.remove(node) + + def remove_nodes(self, nodes_to_remove): + """Remove nodes from model.""" + for node in nodes_to_remove: + self.remove_node(node) + + def add_node(self, node): + """Add a node to model.""" + self._model.graph.node.extend([node]) + + def add_nodes(self, nodes_to_add): + """Add nodes to model.""" + self._model.graph.node.extend(nodes_to_add) + + def add_initializer(self, tensor): + """Add a initializer to model.""" + if find_by_name(tensor.name, self._model.graph.initializer) is None: + self._model.graph.initializer.extend([tensor]) + + def add_initializers(self, tensors): + """Add initializers to model.""" + for tensor in tensors: + self.add_initializer(tensor) + + def get_initializer(self, name): + """Get an initializer by name.""" + for tensor in self._model.graph.initializer: + if tensor.name == name: + return tensor + return None + + def get_initializer_share_num(self, name): + """Get the number of shares of initializer.""" + num = 0 + if self.get_initializer(name) is None: + return num + + for node in self.nodes(): + if name in node.input: + num += 1 + return num + + def get_node(self, name): + """Get a node by name.""" + for node in self._model.graph.node: + if node.name == name: + return node + return None + + def remove_initializer(self, tensor): + """Remove an initializer from model.""" + if tensor in self._model.graph.initializer: + self._model.graph.initializer.remove(tensor) + + def remove_initializers(self, init_to_remove): + """Remove initializers from model.""" + for initializer in init_to_remove: + self.remove_initializer(initializer) + + def set_initializer(self, tensor, array, raw=False): + """Update initializer.""" + old_tensor = self.get_initializer(tensor) + self.remove_initializer(old_tensor) + dims = old_tensor.dims + data_type = old_tensor.data_type + new_tensor = ( + onnx.helper.make_tensor(tensor, data_type, dims, array.flatten().tolist()) + if not raw + else onnx.helper.make_tensor(tensor, data_type, dims, array.tostring(), raw=raw) + ) + self.add_initializer(new_tensor) + + @property + def input_name_to_nodes(self): + """Return input names of nodes.""" + return self._input_name_to_nodes + + def _get_input_name_to_nodes(self, nodes): + """Get input names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_input_name_to_nodes(attr.g.node) + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in self._input_name_to_nodes: + self._input_name_to_nodes[input_name] = [node] + else: + self._input_name_to_nodes[input_name].append(node) + + @property + def output_name_to_node(self): + """Return output names of nodes.""" + return self._output_name_to_node + + def _get_output_name_to_node(self, nodes): + """Get output names of nodes.""" + for node in nodes: + attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(attrs) > 0: + for attr in attrs: + self._get_output_name_to_node(attr.g.node) + for output_name in node.output: + if len(output_name.strip()) != 0: + self._output_name_to_node[output_name] = node + + def get_siblings(self, node): + """Get siblings nodes.""" + siblings = [] + for parent in self.get_parents(node): + for child in self.get_children(parent): + if child.name != node.name: + siblings.append(child) + return siblings + + def get_children(self, node, input_name_to_nodes=None): + """Get children nodes.""" + if input_name_to_nodes is None: + input_name_to_nodes = self._input_name_to_nodes + + children = [] + for output in node.output: + if output in input_name_to_nodes: + for child in input_name_to_nodes[output]: + children.append(child) # noqa: PERF402 + return children + + def get_parents(self, node, output_name_to_node=None): + """Get parents nodes.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + parents = [] + for input in node.input: + if input in output_name_to_node: + parents.append(output_name_to_node[input]) + return parents + + def get_parent(self, node, idx, output_name_to_node=None): + """Get parent node by idx.""" + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if len(node.input) <= idx: + return None + + input = node.input[idx] + if input not in output_name_to_node: + return None + + return output_name_to_node[input] + + def find_node_by_name(self, node_name, new_nodes_list, graph): + """Find out node by name.""" + graph_nodes_list = list(graph.node) # deep copy + graph_nodes_list.extend(new_nodes_list) + node = find_by_name(node_name, graph_nodes_list) + return node + + def find_nodes_by_initializer(self, graph, initializer): + """Find all nodes with given initializer as an input.""" + nodes = [] + for node in graph.node: + for node_input in node.input: + if node_input == initializer.name: + nodes.append(node) + return nodes + + def get_scale_zero(self, tensor): + """Help function to get scale and zero_point.""" + if not tensor.endswith("_quantized"): + logger.debug(f"Find {tensor} in the quantized graph is not quantized.") + return None, None + + def _searcher(tensor_name): + """Search scale and zero point tensor recursively.""" + node = self._input_name_to_nodes[tensor_name][0] + parent = self._output_name_to_node.get(tensor_name, None) + direct_int8 = ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "MaxPool", "Pad", "Split"] + if parent is not None and parent.op_type in direct_int8: + fp32_tensor_name = ( + parent.input[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + elif node.op_type in ["Gather"]: # pragma: no cover + fp32_tensor_name = ( + node.output[0] + .replace("_quantized", "") + .replace("_QuantizeLinear", "") + .replace("_QuantizeInput", "") + ) + else: + fp32_tensor_name = ( + tensor_name.replace("_quantized", "").replace("_QuantizeLinear", "").replace("_QuantizeInput", "") + ) + scale = fp32_tensor_name + "_scale" + scale_tensor = self.get_initializer(scale) + zo = fp32_tensor_name + "_zero_point" + zo_tensor = self.get_initializer(zo) + + if scale_tensor is None or zo_tensor is None: + if parent is not None: + scale_tensor, zo_tensor = _searcher(parent.input[0]) + return scale_tensor, zo_tensor + + node = self._input_name_to_nodes[tensor][0] + # TODO check if scale_tensor and zero_point is needed + # for bias of qlinearconv, scale and zero_point is not needed + if (node.op_type == "QLinearConv" and tensor == node.input[-1]) or ( + node.op_type == "QGemm" and tensor == node.input[-3] + ): + return None, None + else: + scale_tensor, zo_tensor = _searcher(tensor) + assert scale_tensor, f"missing scale for tensor {tensor}" + assert zo_tensor, f"missing zero point for tensor {tensor}" + return scale_tensor, zo_tensor + + def save_model_to_file(self, output_path, use_external_data_format=False): + """Save model to external data, which is needed for model size > 2GB.""" + from onnx.external_data_helper import convert_model_to_external_data + + if use_external_data_format: + convert_model_to_external_data( + self._model, all_tensors_to_one_file=True, location=Path(output_path).name + ".data" + ) + onnx.save_model(self._model, output_path) + + @staticmethod + def replace_node_input(node, old_input_name, new_input_name): + """Replace input of a node.""" + assert isinstance(old_input_name, str) and isinstance(new_input_name, str) + for j in range(len(node.input)): + if node.input[j] == old_input_name: + node.input[j] = new_input_name + + def replace_input_of_all_nodes(self, old_input_name, new_input_name, white_optype=None, black_optype=None): + """Replace inputs of all nodes.""" + if white_optype is None: + white_optype = [] + if black_optype is None: + black_optype = [] + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_input(node, old_input_name, new_input_name) + + @staticmethod + def replace_node_output(node, old_output_name, new_output_name): + """Replace output of a node.""" + assert isinstance(old_output_name, str) and isinstance(new_output_name, str) + for j in range(len(node.output)): + if node.output[j] == old_output_name: + node.output[j] = new_output_name + + def replace_output_of_all_nodes(self, old_output_name, new_output_name, white_optype=None, black_optype=None): + """Replace outputs of all nodes.""" + if white_optype is None: + white_optype = [] + if black_optype is None: + black_optype = [] + if len(white_optype) > 0: + for node in self.model.graph.node: + if node.op_type in white_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + else: + for node in self.model.graph.node: + if node.op_type not in black_optype: + ONNXModel.replace_node_output(node, old_output_name, new_output_name) + + def remove_unused_nodes(self): + """Remove unused nodes.""" + unused_nodes = [] + nodes = self.nodes() + for node in nodes: + if ( + node.op_type == "Constant" + and node.output[0] not in self._model.graph.output + and node.output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + elif ( + node.op_type == "QuantizeLinear" + and len(self.get_children(node)) == 1 + and self.get_children(node)[0].op_type == "DequantizeLinear" + and node.input[0] not in self._output_name_to_node + and self.get_children(node)[0].output[0] not in self._input_name_to_nodes + ): + unused_nodes.append(node) + unused_nodes.extend(self.get_children(node)) + else: + # remove the node if it does not serve as the input or output of any other nodes + unused = True + for output in node.output: + if output in self._input_name_to_nodes or output in self.output(): + unused = False + break + for input in node.input: + if self.get_initializer(input) is not None: + continue + elif input in self._output_name_to_node or input in self.input(): + unused = False + break + if unused: + unused_nodes.append(node) + self.remove_nodes(unused_nodes) + + ununsed_weights = [] + for w in self._model.graph.initializer: + if w.name not in self._input_name_to_nodes and w.name not in self._model.graph.output: + ununsed_weights.append(w) + # Remove from graph.input + for graph_input in self.graph().input: + if graph_input.name == w.name: + self.graph().input.remove(graph_input) + + self.remove_initializers(ununsed_weights) + self.update() + + def topological_sort(self, enable_subgraph=False): + """Topological sort the model.""" + import copy + from collections import deque + + if not enable_subgraph: + input_name_to_nodes = {} + output_name_to_node = {} + for node in self.model.graph.node: + for input_name in node.input: + if len(input_name.strip()) != 0: + if input_name not in input_name_to_nodes: + input_name_to_nodes[input_name] = [node] + else: + input_name_to_nodes[input_name].append(node) + for output_name in node.output: + if len(output_name.strip()) != 0: + output_name_to_node[output_name] = node + else: # pragma: no cover + input_name_to_nodes = self._input_name_to_nodes + output_name_to_node = self._output_name_to_node + + all_nodes = {} + q = deque() + wait = deque() + for inp in self.model.graph.input: + q.extend(input_name_to_nodes[inp.name]) + for n in self.model.graph.node: + if all(i not in output_name_to_node and i not in self.input() for i in n.input): + q.append(n) + + while q: + n = q.popleft() + if not all(output_name_to_node[i].name in all_nodes for i in n.input if i in output_name_to_node): + if n not in wait: + wait.append(n) + continue + + all_nodes[n.name] = n + for out in n.output: + if out in input_name_to_nodes: + q.extend([i for i in input_name_to_nodes[out] if i.name not in all_nodes and i not in q]) + if len(q) == 0 and len(wait) != 0: + q = copy.deepcopy(wait) + wait.clear() + nodes = [i[1] for i in all_nodes.items()] + assert len(list({n.name for n in nodes})) == len(list({n.name for n in self.model.graph.node})) + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes) + + def get_nodes_chain(self, start, stop, result_chain=None): + """Get nodes chain with given start node and stop node.""" + from collections import deque + + from onnx import NodeProto + + if result_chain is None: + result_chain = [] + # process start node list + start_node = deque() + for node in start: + if isinstance(node, str): + start_node.append(node) + elif isinstance(node, NodeProto): + start_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 + + # process stop node list + stop_node = [] + for node in stop: + if isinstance(node, str): + stop_node.append(node) + elif isinstance(node, NodeProto): + stop_node.append(node.name) + else: + assert False, "'get_nodes_chain' function only support list[string]or list[NodeProto] params" # noqa: B011 + + while start_node: + node_name = start_node.popleft() + if node_name in stop_node: + continue + if node_name not in result_chain: + result_chain.append(node_name) + else: + continue + + node = find_by_name(node_name, list(self.model.graph.node)) + for parent in self.get_parents(node): + start_node.append(parent.name) + + return result_chain + + def find_split_node_for_layer_wise_quantization(self): + """Find split node for layer wise quantization.""" + # find split nodes of decoder blocks + # embed -> decoder.0 -(split_node)-> ... -(split_node)-> decoder.n -(split_node)-> norm -> head + # after split: embed -> decoder.0, + # decoder.1, + # decoder.2, + # ..., + # decoder.n, + # norm -> head + start_nodes = [] + for node in self._model.graph.node: + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "Mul", "Div", "Add"], + [None, 0, None, 0, None, 0], + ), + self.match_parent_path( + start_node, + ["MatMul", "Mul", "MatMul", "SimplifiedLayerNormalization", "Add"], + [None, 0, None, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + start_nodes.append(start_node) + return start_nodes + + def find_qkv_in_attention(self, find_all=False): + """Find qkv MatMul in Attention. + + Args: + find_all (bool, optional): find all qkv MatMul. Defaults to False + + Returns: + qkv (list): qkv MatMul list + """ + qkv = [] + for node in self._model.graph.node: + if node.op_type == "Attention": + qkv.append([node.name]) + continue + start_node, qkv_nodes_list = None, None + if node.op_type == "SkipLayerNormalization": + start_node = node + qkv_nodes_list = [ + self.match_parent_path( + start_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ), + ] + if node.op_type == "Add": + start_node = node + qkv_nodes_list = [ + # match base attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0], + ), + self.match_parent_path( + start_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0] + ), + # match gpt attention no past structure + self.match_parent_path( + start_node, + ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + output_name_to_node=self.output_name_to_node, + return_indice=[], + ), + # match bart attention structure + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ), + self.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, None, 0, 0, 0, 0], + ), + ] + if not start_node: + continue + if not any(qkv_nodes_list): + continue + qkv_nodes = [qkv for qkv in qkv_nodes_list if qkv is not None][-1] + other_inputs = [] + for input in start_node.input: + if input not in self.output_name_to_node: + continue + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + continue + root_input = other_inputs[0] + input_name_to_nodes = self.input_name_to_nodes + children = input_name_to_nodes[root_input] + children_types = [child.op_type for child in children] + if children_types.count("MatMul") == 3: + qkv.append([child.name for child in children if child.op_type == "MatMul"]) + if not find_all: + break + return qkv + + def find_ffn_matmul(self, attention_index, attention_matmul_list, block_len): + """Find MatMul in FFN. + + Args: + attention_index (list): index of Attention + attention_matmul_list (list): list of Attention and MatMul nodes + block_len (int): block length + + Returns: + list: list of MatMul in FFN + """ + ffn_matmul = [] + for idx in range(len(attention_index)): + if idx != len(attention_index) - 1: + index = attention_index[idx + 1] + if index - 2 >= 0: + ffn_matmul.append([attention_matmul_list[index - 2], attention_matmul_list[index - 1]]) + else: + index = attention_index[idx] + if index + block_len - 1 < len(attention_matmul_list): + ffn_matmul.append( + [attention_matmul_list[index + block_len - 2], attention_matmul_list[index + block_len - 1]] + ) + return ffn_matmul + + def export(self, save_path, conf): + """Export Qlinear to QDQ model.""" + from neural_compressor.config import ONNXQlinear2QDQConfig + from neural_compressor.utils.export import onnx_qlinear_to_qdq + + if isinstance(conf, ONNXQlinear2QDQConfig): + add_nodes, remove_nodes, inits = onnx_qlinear_to_qdq(self._model, self._input_name_to_nodes) + self.add_nodes(add_nodes) + self.remove_nodes(remove_nodes) + self.add_initializers(inits) + self.update() + self.remove_unused_nodes() + self.topological_sort() + self.save(save_path) + else: + logger.warning("Unsupported config for export, only ONNXQlinear2QDQConfig is supported!") + exit(0) + + def add_tensors_to_outputs(self, tensor_names): + """Add the tensors to the model outputs to gets their values. + + Args: + tensor_names: The names of tensors to be dumped. + """ + added_outputs = [] + for tensor in tensor_names: + if tensor not in self.output(): + added_tensor = onnx.helper.ValueInfoProto() + added_tensor.name = tensor + added_outputs.append(added_tensor) + self._model.graph.output.extend(added_outputs) # pylint: disable=no-member + + def remove_tensors_from_outputs(self, tensor_names): + """Remove the tensors from the model outputs. + + Args: + tensor_names: The names of tensors to be removed. + """ + removed_outputs = [] + for tensor in tensor_names: + if tensor in self.output(): + removed_outputs.append(self._model.graph.output[self.output().index(tensor)]) + for output in removed_outputs: + self._model.graph.output.remove(output) + + def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=None): + """Find parent node based on constraints on op_type. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + + Returns: + parent: The matched parent node. None if not found. + index: The input index of matched parent node. None if not found. + """ + if exclude is None: + exclude = [] + for i, input in enumerate(node.input): + if input in output_name_to_node: + parent = output_name_to_node[input] + if parent.op_type == parent_op_type and parent not in exclude: + return parent, i + return None, None + + def match_parent( + self, + node, + parent_op_type, + input_index=None, + output_name_to_node=None, + exclude=None, + return_indice=None, + ): + """Find parent node based on constraints on op_type and index. + + Args: + node (str): current node name. + parent_op_type (str): constraint of parent node op_type. + input_index (int or None): only check the parent given input index of current node. + output_name_to_node (dict): dictionary with output name as key, and node as value. + exclude (list): list of nodes that are excluded (not allowed to match as parent). + return_indice (list): a list to append the input index when input_index is None. + + Returns: + parent: The matched parent node. + """ + assert node is not None + assert input_index is None or input_index >= 0 + if exclude is None: + exclude = [] + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + if input_index is None: + parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude) + if return_indice is not None: + return_indice.append(index) + return parent + + if input_index >= len(node.input): + return None + + parent = self.get_parent(node, input_index, output_name_to_node) + if parent is not None and parent.op_type == parent_op_type and parent not in exclude: + return parent + + return None + + def match_parent_path( + self, + node, + parent_op_types, + parent_input_index, + output_name_to_node=None, + return_indice=None, + ): + """Find a sequence of input edges based on constraints on parent op_type and index. + + Args: + node (str): current node name. + parent_op_types (str): constraint of parent node op_type of each input edge. + parent_input_index (list): constraint of input index of each input edge. + None means no constraint. + output_name_to_node (dict): dictionary with output name as key, and node as value. + return_indice (list): a list to append the input index when there is + no constraint on input index of an edge. + + Returns: + parents: a list of matched parent node. + """ + assert len(parent_input_index) == len(parent_op_types) + + if output_name_to_node is None: + output_name_to_node = self._output_name_to_node + + current_node = node + matched_parents = [] + for i, op_type in enumerate(parent_op_types): + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i], + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) + if matched_parent is None: + return None + + matched_parents.append(matched_parent) + current_node = matched_parent + + return matched_parents + + def is_smoothquant_model(self): + """Check the model is smooth quantized or not. + + Returns: + bool: the model is smooth quantized or not. + """ + for init in self.model.graph.initializer: # noqa: SIM110 + if "_smooth_scale" in init.name: + return True + return False + + def find_split_nodes(self): + """Find split nodes for layer-wise quantization.""" + split_nodes = self.find_split_node_for_layer_wise_quantization() + return split_nodes + + def split_model_with_node( + self, split_node_name, path_of_model_to_split, shape_infer=True, save_both_split_models=True + ): + """Split model into two parts at a given node. + + Args: + split_node_name (str): name of the node where the model is split at> + path_of_model_to_split (str): path of model to be split. + shape_infer (bool): do shape inference. Default is True. + save_both_split_models (bool): whether to save the two split models. + False means only save the first split model. + True means save both the two split models. + Default id True. + + Returns: + tuple: the first split model, the second split model + """ + # origin model : ... -> node_1 -> split_node -> node_2 -> ... + # split model 1: ... -> node_1 -> split_node + # split model 2: node_2 -> ... + + split_model_part_1 = onnx.ModelProto() + split_model_part_1.CopyFrom(self._model) + split_model_part_1.graph.ClearField("node") + + split_model_part_2 = onnx.ModelProto() + split_model_part_2.CopyFrom(self._model) + split_model_part_2.graph.ClearField("node") + + split_node_output = None + part_idx = 1 + for node in self._model.graph.node: + if part_idx == 1: + split_model_part_1.graph.node.append(node) + elif part_idx == 2: + split_model_part_2.graph.node.append(node) + + if node.name == split_node_name: + split_node_output = node.output + part_idx = 2 + + assert len(split_node_output) == 1, ( + f"Only support split at node with 1 output tensor, while current split node {split_node_name} has {len(split_node_output)} output tensors" + ) + split_tensor_name = split_node_output[0] + + # infer shape of the model to be split + if shape_infer: + try: + from neural_compressor.adaptor.ox_utils.util import infer_shapes + + self._model = infer_shapes(self._model, auto_merge=True, base_dir=os.path.dirname(self._model_path)) + except Exception as e: # pragma: no cover + logger.error( + "Shape infer fails for layer-wise quantization. " + "We would recommend checking the graph optimization level of your model " + "and setting it to 'DISABLE_ALL' or 'ENABLE_BASIC', " + "as this may help avoid this error." + ) + raise e + + split_tensor_type, split_tensor_shape = self._get_output_type_shape_by_tensor_name(split_tensor_name) + split_tensor = onnx.helper.make_tensor_value_info(split_tensor_name, split_tensor_type, split_tensor_shape) + + split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True) + split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True) + + # remove unused input & output + split_model_part_1._remove_unused_input_output() + split_model_part_2._remove_unused_input_output() + + split_model_part_1.model.graph.output.append(split_tensor) + split_model_part_2.model.graph.input.append(split_tensor) + + insert_output_for_model_1 = [] + insert_input_for_model_2 = [] + for output in split_model_part_1.output_name_to_node: + if output in split_model_part_2.input_name_to_nodes: + output_type, output_shape = self._get_output_type_shape_by_tensor_name(output) + output_tensor = onnx.helper.make_tensor_value_info(output, output_type, output_shape) + if output_tensor not in split_model_part_1.model.graph.output: + insert_output_for_model_1.append(output_tensor) + if output_tensor not in split_model_part_2.model.graph.input: + insert_input_for_model_2.append(output_tensor) + + # insert model 1 output + for output in insert_output_for_model_1: + split_model_part_1.model.graph.output.append(output) + + # insert model 2 input + for input in insert_input_for_model_2: + split_model_part_2.model.graph.input.append(input) + + # remove unused init + split_model_part_1.remove_unused_init() + split_model_part_2.remove_unused_init() + + split_model_part_1.update() + split_model_part_2.update() + + dir_of_model_to_split = os.path.dirname(path_of_model_to_split) + + split_model_part_1.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_1_path = os.path.join(dir_of_model_to_split, "split_model_part_1.onnx") + split_model_part_1.model_path = split_model_part_1_path + split_model_part_1._save_split_model(split_model_part_1_path) + split_model_part_1.check_is_large_model() + logger.debug(f"save split model part 1 to {split_model_part_1_path} for layer wise quantization") + + if save_both_split_models: + split_model_part_2.load_model_initializer_by_tensor(dir_of_model_to_split) + split_model_part_2_path = os.path.join(dir_of_model_to_split, "split_model_part_2.onnx") + split_model_part_2.model_path = split_model_part_2_path + split_model_part_2._save_split_model(split_model_part_2_path) + split_model_part_2.check_is_large_model() + logger.debug(f"save split model part 2 to {split_model_part_2_path} for layer wise quantization") + return split_model_part_1, split_model_part_2 + else: + return split_model_part_1, split_model_part_2 + + def _save_split_model(self, save_path): + """Save split model as external data for layer wise quantization. + + Args: + save_path (str): the path to save the split model + """ + if os.path.exists(save_path + "_data"): + os.remove(save_path + "_data") + onnx.save_model( + self._model, + save_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=save_path.split("/")[-1] + "_data", + size_threshold=1024, + convert_attribute=False, + ) + + def _get_output_type_shape_by_tensor_name(self, tensor_name): + """Get output type and shape with a tensor name. + + Args: + tensor_name (str): name of a tensor + + Returns: + tuple: output type and shape + """ + elem_type = onnx.TensorProto.FLOAT + shape = None + for output in self._model.graph.value_info: + if output.name == tensor_name: + elem_type = output.type.tensor_type.elem_type + shape = [ + dim.dim_value if dim.HasField("dim_value") else -1 for dim in output.type.tensor_type.shape.dim + ] + break + return elem_type, shape + + def _remove_unused_input_output(self): + """Remove unused input & output for split model.""" + remove_outputs = [] + remove_inputs = [] + for output in self._model.graph.output: + if output.name not in self.output_name_to_node: + remove_outputs.append(output) + + for input in self._model.graph.input: + if input.name not in self.input_name_to_nodes: + remove_inputs.append(input) + + for output in remove_outputs: + self._model.graph.output.remove(output) + for input in remove_inputs: + self._model.graph.input.remove(input) + + def remove_unused_init(self): + """Remove unused init.""" + remov_inits = [] + for init in self._model.graph.initializer: + if init.name not in self.input_name_to_nodes: + remov_inits.append(init) + self.remove_initializers(remov_inits) + + def load_model_initializer_by_tensor(self, data_path=None): + """Load model initializer by tensor. + + Args: + data_path (str, optional): the directory of saved initializer. Defaults to None. + """ + from onnx.external_data_helper import load_external_data_for_tensor + + if data_path is None: + data_path = os.path.dirname(self._model_path) + for init in self._model.graph.initializer: + if init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL: + load_external_data_for_tensor(init, data_path) + + def write_external_data_to_new_location(self, external_data_location="external.data", overwrite=False): + """Write external data of merged quantized model to new location to save memory. + + Args: + external_data_location (str, optional): external data location of merged quantized model. + Defaults to "external.data". + overwrite (bool, optional): if True, remove existed externa data. Defaults to False. + """ + from onnx.external_data_helper import convert_model_to_external_data, write_external_data_tensors + + if overwrite and os.path.exists(os.path.join(os.path.dirname(self._model_path), external_data_location)): + os.remove(os.path.join(os.path.dirname(self._model_path), external_data_location)) + self.load_model_initializer_by_tensor() + convert_model_to_external_data(self._model, location=external_data_location) + # TODO : if init is already saved, skip write it + write_external_data_tensors(self._model, filepath=os.path.dirname(self._model_path)) + + def merge_split_models(self, to_merge_model): + """Merge two split model into final model.""" + to_merge_model.write_external_data_to_new_location() + self.add_nodes(list(to_merge_model.nodes())) + self.add_initializers(list(to_merge_model.initializer())) + self.update() + + # add new output + for output in to_merge_model.graph().output: + if output.name not in self.output(): + self._model.graph.output.append(output) + + # remove unused output + remove_output = [] + for output in self._model.graph.output: + if output.name in to_merge_model.input(): + remove_output.append(output) + for output in remove_output: + self._model.graph.output.remove(output) + + # add new input + for input in to_merge_model.graph().input: + if ( + input.name not in self.input() + and input.name not in self.output() + and input.name not in self.output_name_to_node + ): + self._model.graph.input.append(input) + + def re_org_output(self, origin_output): + """Re-org output of merged model for layer-wise quantization.""" + outputs = {} + tmp_remove = [] + for output in self._model.graph.output: + outputs[output.name] = output + tmp_remove.append(output) + + for output in tmp_remove: + self._model.graph.output.remove(output) + + for out_name in origin_output: + self._model.graph.output.append(outputs[out_name]) diff --git a/onnxruntime/python/tools/quantization/neural_compressor/util.py b/onnxruntime/python/tools/quantization/neural_compressor/util.py new file mode 100644 index 0000000000000..aae01b4defd1f --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/util.py @@ -0,0 +1,80 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper classes or functions for onnxrt adaptor.""" + +import importlib +import logging + +import numpy as np + +logger = logging.getLogger("neural_compressor") + + +MAXIMUM_PROTOBUF = 2147483648 + + +def simple_progress_bar(total, i): + """Progress bar for cases where tqdm can't be used.""" + progress = i / total + bar_length = 20 + bar = "#" * int(bar_length * progress) + spaces = " " * (bar_length - len(bar)) + percentage = progress * 100 + print(f"\rProgress: [{bar}{spaces}] {percentage:.2f}%", end="") + + +def find_by_name(name, item_list): + """Helper function to find item by name in a list.""" + items = [] + for item in item_list: + assert hasattr(item, "name"), f"{item} should have a 'name' attribute defined" # pragma: no cover + if item.name == name: + items.append(item) + if len(items) > 0: + return items[0] + else: + return None + + +def to_numpy(data): + """Convert to numpy ndarrays.""" + import torch + + if not isinstance(data, np.ndarray): + if not importlib.util.find_spec("torch"): + logger.error( + "Please install torch to enable subsequent data type check and conversion, " + "or reorganize your data format to numpy array." + ) + exit(0) + if isinstance(data, torch.Tensor): + if data.dtype is torch.bfloat16: # pragma: no cover + return data.detach().cpu().to(torch.float32).numpy() + if data.dtype is torch.chalf: # pragma: no cover + return data.detach().cpu().to(torch.cfloat).numpy() + return data.detach().cpu().numpy() + else: + try: + return np.array(data) + except Exception: + assert False, ( # noqa: B011 + f"The input data for onnx model is {type(data)}, which is not supported to convert to numpy ndarrays." + ) + else: + return data diff --git a/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py new file mode 100644 index 0000000000000..4eda7efc9b8fe --- /dev/null +++ b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py @@ -0,0 +1,932 @@ +# +# The implementation of this file is based on: +# https://github.com/intel/neural-compressor/tree/master/neural_compressor +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modifications: +# Add k-quant quantization method. +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""WeightOnly for onnxrt adaptor.""" + +import copy +import logging +import os +import sys + +import numpy as np +import onnx +from onnx import numpy_helper +from onnx.helper import np_dtype_to_tensor_dtype + +import onnxruntime as ort + +from .onnx_model import ONNXModel +from .util import simple_progress_bar + +logger = logging.getLogger("neural_compressor") + + +def make_matmul_weight_only_node( + node, + weight_shape, + num_bits, + group_size, + k_blocks, + q_weight, + scale, + zero_point, + accuracy_level=0, +): # pragma: no cover + """Build MatMulNBits node. + + Args: + node: original matmul node + weight_shape: original weight shape + num_bits (int): num_bits + group_size (int): how many elements share one scale/zp + k_blocks (int): block number + q_weight (array): quantized weight + scale (array): scale + zero_point (array): zero point + accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8). + + Returns: + matmul_weight_only_node: MatMulNBits node + new_inits: initializers of the new node + """ + blob_size = group_size * num_bits // 8 + packed = np.zeros((q_weight.shape[0], blob_size), dtype="uint8") + q_weight_name = node.input[1] + f"_Q{num_bits!s}G{group_size!s}" + input_names = [node.input[0], q_weight_name] + new_inits = [] + kwargs = {} + + op_type = "MatMulNBits" + + # pack quantized weight + if num_bits == 4: + q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4 + packed[:, :] = q_weight_pairs[:, :blob_size] + elif num_bits == 8: + packed = q_weight + else: + logger.error(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.") + + packed = np.reshape(packed, (-1, k_blocks, blob_size)) + + # build scale tensor + scale = np.reshape(scale, (-1, k_blocks)) + assert scale.dtype == np.float32 or scale.dtype == np.float16 + scale_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_scale", + data_type=np_dtype_to_tensor_dtype(scale.dtype), + dims=scale.shape, + vals=scale.tobytes(), + raw=True, + ) + input_names.append(scale_tensor.name) + new_inits.append(scale_tensor) + + # build zero_point tensor + if zero_point is not None: + if num_bits == 8: + packed_zp = zero_point.astype("uint8") + elif num_bits == 4: + # For 4-bit case, the default zeros is 0x8. So it is 0x88 = 136 if we fill lower/higher 4 bits with 0x8. + packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8") + # create an index array + idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1) + # separate odd and even indices + even_idx = idx[::2] + odd_idx = idx[1::2] + # vectorized operation for even and odd indices + packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel() + packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4) + else: + raise ValueError(f"MatMulNBits does not have kernel support for num_bits = {num_bits}.") + + packed_zp = np.reshape(packed_zp, (weight_shape[1], -1)) + zp_tensor = onnx.helper.make_tensor( + name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True + ) + input_names.append(zp_tensor.name) + new_inits.append(zp_tensor) + + # set kwargs + kwargs["K"] = weight_shape[0] + kwargs["N"] = weight_shape[1] + kwargs["bits"] = num_bits + kwargs["block_size"] = group_size + if accuracy_level > 0: + # require onnxruntime > 1.16.3 + kwargs["accuracy_level"] = accuracy_level + + q_weight_tensor = onnx.helper.make_tensor( + name=q_weight_name, + data_type=2, + dims=packed.shape, + vals=packed.tobytes(), + raw=True, + ) + new_inits.append(q_weight_tensor) + + matmul_weight_only_node = onnx.helper.make_node( + op_type, + inputs=input_names, + outputs=node.output, + name=node.name + "_Q" + str(num_bits) if node.name else "_Q" + str(num_bits), + domain="com.microsoft", + **kwargs, + ) + return matmul_weight_only_node, new_inits + + +def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quantize tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)) + if scheme == "asym" or dtype == "uint": + maxq = 2**num_bits - 1 + minq = 0 + elif scheme == "sym": + maxq = 2 ** (num_bits - 1) - 1 if num_bits != 1 else 0 + minq = -(2 ** (num_bits - 1)) if num_bits != 1 else -1 + + rmin = np.min(data, axis=1, keepdims=True) * ratio + rmax = np.max(data, axis=1, keepdims=True) * ratio + if scheme == "sym": + max_range = np.maximum(np.abs(rmin), np.abs(rmax)) + scale = np.ones(rmax.shape) + mask = max_range > 0 + scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq) + zero_point = ( + np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1)) + ) + else: + scale = np.ones(rmax.shape) + scale[rmin != rmax] = np.array( + [float(i) / (maxq - minq) for i in (rmax - rmin)[rmin != rmax].flatten().tolist()] + ) + zero_point = ( + ((np.zeros(scale.shape) - rmin) / scale).round() + if dtype == "int" + else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8") + ) + + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + + +def quant_tensor_k_quant_cpu(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 32. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + data = np.reshape(data, (-1, group_size)).astype(np.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = np.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = np.sqrt(sum_x2 / group_size) # (nb, 1) + weights = np.add(av_x, np.abs(data)) # (nb, group_size) + rmin = np.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = np.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = np.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = np.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = np.clip(np.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + # nstep * rdelta = -2 * rrmin, maxq - minq = 2**num_bits - 1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = np.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = np.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = np.clip(np.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = np.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = np.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = np.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = np.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806 + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = np.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = np.array(mad) + best_mad_1 = np.array(best_mad) + idx_to_replace = np.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = np.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(np.float64) + q_weight = np.empty_like(data, dtype=scale.dtype) + np.divide(data, scale, out=q_weight) + np.add(q_weight, zero_point, out=q_weight) + np.round(q_weight, out=q_weight) + np.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight, scale, zero_point + + +def quant_tensor_k_quant_cuda(data, num_bits=4, group_size=32): + """Quantize tensor per group based on k quant. + + Ref: https://github.com/ggml-org/llama.cpp/blob/64eda5deb9859e87a020e56bab5d2f9ca956f1de/ggml/src/ggml-quants.c + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + + Returns: + output: quantized weight + scale: scale + zero_point: zero point + """ + try: + import cupy as cp + import torch + + if torch.cuda.is_available(): + data = cp.asarray(data) + data = data.reshape((-1, group_size)).astype(cp.float32) # nb = data.shape[0], (nb, group_size) + maxq = 2**num_bits - 1 + minq = 0 + sum_x2 = cp.sum(data**2, axis=1, keepdims=True) # (nb, 1) + av_x = cp.sqrt(sum_x2 / group_size) # (nb, 1) + weights = cp.add(av_x, cp.abs(data)) # (nb, group_size) + rmin = cp.min(data, axis=1, keepdims=True) # (nb, 1) + rmax = cp.max(data, axis=1, keepdims=True) # (nb, 1) + sum_w = cp.sum(weights, axis=1, keepdims=True) # (nb, 1) + sum_x = cp.sum(weights * data, axis=1, keepdims=True) # (nb, group_size) + iscale = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + mask = rmin != rmax + iscale[mask] = (maxq - minq) / (rmax[mask] - rmin[mask]) + scale = 1 / iscale + quant_data = cp.clip(cp.round(iscale * (data - rmin)), minq, maxq) # (nb, group_size) + diff = scale * quant_data + rmin - data # (nb, group_size) + best_mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + nstep = 20 + rdelta = 0.1 + rrmin = -1 + for is_ in range(nstep): + iscale_new = cp.ones(rmax.shape, dtype=data.dtype) # (nb, 1) + factor = cp.array([rrmin + rdelta * is_ + maxq - minq]).astype(data.dtype)[0] + mask = rmin != rmax + iscale_new[mask] = factor / (rmax[mask] - rmin[mask]) + quant_data_new = cp.clip(cp.round(iscale_new * (data - rmin)), minq, maxq) # (nb, group_size) + mul_weights_quant_data_new = weights * quant_data_new + sum_l = cp.sum(mul_weights_quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_l2 = cp.sum(mul_weights_quant_data_new * quant_data_new, axis=1, keepdims=True) # (nb, 1) + sum_xl = cp.sum(mul_weights_quant_data_new * data, axis=1, keepdims=True) # (nb, 1) + D = cp.subtract(sum_w * sum_l2, sum_l**2) # noqa: N806 + + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D # (nb, 1) + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D # (nb, 1) + + diff = this_scale * quant_data_new + this_min - data # (nb, group_size) + mad = cp.sum(weights * diff**2, axis=1, keepdims=True) # (nb, 1) + + mad_1 = cp.array(mad) + best_mad_1 = cp.array(best_mad) + idx_to_replace = cp.where(mad_1 < best_mad_1)[0] + quant_data[idx_to_replace, :] = quant_data_new[idx_to_replace, :] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + + zero_point = cp.clip(((-rmin) / scale).round(), 0, maxq).astype("uint8") + scale = scale.astype(cp.float64) + q_weight = cp.empty_like(data, dtype=scale.dtype) + cp.divide(data, scale, out=q_weight) + cp.add(q_weight, zero_point, out=q_weight) + cp.round(q_weight, out=q_weight) + cp.clip(q_weight, minq, maxq, out=q_weight) + + return q_weight.get(), scale.get(), zero_point.get() + else: + logger.warning( + "Try to use k-quant quantization on CUDA. However, CUDA is not available." + "Fall back to k-quant quantization on CPU." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + except ImportError: + logger.info( + "Now we are using k-quant quantization on cpu, which is time consuming." + "Please consider install cupy to speed up on CUDA. See https://cupy.dev/" + "Please also install torch to check CUDA availability." + ) + return quant_tensor_k_quant_cpu(data, num_bits, group_size) + + +def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0): + """Quant dequant tensor per group. + + Args: + data : input weight + num_bits (int, optional): num_bits. Defaults to 4. + group_size (int, optional): how many elements share one scale/zp. Defaults to 4. + scheme (str, optional): quantization scheme. Defaults to "asym". + dtype (str, optional): data type. Defaults to "int". + ratio (float, optional): percentile of clip. Defaults to 1.0. + + Returns: + output: quant-dequant weight + """ + org_shape = data.shape + weight, scale, zp = quant_tensor(data, num_bits, group_size, scheme, dtype, ratio) + return np.reshape(scale * (weight - zp), org_shape) + + +def pad_tensor(weight, group_size, k_blocks): + """Pad tensor rowi so that it can be is divisible by group_size. + + Args: + weight (array): weight + group_size (int): how many elements share one scale/zp + k_blocks (int): the number of block + + Returns: + weight: paded weight + """ + if group_size == -1: + return weight + + org_w_shape = weight.shape + padded_rows = k_blocks * group_size + pad_len = padded_rows - org_w_shape[0] + + if pad_len > 0: + weight = np.pad(weight, ((0, pad_len), (0, 0)), "constant") + + return weight + + +def rtn_quantize( + model, + weight_config={}, # noqa: B006 + num_bits=4, + group_size=32, + scheme="asym", + ratios={}, # noqa: B006 + accuracy_level=0, + providers=["CPUExecutionProvider"], # noqa: B006 + algorithm="k_quant", +): + """Quant the model with round to nearst method. + + Args: + model (ModelProto or ONNXModel): onnx model + weight_config (dict): quantization config + For example, + weight_config = { + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'RTN' + } + } + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + ratios (dict, optional): percentile of clip. Defaults to {}. + accuracy_level (int): accuracy level. Support 0 (unset),1(fp32), 2(fp16), 3(bf16), or 4(int8). + providers (list): providers to use + + Returns: + model: fake quantized ONNXModel + """ + model = ONNXModel(model) + base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" + new_nodes = [] + remove_nodes = [] + total_num = len([i for i in model.nodes() if i.op_type in ["MatMul"]]) + curr_id = 0 + for node in model.nodes(): + if node.op_type in ["MatMul"]: + curr_id += 1 + simple_progress_bar(total_num, curr_id) + if ( + node.op_type in ["MatMul"] + and model.get_initializer(node.input[1]) is not None + and weight_config.get(node.name, {}) != "fp32" + ): + weight_tensor = model.get_initializer(node.input[1]) + weight = numpy_helper.to_array(weight_tensor, base_dir=base_dir).copy() + if len(weight.shape) != 2: + continue + + dtype = weight.dtype + + if node.name in weight_config: + num_bits = weight_config[node.name]["bits"] + group_size = weight_config[node.name]["group_size"] + scheme = weight_config[node.name]["scheme"] + + org_w_shape = weight.shape # ic, oc + group_size = group_size if group_size != -1 else org_w_shape[0] + + k_blocks = (org_w_shape[0] - 1) // group_size + 1 + init_share_num = model.get_initializer_share_num(node.input[1]) + + weight = pad_tensor(weight, group_size, k_blocks) + + satisfy_MatMulNBits_condition = num_bits == 4 or num_bits == 8 # noqa: N806 + + if satisfy_MatMulNBits_condition: # pragma: no cover + if algorithm == "k_quant": + q_weight, scale, zp = quant_tensor_k_quant_cuda(weight.T, num_bits, group_size) + else: + q_weight, scale, zp = quant_tensor( + weight.T, num_bits, group_size, scheme, "uint", ratios.get(node.input[1], 1) + ) + + q_matmul_node, new_inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_w_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight.astype("uint8"), + scale=scale.astype(dtype), + zero_point=zp if scheme == "asym" else None, + accuracy_level=accuracy_level, + ) + + model.add_initializers(new_inits) + remove_nodes.append(node) + new_nodes.append(q_matmul_node) + else: + q_weight = qdq_tensor(weight.T, num_bits, group_size, scheme, "int", ratios.get(node.input[1], 1)) + q_weight = np.reshape(q_weight, (org_w_shape[1], -1)) + q_weight = np.transpose(q_weight) + q_weight = q_weight[: org_w_shape[0], :].astype(dtype) + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", + data_type=np_dtype_to_tensor_dtype(dtype), + dims=weight.shape, + vals=q_weight.tobytes(), + raw=True, + ) + model.add_initializer(q_weight_tensor) + node.input[1] = q_weight_tensor.name + if init_share_num == 1: + model.remove_initializer(weight_tensor) + + model.add_nodes(new_nodes) + model.remove_nodes(remove_nodes) + model.topological_sort() + return model + + +def get_weight_scale(weight, group_size): + """Get the scale of weight.""" + org_shape = weight.shape + weight = np.reshape(weight, (-1, group_size)) if group_size != -1 else weight + scale = np.mean(np.reshape(np.abs(weight) / np.max(np.abs(weight), axis=1, keepdims=True), org_shape), axis=0) + return scale + + +def prepare_inputs(model, n_samples, dataloader, providers): + """Prepare inputs for weight only quantization. + + Args: + model (ModelProto or ONNXModel): onnx model + n_samples (int, optional): calibration sample number. -1 means all samples. + dataloader (object): dataloader for calibration. + providers (list): providers to use + + Returns: + inputs: prepared inputs. + so: session options + """ + from importlib.util import find_spec + + from .util import to_numpy + + so = ort.SessionOptions() + if sys.version_info < (3, 11) and find_spec("onnxruntime_extensions"): # pragma: no cover + from onnxruntime_extensions import get_library_path + + so.register_custom_ops_library(get_library_path()) + if model.is_large_model: + onnx.save_model( + model.model, + model.model_path + "_augment.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + + session = ( + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) + if not model.is_large_model + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) + ) + inputs_names = [i.name for i in session.get_inputs()] + del session + + inputs = [] + for i, data in enumerate(dataloader): + if n_samples != -1 and ((i + 1) * dataloader.batch_size) > n_samples: + break + if len(inputs_names) != 1 or isinstance(data[0], dict): + assert len(data[0]) == len(inputs_names), ( + f"Input number mismatch, require {len(inputs_names)} but get {len(data[0])}" + ) + + if isinstance(data[0], dict): + inputs.append(dict([(name, to_numpy(inp_data)) for name, inp_data in data[0].items()])) # noqa: C404 + elif isinstance(data[0], np.ndarray): # pragma: no cover + inputs.append(dict([(name, inp) for name, inp in zip(inputs_names, [data[0]], strict=False)])) # noqa: C404 + else: # pragma: no cover + inputs.append(dict([(name, to_numpy(inp)) for name, inp in zip(inputs_names, data[0], strict=False)])) # noqa: C404 + return inputs, so + + +def gptq( + W, + H, + num_bits=4, + group_size=32, + scheme="asym", + blocksize=128, + percdamp=0.01, + actorder=False, + mse=False, + perchannel=True, +): + """Quant the weight with GPTQ method. + + Args: + W (array): weight. + H (array): Hessian matrix. + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + blocksize (int, optional): blocksize to quantize weight. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + + Returns: + Q: fake quantized weight + """ + maxq = 2**num_bits - 1 + grid = 100 + maxshrink = 0.8 + norm = 2.4 + + def find_params(weight): + org_shape = weight.shape + # find zp, scale + if not perchannel: + weight = np.expand_dims(weight.flatten(), axis=1) + tmp = np.zeros(weight.shape[1]) + xmin = np.minimum(np.min(weight, axis=0), tmp) + xmax = np.maximum(np.max(weight, axis=0), tmp) + if scheme == "sym": + xmax = np.maximum(np.abs(xmin), xmax) + tmp = xmin < 0 + if np.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + scale = (xmax - xmin) / maxq + if scheme == "sym": + zero = np.ones(scale.shape) * (maxq + 1) / 2 + else: + zero = np.round(-xmin / scale) + if mse: + best = np.ones([weight.shape[1]]) * float("inf") + for i in range(int(maxshrink * grid)): + p = 1 - i / grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / maxq + zero1 = np.round(-xmin1 / scale1) if scheme != "sym" else zero + q = np.clip(np.round(weight / scale1) + zero1, 0, maxq) + q -= weight + q = np.power(np.abs(q), norm) + err = np.sum(q, 0) + tmp = err < best + if np.any(tmp): + best[tmp] = err[tmp] + scale[tmp] = scale1[tmp] + zero[tmp] = zero1[tmp] + if not perchannel: + tmp = org_shape[1] + scale = np.repeat(scale, tmp) + zero = np.repeat(zero, tmp) + shape = [-1] + [1] * (len(org_shape) - 1) + scale = np.reshape(scale, shape) + zero = np.reshape(zero, shape) + return scale, zero + + shape = W.shape + scale, zp = find_params(W) + dead = np.diag(H) == 0 + H[dead, dead] = 1 + W[dead, :] = 0 # such channel makes no contribution to quantization computation + + # rearrange considering the diag's value + if actorder: + perm = np.argsort(np.diag(H))[::-1] + W = W[perm, :] # noqa: N806 + H = H[perm, :][:, perm] # noqa: N806 + Losses = np.zeros_like(W) # noqa: N806 + Q = np.zeros_like(W) # noqa: N806 + damp = percdamp * np.mean(np.diag(H)) + diag = np.arange(shape[0]) + H[diag, diag] += damp # add a average value of + H = np.linalg.cholesky(np.linalg.inv(H)).T # noqa: N806 + Hinv = H # noqa: N806 + for i1 in range(0, shape[0], blocksize): + i2 = min(i1 + blocksize, shape[0]) + count = i2 - i1 + + W1 = copy.deepcopy(W[i1:i2, :]) # noqa: N806 + Q1 = np.zeros_like(W1) # noqa: N806 + Err1 = np.zeros_like(W1) # noqa: N806 + Losses1 = np.zeros_like(W1) # noqa: N806 + Hinv1 = Hinv[i1:i2, i1:i2] # noqa: N806 + + for i in range(count): # within a block, channel wise + w = W1[i, :] + d = Hinv1[i, i] + + if group_size != -1: + if (i1 + i) % group_size == 0: + scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :]) + + q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten() + Q1[i, :] = q + Losses1[i, :] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + W1[i:, :] -= np.matmul(np.expand_dims(Hinv1[i:, i], axis=1), np.expand_dims(err1, axis=0)) + Err1[i, :] = err1 + + Q[i1:i2, :] = Q1 + Losses[i1:i2, :] = Losses1 / 2 + + W[i2:, :] -= np.matmul(Hinv[i2:, i1:i2], Err1) + + if actorder: + invperm = np.argsort(perm) + Q = Q[invperm, :] # noqa: N806 + + Q = np.reshape(Q, W.shape) # noqa: N806 + del W + return Q + + +def gptq_quantize( + model, + dataloader, + weight_config={}, # noqa: B006 + num_bits=4, + group_size=32, + scheme="asym", + n_samples=128, + percdamp=0.01, + blocksize=128, + actorder=False, + mse=False, + perchannel=True, + accuracy_level=0, + providers=["CPUExecutionProvider"], # noqa: B006 +): + """Quant the model with GPTQ method. + + Args: + model (ModelProto or ONNXModel): onnx model + dataloader (object): dataloader for calibration. + weight_config (dict): quantization config + For example, + weight_config = { + 'fc2': + { + 'bits': 4, + 'group_size': 32, + 'scheme': 'sym', + 'algorithm': 'GPTQ' + } + } + num_bits (int, optional): num_bits. Default is 4. + group_size (int, optional): how many elements share one scale/zp. Default is 32. + scheme (str, optional): sym or asym. Defaults to "asym". + n_samples (int, optional): calibration sample number. + percdamp (float, optional): percent of the average Hessian diagonal to use for dampening. + blocksize (int, optional): blocksize to quantize weight. + actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. + mse (bool, optional): whether get scale and zero point with mse error. + perchannel (bool, optional): whether quantize weight per-channel. + accuracy_level (int): accuracy level. Support 0 (unset), 1(fp32), 2(fp16), 3(bf16), or 4(int8). + providers (list): providers to use + + Returns: + model: fake quantized ONNXModel + """ + model = ONNXModel(model) + base_dir = os.path.dirname(model.model_path) if model.model_path is not None else "" + + inputs, so = prepare_inputs(model, n_samples, dataloader, providers) + del dataloader + org_output = copy.deepcopy(model.model.graph.output) + model.remove_tensors_from_outputs([i.name for i in org_output]) + output_names = [] + for node in model.nodes(): + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + ): + output_names.append(node.input[0]) + output_names = list(set(output_names)) + model.add_tensors_to_outputs(output_names) + if model.is_large_model: + onnx.save_model( + model.model, + model.model_path + "_augment.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + + session = ( + ort.InferenceSession(model.model.SerializeToString(), so, providers=providers) + if not model.is_large_model + else ort.InferenceSession(model.model_path + "_augment.onnx", so, providers=providers) + ) + + for idx, input_name in enumerate(output_names): + simple_progress_bar(len(output_names), idx + 1) + node_list = [] + weights = [] + + for node in model.input_name_to_nodes[input_name]: + if ( + node.op_type in ["MatMul"] + and weight_config.get(node.name, {}) != "fp32" + and weight_config.get(node.name, {}).get("algorithm", "GPTQ") == "GPTQ" + and model.get_initializer(node.input[1]) is not None + ): + weight = numpy_helper.to_array( + model.get_initializer(model.get_node(node.name).input[1]), base_dir + ).copy() + if len(weight.shape) != 2: + continue + + weights.append(weight) + node_list.append(model.get_node(node.name)) + + if len(weights) == 0: + continue + + Hs = [np.zeros((i.shape[0], i.shape[0])) for i in weights] # noqa: N806 + nsamples = 0 + for data in inputs: + inp = session.run([input_name], data)[0] + tmp = inp.shape[0] + inp = np.reshape(inp, (-1, inp.shape[-1])) + Hs = [i * (nsamples / (nsamples + tmp)) for i in Hs] # noqa: N806 + nsamples += tmp + inp = np.sqrt(2 / nsamples) * inp + Hs = [i + np.matmul(inp.T, inp) for i in Hs] # noqa: N806 + + for ( + node, + weight, + H, # noqa: N806 + ) in zip(node_list, weights, Hs, strict=False): + if node.name in weight_config: + num_bits = weight_config[node.name]["bits"] + group_size = weight_config[node.name]["group_size"] + scheme = weight_config[node.name]["scheme"] + group_size = group_size if group_size != -1 else weight.shape[0] + dtype = weight.dtype + + q_weight = gptq( + weight, + H, + num_bits=num_bits, + group_size=group_size, + scheme=scheme, + blocksize=blocksize, + percdamp=percdamp, + actorder=actorder, + mse=mse, + perchannel=perchannel, + ) + + weight_tensor = model.get_initializer(node.input[1]) + init_share_num = model.get_initializer_share_num(node.input[1]) + + satisfy_MatMulNBits_condition = num_bits == 4 # noqa: N806 + + if satisfy_MatMulNBits_condition: # pragma: no cover + org_shape = weight.shape + k_blocks = (org_shape[0] + group_size - 1) // group_size + q_weight = pad_tensor(q_weight, group_size, k_blocks) + q_weight, scale, zp = quant_tensor(q_weight.T, num_bits, group_size, scheme, "uint") + q_matmul_node, new_inits = make_matmul_weight_only_node( + node=node, + weight_shape=org_shape, + num_bits=num_bits, + group_size=group_size, + k_blocks=k_blocks, + q_weight=q_weight.astype("uint8"), + scale=scale.astype(dtype), + zero_point=zp if scheme == "asym" else None, + accuracy_level=accuracy_level, + ) + + model.add_initializers(new_inits) + model.remove_node(node) + model.add_node(q_matmul_node) + else: + q_weight_tensor = onnx.helper.make_tensor( + name=node.input[1] + f"_Q{num_bits!s}G{group_size!s}", + data_type=np_dtype_to_tensor_dtype(dtype), + dims=q_weight.shape, + vals=q_weight.astype(dtype).tobytes(), + raw=True, + ) + model.add_initializer(q_weight_tensor) + node.input[1] = q_weight_tensor.name + if init_share_num == 1: + model.remove_initializer(weight_tensor) + + model.remove_tensors_from_outputs(output_names) + model.model.graph.output.MergeFrom(org_output) + + model.topological_sort() + + # reload external data to prevent external data file path errors + if model.is_large_model: + from onnx.external_data_helper import load_external_data_for_model + + load_external_data_for_model(model.model, os.path.split(model.model_path)[0]) + + return model diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 0e739055b1772..03f4791c580e6 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -360,6 +360,8 @@ def test_quantize_matmul_int4_offsets_qdq(self): def test_quantize_matmul_int4_using_rtn_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + if not find_spec("torch"): + self.skipTest("skip test_quantize_matmul_int4_using_rtn_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) @@ -371,6 +373,8 @@ def test_quantize_matmul_int4_using_rtn_algo(self): def test_quantize_matmul_int4_using_gptq_algo(self): if not find_spec("neural_compressor"): self.skipTest("skip test_smooth_quant since neural_compressor is not installed") + if not find_spec("torch"): + self.skipTest("skip test_quantize_matmul_int4_using_gptq_algo since torch is not installed") model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) self.construct_model_matmul(model_fp32_path, symmetric=False) data_reader = self.input_feeds(1, {"input": (100, 52)}) diff --git a/setup.py b/setup.py index 1e426ea8e060b..c45657c0c2873 100644 --- a/setup.py +++ b/setup.py @@ -517,6 +517,7 @@ def finalize_options(self): "onnxruntime.quantization.CalTableFlatBuffers", "onnxruntime.quantization.fusions", "onnxruntime.quantization.execution_providers.qnn", + "onnxruntime.quantization.neural_compressor", "onnxruntime.transformers", "onnxruntime.transformers.models.bart", "onnxruntime.transformers.models.bert", From b4f7a905b0d636b71bd486c0ef702eb5a44eadf2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 3 May 2025 16:48:51 +1000 Subject: [PATCH 02/84] Selection policy. Device discovery updates. Bug fixes. (#24625) ### Description Add initial selection policy implementations. Update device discovery - get vendor and vendor id for CPU from cpuid_info - trim metadata to known useful fields - NPU detection via dxcore only Bug fixes/updates from PRs for C# and python bindings Add some tests for selection policy - TODO: Add more tests ### Motivation and Context Desire to boil oceans. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 44 ++- .../core/session/onnxruntime_cxx_api.h | 12 +- .../core/session/onnxruntime_cxx_inline.h | 7 + onnxruntime/core/common/cpuid_info.cc | 10 + onnxruntime/core/common/cpuid_info.h | 7 + .../core/framework/graph_partitioner.cc | 8 +- onnxruntime/core/framework/session_options.h | 15 + .../core/platform/windows/device_discovery.cc | 85 +++-- .../core/session/abi_key_value_pairs.h | 2 + .../core/session/abi_session_options.cc | 11 + .../core/session/ep_library_internal.cc | 8 +- onnxruntime/core/session/inference_session.h | 4 +- .../core/session/model_compilation_options.cc | 4 +- onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 4 + .../core/session/provider_policy_context.cc | 347 ++++++++++++++++++ .../core/session/provider_policy_context.h | 79 ++++ onnxruntime/core/session/utils.cc | 17 +- .../test/autoep/test_autoep_selection.cc | 122 +++++- 19 files changed, 724 insertions(+), 65 deletions(-) create mode 100644 onnxruntime/core/session/provider_policy_context.cc create mode 100644 onnxruntime/core/session/provider_policy_context.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9a5891f9e236d..cef5eab9a505e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -425,6 +425,32 @@ typedef enum OrtExecutionProviderDevicePolicy { OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, } OrtExecutionProviderDevicePolicy; +/** \brief Delegate to allow providing custom OrtEpDevice selection logic + * + * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. + * The user can use this to select OrtEpDevice instances from the list of available devices. + * + * \param ep_devices The list of available devices. + * \param num_devices The number of available devices. + * \param model_metadata The model metadata. + * \param runtime_metadata The runtime metadata. May be nullptr. + * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. + * \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array. + Currently the maximum is 8. + * \param num_ep_devices The number of selected devices. + * + * \return OrtStatus* Selection status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -5073,7 +5099,8 @@ struct OrtApi { ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); - /** \brief Append execution provider to the session options by name. + /** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances + * to the session options. * * \param[in] session_options Session options to add execution provider to. * \param[in] env Environment that execution providers were registered with. @@ -5098,6 +5125,21 @@ struct OrtApi { _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); + /** \brief Set the execution provider selection policy for the session. + * + * Allows users to specify a device selection policy for automatic execution provider (EP) selection, + * or provide a delegate callback for custom selection logic. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). + * \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate); + /** \brief Get the hardware device type. * * \param[in] device The OrtHardwareDevice instance to query. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0ecc27c59dc28..6c175c606b4a1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1085,19 +1085,27 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const KeyValuePairs& ep_options); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const std::unordered_map& ep_options); + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy + SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, + EpSelectionDelegate* delegate = nullptr); + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48b3b80cced55..1fdb8f16d9600 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1149,6 +1149,13 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, + EpSelectionDelegate* delegate) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 97766028cfe12..ee7782e3c8763 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -106,6 +106,7 @@ void CPUIDInfo::X86Init() { GetCPUID(0, data); vendor_ = GetX86Vendor(data); + vendor_id_ = GetVendorId(vendor_); int num_IDs = data[0]; if (num_IDs >= 1) { @@ -151,6 +152,14 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { #endif // defined(CPUIDINFO_ARCH_X86) +uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { + if (vendor == "GenuineIntel") return 0x8086; + if (vendor == "GenuineAMD") return 0x1022; + if (vendor.find("Qualcomm") == 0) return 'Q' << 24 | 'C' << 16 | 'O' << 8 | 'M'; + if (vendor.find("NV") == 0) return 0x10DE; + return 0; +} + #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -204,6 +213,7 @@ void CPUIDInfo::ArmLinuxInit() { void CPUIDInfo::ArmWindowsInit() { // Get the ARM vendor string from the registry vendor_ = GetArmWindowsVendor(); + vendor_id_ = GetVendorId(vendor_); // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 4d6e7e8b9105e..b820fa2ab1af7 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -19,6 +19,10 @@ class CPUIDInfo { return vendor_; } + uint32_t GetCPUVendorId() const { + return vendor_id_; + } + bool HasAMX_BF16() const { return has_amx_bf16_; } bool HasAVX() const { return has_avx_; } bool HasAVX2() const { return has_avx2_; } @@ -123,6 +127,9 @@ class CPUIDInfo { bool has_arm_neon_bf16_{false}; std::string vendor_; + uint32_t vendor_id_; + + uint32_t GetVendorId(const std::string& vendor); #if defined(CPUIDINFO_ARCH_X86) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e3e54be3f7c21..8ed5eeaa8d44f 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -806,7 +806,13 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes, "Compiled model does not contain any EPContext nodes. " "Check that the session EPs support compilation and can execute at least one model subgraph."); - return Status::OK(); + + LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. " + "Either the session EPs do not support compilation or " + "no subgraphs were able to be compiled."; + + // we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not + // EPContext nodes. } auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 94ff2bb55a055..8f8a3d6634a7e 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -90,6 +90,15 @@ struct EpContextModelGenerationOptions { size_t output_external_initializer_size_threshold = 0; }; +struct EpSelectionPolicy { + // flag to detect that a policy was set by the user. + // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered + // and no selection policy was explicitly provided. + bool enable{false}; + OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT; + EpSelectionDelegate* delegate{}; +}; + /** * Configuration information for a session. */ @@ -222,6 +231,11 @@ struct SessionOptions { // copied internally and the flag needs to be accessible across all copies. std::shared_ptr load_cancellation_flag = std::make_shared(false); + // Policy to guide Execution Provider selection + EpSelectionPolicy ep_selection_policy = {false, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_DEFAULT, + nullptr}; + // Options for generating compile EPContext models were previously stored in session_option.configs as // string key/value pairs. To support more advanced options, such as setting input/output buffers, we // now have to store EPContext options in a struct of type EpContextModelGenerationOptions. @@ -253,6 +267,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " use_per_session_threads:" << session_options.use_per_session_threads << " thread_pool_allow_spinning:" << session_options.thread_pool_allow_spinning << " use_deterministic_compute:" << session_options.use_deterministic_compute + << " ep_selection_policy:" << session_options.ep_selection_policy.policy << " config_options: { " << session_options.config_options << " }" //<< " initializers_to_share_map:" << session_options.initializers_to_share_map #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 88fbec37c8075..5a5b5041a5912 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -119,7 +119,13 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde uint32_t vendor_id = get_id(buffer, L"VEN_"); uint32_t device_id = get_id(buffer, L"DEV_"); - // won't always have a vendor id from an ACPI entry. need at least a device id to identify the hardware + + // Processor ID should come from CPUID mapping. + if (vendor_id == 0 && guid == GUID_DEVCLASS_PROCESSOR) { + vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId(); + } + + // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { continue; } @@ -138,8 +144,8 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde entry = &device_info[key]; entry->vendor_id = vendor_id; entry->device_id = device_id; - // put the first hardware id string in the metadata. ignore the other lines. - entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); + // put the first hardware id string in the metadata. ignore the other lines. not sure if this is of value. + // entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); } else { // need valid ids continue; @@ -156,14 +162,14 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde (PBYTE)buffer, sizeof(buffer), &size)) { std::wstring desc{buffer}; - // Should we require the NPU to be found by DXCore or do we want to allow this vague matching? - // Probably depends on whether we always attempt to run DXCore or not. - const auto possible_npu = [](const std::wstring& desc) { - return (desc.find(L"NPU") != std::wstring::npos || - desc.find(L"Neural") != std::wstring::npos || - desc.find(L"AI Engine") != std::wstring::npos || - desc.find(L"VPU") != std::wstring::npos); - }; + // For now, require dxcore to identify an NPU. + // If we want to try and infer it from the description this _may_ work but is untested. + // const auto possible_npu = [](const std::wstring& desc) { + // return (desc.find(L"NPU") != std::wstring::npos || + // desc.find(L"Neural") != std::wstring::npos || + // desc.find(L"AI Engine") != std::wstring::npos || + // desc.find(L"VPU") != std::wstring::npos); + // }; // use description if no friendly name if (entry->description.empty()) { @@ -171,7 +177,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } uint64_t npu_key = GetDeviceKey(*entry); - bool is_npu = npus.count(npu_key) > 0 || possible_npu(desc); + bool is_npu = npus.count(npu_key) > 0; // rely on dxcore to determine if something is an NPU if (guid == GUID_DEVCLASS_DISPLAY) { entry->type = OrtHardwareDeviceType_GPU; @@ -194,22 +200,17 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde continue; } - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, - (PBYTE)buffer, sizeof(buffer), &size)) { - entry->vendor = std::wstring(buffer, wcslen(buffer)); + if (entry->type == OrtHardwareDeviceType_CPU) { + // get 12 byte string from CPUID. easier for a user to match this if they are explicitly picking a device. + std::string_view cpuid_vendor = CPUIDInfo::GetCPUIDInfo().GetCPUVendor(); + entry->vendor = std::wstring(cpuid_vendor.begin(), cpuid_vendor.end()); } - // Add the UI number if GPU. Helpful if user has integrated and discrete GPUs - if (entry->type == OrtHardwareDeviceType_GPU) { - DWORD ui_number = 0; - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_UI_NUMBER, nullptr, - (PBYTE)&ui_number, sizeof(ui_number), &size)) { - // use value read. - } else { - // infer it as 0 if not set. + if (entry->vendor.empty()) { + if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, + (PBYTE)buffer, sizeof(buffer), &size)) { + entry->vendor = std::wstring(buffer, wcslen(buffer)); } - - entry->metadata.emplace(L"SPDRP_UI_NUMBER", std::to_wstring(ui_number)); } } @@ -252,9 +253,7 @@ std::unordered_map GetDeviceInfoD3D12() { info.description = std::wstring(desc.Description); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); - info.metadata[L"VideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SharedSystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; + info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; } // iterate by high-performance GPU preference to add that info @@ -272,7 +271,7 @@ std::unordered_map GetDeviceInfoD3D12() { auto it = device_info.find(key); if (it != device_info.end()) { DeviceInfo& info = it->second; - info.metadata[L"HighPerformanceIndex"] = std::to_wstring(i); + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); } } @@ -405,25 +404,40 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - std::wstring_convert > converter; // wstring to string - const auto device_to_ortdevice = [&converter]( + // our log output to std::wclog breaks with UTF8 chars that are not supported by the current code page. + // e.g. (TM) symbol. that stops ALL logging working on at least arm64. + // safest way to avoid that is to keep it to single byte chars. + // process the OrtHardwareDevice values this way so it can be safely logged. + // only the 'description' metadata is likely to be affected and that is mainly for diagnostic purposes. + const auto to_safe_string = [](const std::wstring& wstr) -> std::string { + std::string str(wstr.size(), ' '); + std::transform(wstr.begin(), wstr.end(), str.begin(), [](wchar_t wchar) { + if (wchar >= 0 && wchar <= 127) { + return static_cast(wchar); + } + return ' '; + }); + return str; + }; + + const auto device_to_ortdevice = [&to_safe_string]( DeviceInfo& device, std::unordered_map* extra_metadata = nullptr) { - OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, converter.to_bytes(device.vendor)}; + OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)}; if (!device.description.empty()) { - ortdevice.metadata.Add("Description", converter.to_bytes(device.description)); + ortdevice.metadata.Add("Description", to_safe_string(device.description)); } for (auto& [key, value] : device.metadata) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } if (extra_metadata) { // add any extra metadata from the dxcore info for (auto& [key, value] : *extra_metadata) { if (device.metadata.find(key) == device.metadata.end()) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } } } @@ -431,6 +445,7 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor std::ostringstream oss; oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id << ", device_id:0x" << ortdevice.device_id + << ", vendor:" << ortdevice.vendor << ", type:" << std::dec << static_cast(ortdevice.type) << ", metadata: ["; for (auto& [key, value] : ortdevice.metadata.entries) { diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 3242be817881a..150575b3a9efc 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -57,6 +57,8 @@ struct OrtKeyValuePairs { keys.erase(key_iter); values.erase(values.begin() + idx); } + + entries.erase(iter); } } diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 0b116c2fa64f6..b1c0467da642e 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -366,6 +366,17 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = policy; + options->value.ep_selection_policy.delegate = delegate; + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, _In_ bool is_cancel) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index c515195c7e6bf..aa032f24f13c0 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -183,14 +183,14 @@ std::vector> EpLibraryInternal::CreateInterna // CPU EP internal_eps.push_back(CreateCpuEp()); -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - #if defined(USE_WEBGPU) internal_eps.push_back(CreateWebGpuEp()); #endif +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + return internal_eps; } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index a21388d1e9918..ba9812a59fec3 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -350,8 +350,8 @@ class InferenceSession { /** * Initializes a previously loaded ONNX model. Initialization includes but is not - * limited to graph transformations, construction of kernels, etc. - * This method assumes that a method has been loaded previously. + * limited to graph transformations, construction of kernels, EP policy decisions, etc. + * This method assumes that a model has been loaded previously. * This API is thread-safe. * @return OK if success */ diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 80ef18de5cfa3..c4a7c5262d03d 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -19,7 +19,9 @@ ModelCompilationOptions::ModelCompilationOptions(const OrtEnv& env, const OrtSes session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); session_options_.value.ep_context_gen_options.enable = true; session_options_.value.ep_context_gen_options.overwrite_existing_output_file = true; - session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = true; + // defaulting to false to support wider usage. will log WARNING if compiling model with no context nodes. + // TODO: Add ability for user to explicitly set this. + session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = false; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b5c271594055a..c70075b234faf 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3012,6 +3012,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::UnregisterExecutionProviderLibrary, &OrtApis::GetEpDevices, &OrtApis::SessionOptionsAppendExecutionProvider_V2, + &OrtApis::SessionOptionsSetEpSelectionPolicy, &OrtApis::HardwareDevice_Type, &OrtApis::HardwareDevice_VendorId, @@ -3061,7 +3062,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 315, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0033eb0d604f2..7928f9b822cf0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -575,6 +575,10 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate); + // OrtHardwareDevice accessors. ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device); ORT_API(uint32_t, HardwareDevice_VendorId, _In_ const OrtHardwareDevice* device); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc new file mode 100644 index 0000000000000..565891fe2cdfb --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/provider_policy_context.h" + +#include + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/ep_factory_internal.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +namespace { +bool MatchesEpVendor(const OrtEpDevice* d) { + // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + return d->device->vendor == d->ep_vendor; +} + +bool IsDiscreteDevice(const OrtEpDevice* d) { + if (d->device->type != OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + const auto& entries = d->device->metadata.entries; + if (auto it = entries.find("Discrete"); it != entries.end()) { + return it->second == "1"; + } + + return false; +} + +bool IsDefaultCpuEp(const OrtEpDevice* d) { + return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU && + d->ep_vendor == "Microsoft"; +} + +// Sort devices. NPU -> GPU -> CPU +// Within in type, vendor owned, not. +// Default CPU EP is last +std::vector OrderDevices(const std::vector& devices) { + std::vector sorted_devices(devices.begin(), devices.end()); + std::sort(sorted_devices.begin(), sorted_devices.end(), [](const OrtEpDevice* a, const OrtEpDevice* b) { + auto aDeviceType = a->device->type; + auto bDeviceType = b->device->type; + if (aDeviceType != bDeviceType) { + // NPU -> GPU -> CPU + // std::sort is ascending order, so NPU < GPU < CPU + + // one is an NPU + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return false; + } + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + // this shouldn't be reachable as it would imply both are CPU + ORT_THROW("Unexpected combination of devices"); + } + + // both devices are the same + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + bool aDiscrete = IsDiscreteDevice(a); + bool bDiscrete = IsDiscreteDevice(b); + if (aDiscrete != bDiscrete) { + return aDiscrete == true; // prefer discrete + } + + // both discrete or both integrated + } + + // prefer device matching platform vendor + bool aVendor = MatchesEpVendor(a); + bool bVendor = MatchesEpVendor(b); + if (aVendor != bVendor) { + return aVendor == true; // prefer the device that matches the EP vendor + } + + // default CPU EP last + bool aIsDefaultCpuEp = IsDefaultCpuEp(a); + bool bIsDefaultCpuEp = IsDefaultCpuEp(b); + if (!aIsDefaultCpuEp && !bIsDefaultCpuEp) { + // neither are default CPU EP. both do/don't match vendor. + // TODO: implement tie-breaker for this scenario. arbitrarily prefer the first for now + return true; + } + + // one is the default CPU EP + return aIsDefaultCpuEp == false; // prefer the one that is not the default CPU EP + }); + + return sorted_devices; +} +} // namespace + +// Select execution providers based on the device policy and available devices and add to session +Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, + InferenceSession& sess) { + ORT_ENFORCE(options.value.ep_selection_policy.delegate == nullptr, + "EP selection delegate support is not implemented yet."); + + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + const auto& execution_devices = OrderDevices(env.GetOrtEpDevices()); + + // The list of devices selected by policies + std::vector devices_selected; + + // Run the delegate if it was passed in lieu of any other policy + if (options.value.ep_selection_policy.delegate) { + auto policy_fn = options.value.ep_selection_policy.delegate; + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); + std::array selected_devices{nullptr}; + + size_t num_selected = 0; + auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(), + nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected); + + // return or fall-through for both these cases + // going with explicit failure for now so it's obvious to user what is happening + if (status != nullptr) { + std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy + OrtApis::ReleaseStatus(status); + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); + } + + if (num_selected == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); + } + } else { + // Create the selector for the chosen policy + std::unique_ptr selector; + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + selector = std::make_unique(); + break; + } + + // Execute policy + + selector->SelectProvidersForDevices(execution_devices, devices_selected); + } + + // Fail if we did not find any device matches + if (devices_selected.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "No execution providers selected. Please check the device policy and available devices."); + } + + // Configure the session options for the devices. This updates the SessionOptions in the InferenceSession with any + // EP options that have not been overridden by the user. + ORT_RETURN_IF_ERROR(AddEpDefaultOptionsToSession(sess, devices_selected)); + + // Create OrtSessionOptions for the CreateEp call. + // Once the InferenceSession is created, its SessionOptions is the source of truth and contains all the values from + // the user provided OrtSessionOptions. We do a copy for simplicity. The OrtSessionOptions instance goes away + // once we exit this function so an EP implementation should not use OrtSessionOptions after it returns from + // CreateEp. + auto& session_options = sess.GetMutableSessionOptions(); + OrtSessionOptions ort_so; + ort_so.value = session_options; + const auto& session_logger = sess.GetLogger(); + const OrtLogger& api_session_logger = *session_logger->ToExternal(); + + // Remove the ORT CPU EP if configured to do so + bool disable_ort_cpu_ep = ort_so.value.config_options.GetConfigEntry(kOrtSessionOptionsDisableCPUEPFallback) == "1"; + if (disable_ort_cpu_ep) { + RemoveOrtCpuDevice(devices_selected); + } + + // Fold the EPs into a single structure per factory + std::vector eps_selected; + FoldSelectedDevices(devices_selected, eps_selected); + + // Iterate through the selected EPs and create them + for (size_t idx = 0; idx < eps_selected.size(); ++idx) { + std::unique_ptr ep = nullptr; + ORT_RETURN_IF_ERROR(CreateExecutionProvider(env, ort_so, api_session_logger, eps_selected[idx], ep)); + if (ep != nullptr) { + ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::FoldSelectedDevices(std::vector devices_selected, + std::vector& eps_selected) { + while (devices_selected.size() > 0) { + const auto ep_name = std::string(devices_selected[0]->ep_name); + SelectionInfo info; + info.ep_factory = devices_selected[0]->ep_factory; + + do { + auto iter = std::find_if(devices_selected.begin(), devices_selected.end(), [&ep_name](const OrtEpDevice* d) { + return d->ep_name == ep_name; + }); + + if (iter != devices_selected.end()) { + info.devices.push_back((*iter)->device); + info.ep_metadata.push_back(&(*iter)->ep_metadata); + devices_selected.erase(iter); + } else { + break; + } + } while (true); + + eps_selected.push_back(info); + } +} + +Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, + const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep) { + EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(info.ep_factory); + + if (internal_factory) { + // this is a factory we created and registered internally for internal and provider bridge EPs + OrtStatus* status = internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), + info.devices.size(), &options, &logger, + &ep); + if (status != nullptr) { + return ToStatus(status); + } + } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + // but first we need OrtEp and the OrtEpApi to be implemented. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + // OrtEp* api_ep = nullptr; + //// add the ep_options to session options but leave any existing entries (user provided overrides) untouched. + // auto status = info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), info.ep_metadata.data(), + // info.devices.size(), &options, &logger, + // &api_ep); + // if (status != nullptr) { + // return ToStatus(status); + // } + } + + return Status::OK(); +} + +Status ProviderPolicyContext::AddEpDefaultOptionsToSession(InferenceSession& sess, + std::vector devices) { + auto& config_options = sess.GetMutableSessionOptions().config_options; + for (auto device : devices) { + const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(device->ep_name.c_str()); + for (const auto& [key, value] : device->ep_options.entries) { + const std::string option_key = ep_options_prefix + key; + // preserve user-provided options as they override any defaults the EP factory specified earlier + if (config_options.configurations.find(option_key) == config_options.configurations.end()) { + // use AddConfigEntry for the error checking it does + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry(option_key.c_str(), value.c_str())); + } + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::RemoveOrtCpuDevice(std::vector& devices) { + // Remove the Microsoft CPU EP. always last if available. + if (IsDefaultCpuEp(devices.back())) { + devices.pop_back(); + } +} + +void DefaultEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Default policy is prefer CPU + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferCpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first CPU device from sorted devices + auto first_cpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU; + }); + + ORT_ENFORCE(first_cpu != sorted_devices.end(), "No CPU based execution providers were found."); + selected.push_back(*first_cpu); + + // add ORT CPU EP as the final option to ensure maximum coverage of opsets and operators + if (!IsDefaultCpuEp(*first_cpu) && IsDefaultCpuEp(sorted_devices.back())) { + selected.push_back(sorted_devices.back()); + } +} + +void PreferNpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first NPU if there is one. + if (sorted_devices.front()->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + selected.push_back(sorted_devices.front()); + } + + // CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferGpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first GPU device + auto first_gpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU; + }); + + if (first_gpu != sorted_devices.end()) { + selected.push_back(*first_gpu); + } + + // Add a CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h new file mode 100644 index 0000000000000..185f9523312ba --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.h @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/abi_session_options_impl.h" +#include "core/session/environment.h" +#include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy + +namespace onnxruntime { + +struct SelectionInfo { + OrtEpFactory* ep_factory; + std::vector devices; + std::vector ep_metadata; +}; + +class IEpPolicySelector { + public: + /// + /// Select the OrtEpDevice instances to use. + /// Selection is in priority order. Highest priority first. + /// + /// Ordered devices. + /// Type order is NPU -> GPU -> CPU + /// Within a type: Discrete -> Integrated if GPU, EP vendor matches device vendor, vendor does not match + /// ORT CPU EP is always last if available. + /// + /// + virtual void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) = 0; + + virtual ~IEpPolicySelector() = default; +}; + +class ProviderPolicyContext { + public: + ProviderPolicyContext() = default; + + Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); + Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); + void RemoveOrtCpuDevice(std::vector& devices); + Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep); + void FoldSelectedDevices(std::vector devices_selected, // copy + std::vector& eps_selected); + + private: +}; + +class DefaultEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferCpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferNpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferGpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +} // namespace onnxruntime + +#endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index adb019fdde86d..d17514e54a945 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -18,6 +18,7 @@ #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/provider_policy_context.h" using namespace onnxruntime; #if !defined(ORT_MINIMAL_BUILD) @@ -71,6 +72,11 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return ToStatus(status); } } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + /* OrtEp* api_ep = nullptr; auto status = ep_device->ep_factory->CreateEp( ep_device->ep_factory, devices.data(), ep_metadata.data(), devices.size(), @@ -79,10 +85,7 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con if (status != nullptr) { return ToStatus(status); } - - // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, - // and we would add that IExecutionProvider to the InferenceSession. - ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + */ } ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); @@ -175,6 +178,12 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, if (auto_select_ep_name) { ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env->GetEnvironment(), *sess, *auto_select_ep_name)); } + + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection + if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env->GetEnvironment(), *options, *sess)); + } #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index 619f0a4bcda33..04b1b2ea0bdc4 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -64,7 +64,10 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod const std::vector& expected_dims_y, const std::vector& expected_values_y, bool auto_select = true, // auto select vs SessionOptionsAppendExecutionProvider_V2 + // manual select using functor const std::function&)>& select_devices = nullptr, + // auto select using policy + std::optional policy = std::nullopt, bool test_session_creation_only = false) { Ort::SessionOptions session_options; @@ -74,16 +77,20 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } if (auto_select) { - // manually specify EP to select for now - session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); - - // add the provider options to the session options with the required prefix - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); - std::vector keys, values; - ep_options.GetKeyValuePairs(keys, values); - for (size_t i = 0, end = keys.size(); i < end; ++i) { - // add the default value with prefix - session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + if (policy) { + session_options.SetEpSelectionPolicy(*policy); + } else { + // manually specify EP to select + session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); + + // add the provider options to the session options with the required prefix + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); + std::vector keys, values; + ep_options.GetKeyValuePairs(keys, values); + for (size_t i = 0, end = keys.size(); i < end; ++i) { + // add the default value with prefix + session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + } } } else { std::vector devices; @@ -188,7 +195,7 @@ TEST(AutoEpSelection, DmlEP) { devices.push_back(ep_device); } else { // if this is available, 0 == best performance - auto* perf_index = c_api->GetKeyValue(kvps, "HighPerformanceIndex"); + auto* perf_index = c_api->GetKeyValue(kvps, "DxgiHighPerformanceIndex"); if (perf_index && strcmp(perf_index, "0") == 0) { devices[0] = ep_device; // replace as this is the higher performance device } @@ -213,20 +220,27 @@ TEST(AutoEpSelection, WebGpuEP) { TEST(AutoEpSelection, MiscApiTests) { const OrtApi* c_api = &Ort::GetApi(); - // nullptr and empty input to OrtKeyValuePairs + // nullptr and empty input to OrtKeyValuePairs. also test RemoveKeyValuePair { OrtKeyValuePairs* kvps = nullptr; c_api->CreateKeyValuePairs(&kvps); c_api->AddKeyValuePair(kvps, "key1", nullptr); // should be ignored c_api->AddKeyValuePair(kvps, nullptr, "value1"); // should be ignored c_api->RemoveKeyValuePair(kvps, nullptr); // should be ignored - - c_api->AddKeyValuePair(kvps, "", "value2"); // empty key should be ignored + c_api->AddKeyValuePair(kvps, "", "value2"); // should be ignored ASSERT_EQ(c_api->GetKeyValue(kvps, ""), nullptr); + c_api->AddKeyValuePair(kvps, "key1", "value1"); c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + c_api->RemoveKeyValuePair(kvps, "key1"); + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + c_api->GetKeyValuePairs(kvps, &keys, &values, &num_entries); + ASSERT_EQ(num_entries, 1); + c_api->ReleaseKeyValuePairs(kvps); } @@ -259,6 +273,86 @@ TEST(AutoEpSelection, MiscApiTests) { } } +TEST(AutoEpSelection, PreferCpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_CPU); +} + +// this should fallback to CPU if no GPU +TEST(AutoEpSelection, PreferGpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_GPU); +} + +// this should fallback to CPU if no NPU +TEST(AutoEpSelection, PreferNpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU); +} + namespace { struct ExamplePluginInfo { const std::filesystem::path library_path = From c51d67b400e162b28c85c1dfe648635cb38714a3 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 3 May 2025 22:56:53 +1000 Subject: [PATCH 03/84] C# API updates for auto ep selection and compile API (#24604) ### Description C# API updates for auto ep selection and the compilation API. Also includes bugfix to OrtKeyValuePairs::Remove. ### Motivation and Context --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../CompileModel.shared.cs | 130 ++++++++ .../NativeCompileApiMethods.shared.cs | 152 ++++++++++ .../NativeMethods.shared.cs | 277 +++++++++++++++++- .../Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs | 63 ++++ .../OrtEpDevice.shared.cs | 98 +++++++ .../OrtHardwareDevice.shared.cs | 116 ++++++++ .../OrtKeyValuePairs.shared.cs | 192 ++++++++++++ .../OrtValue.shared.cs | 4 +- .../SessionOptions.shared.cs | 107 ++++++- .../CompileApiTests.cs | 67 +++++ .../OrtAutoEpTests.cs | 159 ++++++++++ .../OrtKeyValuePairsTests.cs | 39 +++ ...oft.ML.OnnxRuntime.Tests.NetCoreApp.csproj | 3 +- onnxruntime/core/session/ort_apis.h | 5 +- 14 files changed, 1402 insertions(+), 10 deletions(-) create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs create mode 100644 csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs new file mode 100644 index 0000000000000..9f42bf2247529 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// This class is used to set options for model compilation, and to produce a compiled model using those options. + /// See https://onnxruntime.ai/docs/api/c/ for further details of various options. + /// + public class OrtModelCompilationOptions : SafeHandle + { + /// + /// Create a new OrtModelCompilationOptions object from SessionOptions. + /// + /// SessionOptions instance to read settings from. + public OrtModelCompilationOptions(SessionOptions sessionOptions) + : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtCreateModelCompilationOptionsFromSessionOptions( + OrtEnv.Instance().Handle, sessionOptions.Handle, out handle)); + } + + /// + /// Compile the model using the options set in this object. + /// + public void CompileModel() + { + NativeApiStatus.VerifySuccess(NativeMethods.CompileApi.OrtCompileModel(OrtEnv.Instance().Handle, handle)); + } + + + /// + /// Set the input model to compile. + /// + /// Path to ONNX model to compile. + public void SetInputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelPath(handle, platformPath)); + } + + /// + /// Set the input model to compile to be a byte array. + /// The input bytes are NOT copied and must remain valid while in use by ORT. + /// + /// Input model bytes. + public void SetInputModelFromBuffer(byte[] buffer) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetInputModelFromBuffer( + handle, buffer, (UIntPtr)buffer.Length)); + } + + /// + /// Set the path to write the compiled ONNX model to. + /// + /// Path to write compiled model to. + public void SetOutputModelPath(string path) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(path); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelPath(handle, platformPath)); + + } + + /// + /// Set the path to a file to write initializers as external data to, + /// and the threshold that determines when to write an initializer to the external data file. + /// + /// Path to file to write external data to. + /// Size at which an initializer will be written to external data. + public void SetOutputModelExternalInitializersFile(string filePath, ulong threshold) + { + var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(filePath); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + handle, platformPath, new UIntPtr(threshold))); + } + + // TODO: In order to use this to create an InferenceSession without copying bytes we need more infrastructure. + // - Need something that wraps the allocator, pointer and size and is SafeHandle based. + // - When it is disposed we need to use the allocator to release the native buffer. + // - Need the 4 InferenceSession ctors that take byte[] for the model to be duplicated to handle this new + // wrapper type. + // Due to that making this API internal so we can test it. We can make it public when the other infrastructure + // is in place as it will change the signature of the API. + internal void SetOutputModelBuffer(OrtAllocator allocator, + ref IntPtr outputModelBufferPtr, ref UIntPtr outputModelBufferSizePtr) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetOutputModelBuffer( + handle, allocator.Pointer, ref outputModelBufferPtr, ref outputModelBufferSizePtr)); + } + + /// + /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute + /// of EPContext nodes. + /// + /// Enable if true. Default is false. + public void SetEpContextEmbedMode(bool embed) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextEmbedMode(handle, embed)); + } + + internal IntPtr Handle => handle; + + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// Release the native instance of OrtModelCompilationOptions. + /// + /// true + protected override bool ReleaseHandle() + { + NativeMethods.CompileApi.OrtReleaseModelCompilationOptions(handle); + handle = IntPtr.Zero; + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs new file mode 100644 index 0000000000000..3a87f87d124e9 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace Microsoft.ML.OnnxRuntime.CompileApi; + +using System; +using System.Runtime.InteropServices; + +// NOTE: The order of the APIs in this struct should match exactly that in OrtCompileApi +// See onnxruntime/core/session/compile_api.cc. +[StructLayout(LayoutKind.Sequential)] +public struct OrtCompileApi +{ + public IntPtr ReleaseModelCompilationOptions; + public IntPtr CreateModelCompilationOptionsFromSessionOptions; + public IntPtr ModelCompilationOptions_SetInputModelPath; + public IntPtr ModelCompilationOptions_SetInputModelFromBuffer; + public IntPtr ModelCompilationOptions_SetOutputModelPath; + public IntPtr ModelCompilationOptions_SetOutputModelExternalInitializersFile; + public IntPtr ModelCompilationOptions_SetOutputModelBuffer; + public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; + public IntPtr CompileModel; +} + +internal class NativeMethods +{ + private static OrtCompileApi _compileApi; + + // + // Define the delegate signatures, and a static member for each to hold the marshaled function pointer. + // + // We populate the static members in the constructor of this class. + // + // The C# code will call the C++ API through the delegate instances in the static members. + // + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseModelCompilationOptions(IntPtr /* OrtModelCompilationOptions* */ options); + public DOrtReleaseModelCompilationOptions OrtReleaseModelCompilationOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateModelCompilationOptionsFromSessionOptions( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtSessionOptions* */ sessionOptions, + out IntPtr /* OrtModelCompilationOptions** */ outOptions); + public DOrtCreateModelCompilationOptionsFromSessionOptions + OrtCreateModelCompilationOptionsFromSessionOptions; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ inputModelPath); + public DOrtModelCompilationOptions_SetInputModelPath OrtModelCompilationOptions_SetInputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetInputModelFromBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const void* */ inputModelData, + UIntPtr /* size_t */ inputModelDataSize); + public DOrtModelCompilationOptions_SetInputModelFromBuffer + OrtModelCompilationOptions_SetInputModelFromBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelPath( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputModelPath); + public DOrtModelCompilationOptions_SetOutputModelPath OrtModelCompilationOptions_SetOutputModelPath; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ externalInitializersFilePath, + UIntPtr /* size_t */ externalInitializerSizeThreshold); + public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetOutputModelBuffer( + IntPtr /* OrtModelCompilationOptions* */ options, + IntPtr /* OrtAllocator* */ allocator, + ref IntPtr /* void** */ outputModelBufferPtr, + ref UIntPtr /* size_t* */ outputModelBufferSizePtr); + public DOrtModelCompilationOptions_SetOutputModelBuffer OrtModelCompilationOptions_SetOutputModelBuffer; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextEmbedMode( + IntPtr /* OrtModelCompilationOptions* */ options, + bool embedEpContextInModel); + public DOrtModelCompilationOptions_SetEpContextEmbedMode OrtModelCompilationOptions_SetEpContextEmbedMode; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCompileModel( + IntPtr /* const OrtEnv* */ env, + IntPtr /* const OrtModelCompilationOptions* */ modelOptions); + public DOrtCompileModel OrtCompileModel; + + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) + { + +#if NETSTANDARD2_0 + IntPtr compileApiPtr = getCompileApi(); + _compileApi = (OrtCompileApi)Marshal.PtrToStructure(compileApiPtr, typeof(OrtCompileApi)); +#else + _compileApi = (OrtCompileApi)getCompileApi(); +#endif + + OrtReleaseModelCompilationOptions = + (DOrtReleaseModelCompilationOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.ReleaseModelCompilationOptions, + typeof(DOrtReleaseModelCompilationOptions)); + + OrtCreateModelCompilationOptionsFromSessionOptions = + (DOrtCreateModelCompilationOptionsFromSessionOptions)Marshal.GetDelegateForFunctionPointer( + _compileApi.CreateModelCompilationOptionsFromSessionOptions, + typeof(DOrtCreateModelCompilationOptionsFromSessionOptions)); + + OrtModelCompilationOptions_SetInputModelPath = + (DOrtModelCompilationOptions_SetInputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelPath, + typeof(DOrtModelCompilationOptions_SetInputModelPath)); + + OrtModelCompilationOptions_SetInputModelFromBuffer = + (DOrtModelCompilationOptions_SetInputModelFromBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetInputModelFromBuffer, + typeof(DOrtModelCompilationOptions_SetInputModelFromBuffer)); + + OrtModelCompilationOptions_SetOutputModelPath = + (DOrtModelCompilationOptions_SetOutputModelPath)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelPath, + typeof(DOrtModelCompilationOptions_SetOutputModelPath)); + + OrtModelCompilationOptions_SetOutputModelExternalInitializersFile = + (DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelExternalInitializersFile, + typeof(DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile)); + + OrtModelCompilationOptions_SetOutputModelBuffer = + (DOrtModelCompilationOptions_SetOutputModelBuffer)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetOutputModelBuffer, + typeof(DOrtModelCompilationOptions_SetOutputModelBuffer)); + + OrtModelCompilationOptions_SetEpContextEmbedMode = + (DOrtModelCompilationOptions_SetEpContextEmbedMode)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextEmbedMode, + typeof(DOrtModelCompilationOptions_SetEpContextEmbedMode)); + + OrtCompileModel = + (DOrtCompileModel)Marshal.GetDelegateForFunctionPointer( + _compileApi.CompileModel, + typeof(DOrtCompileModel)); + } +} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 77c35aac65b92..620c13b8641b5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -336,12 +336,43 @@ public struct OrtApi public IntPtr GetModelEditorApi; public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; public IntPtr SessionOptionsSetLoadCancellationFlag; + + public IntPtr GetCompileApi; + + public IntPtr CreateKeyValuePairs; + public IntPtr AddKeyValuePair; + public IntPtr GetKeyValue; + public IntPtr GetKeyValuePairs; + public IntPtr RemoveKeyValuePair; + public IntPtr ReleaseKeyValuePairs; + + public IntPtr RegisterExecutionProviderLibrary; + public IntPtr UnregisterExecutionProviderLibrary; + + public IntPtr GetEpDevices; + + public IntPtr SessionOptionsAppendExecutionProvider_V2; + public IntPtr SessionOptionsSetEpSelectionPolicy; + + public IntPtr HardwareDevice_Type; + public IntPtr HardwareDevice_VendorId; + public IntPtr HardwareDevice_Vendor; + public IntPtr HardwareDevice_DeviceId; + public IntPtr HardwareDevice_Metadata; + + public IntPtr EpDevice_EpName; + public IntPtr EpDevice_EpVendor; + public IntPtr EpDevice_EpMetadata; + public IntPtr EpDevice_EpOptions; + public IntPtr EpDevice_Device; } internal static class NativeMethods { static OrtApi api_; + static internal CompileApi.NativeMethods CompileApi; + #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtGetApi(UInt32 version); @@ -582,6 +613,85 @@ static NativeMethods() typeof(DReleaseLoraAdapter)); OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); + + OrtGetCompileApi = (DOrtGetCompileApi)Marshal.GetDelegateForFunctionPointer( + api_.GetCompileApi, typeof(DOrtGetCompileApi)); + + // populate the CompileApi struct now that we have the delegate to get the compile API pointer. + CompileApi = new CompileApi.NativeMethods(OrtGetCompileApi); + + OrtCreateKeyValuePairs = (DOrtCreateKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.CreateKeyValuePairs, typeof(DOrtCreateKeyValuePairs)); + + OrtAddKeyValuePair = (DOrtAddKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.AddKeyValuePair, typeof(DOrtAddKeyValuePair)); + + OrtGetKeyValue = (DOrtGetKeyValue)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValue, typeof(DOrtGetKeyValue)); + + OrtGetKeyValuePairs = (DOrtGetKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.GetKeyValuePairs, typeof(DOrtGetKeyValuePairs)); + + OrtRemoveKeyValuePair = (DOrtRemoveKeyValuePair)Marshal.GetDelegateForFunctionPointer( + api_.RemoveKeyValuePair, typeof(DOrtRemoveKeyValuePair)); + + OrtReleaseKeyValuePairs = (DOrtReleaseKeyValuePairs)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseKeyValuePairs, typeof(DOrtReleaseKeyValuePairs)); + + OrtHardwareDevice_Type = (DOrtHardwareDevice_Type)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Type, typeof(DOrtHardwareDevice_Type)); + + OrtHardwareDevice_VendorId = (DOrtHardwareDevice_VendorId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_VendorId, typeof(DOrtHardwareDevice_VendorId)); + + OrtHardwareDevice_Vendor = (DOrtHardwareDevice_Vendor)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Vendor, typeof(DOrtHardwareDevice_Vendor)); + + OrtHardwareDevice_DeviceId = (DOrtHardwareDevice_DeviceId)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_DeviceId, typeof(DOrtHardwareDevice_DeviceId)); + + OrtHardwareDevice_Metadata = (DOrtHardwareDevice_Metadata)Marshal.GetDelegateForFunctionPointer( + api_.HardwareDevice_Metadata, typeof(DOrtHardwareDevice_Metadata)); + + + OrtEpDevice_EpName = (DOrtEpDevice_EpName)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpName, typeof(DOrtEpDevice_EpName)); + + OrtEpDevice_EpVendor = (DOrtEpDevice_EpVendor)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpVendor, typeof(DOrtEpDevice_EpVendor)); + + OrtEpDevice_EpMetadata = (DOrtEpDevice_EpMetadata)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpMetadata, typeof(DOrtEpDevice_EpMetadata)); + + OrtEpDevice_EpOptions = (DOrtEpDevice_EpOptions)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_EpOptions, typeof(DOrtEpDevice_EpOptions)); + + OrtEpDevice_Device = (DOrtEpDevice_Device)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_Device, typeof(DOrtEpDevice_Device)); + + OrtRegisterExecutionProviderLibrary = + (DOrtRegisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.RegisterExecutionProviderLibrary, + typeof(DOrtRegisterExecutionProviderLibrary)); + + OrtUnregisterExecutionProviderLibrary = + (DOrtUnregisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( + api_.UnregisterExecutionProviderLibrary, + typeof(DOrtUnregisterExecutionProviderLibrary)); + + OrtGetEpDevices = (DOrtGetEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetEpDevices, + typeof(DOrtGetEpDevices)); + + OrtSessionOptionsAppendExecutionProvider_V2 = + (DOrtSessionOptionsAppendExecutionProvider_V2)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsAppendExecutionProvider_V2, + typeof(DOrtSessionOptionsAppendExecutionProvider_V2)); + + OrtSessionOptionsSetEpSelectionPolicy = + (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicy, + typeof(DSessionOptionsSetEpSelectionPolicy)); } internal class NativeLib @@ -823,7 +933,7 @@ internal class NativeLib IntPtr /* (OrtEnv*) */ environment, //[MarshalAs(UnmanagedType.LPStr)]string modelPath byte[] modelPath, - IntPtr /* (OrtSessionOptions*) */ sessopnOptions, + IntPtr /* (OrtSessionOptions*) */ sessionOptions, out IntPtr /**/ session); public static DOrtCreateSession OrtCreateSession; @@ -1350,7 +1460,7 @@ out IntPtr lora_adapter #endregion - #region RunOptions API +#region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -2153,7 +2263,168 @@ out IntPtr lora_adapter #endregion -#region Misc API +#region Compile API + +#if NETSTANDARD2_0 + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtGetCompileApi(); +#else + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate ref CompileApi.OrtCompileApi DOrtGetCompileApi(); +#endif + public static DOrtGetCompileApi OrtGetCompileApi; +#endregion + +#region Auto EP API related + // + // OrtKeyValuePairs + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key, + byte[] /* const char* */ value); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps, + out IntPtr /* const char* const** */ keys, + out IntPtr /* const char* const** */ values, + out UIntPtr /* size_t* */ numEntries); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, + byte[] /* const char* */ key); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps); + + + public static DOrtCreateKeyValuePairs OrtCreateKeyValuePairs; + public static DOrtAddKeyValuePair OrtAddKeyValuePair; + public static DOrtGetKeyValue OrtGetKeyValue; + public static DOrtGetKeyValuePairs OrtGetKeyValuePairs; + public static DOrtRemoveKeyValuePair OrtRemoveKeyValuePair; + public static DOrtReleaseKeyValuePairs OrtReleaseKeyValuePairs; + + + // + // OrtHardwareDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate int /* OrtHardwareDeviceType */ DOrtHardwareDevice_Type( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_VendorId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtHardwareDevice_Vendor( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate uint /* uint32_t */ DOrtHardwareDevice_DeviceId( + IntPtr /* const OrtHardwareDevice* */ device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtHardwareDevice_Metadata( + IntPtr /* const OrtHardwareDevice* */ device); + + + public static DOrtHardwareDevice_Type OrtHardwareDevice_Type; + public static DOrtHardwareDevice_VendorId OrtHardwareDevice_VendorId; + public static DOrtHardwareDevice_Vendor OrtHardwareDevice_Vendor; + public static DOrtHardwareDevice_DeviceId OrtHardwareDevice_DeviceId; + public static DOrtHardwareDevice_Metadata OrtHardwareDevice_Metadata; + + // + // OrtEpDevice + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpName(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const char* */ DOrtEpDevice_EpVendor(IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpMetadata( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtKeyValuePairs* */ DOrtEpDevice_EpOptions( + IntPtr /* const OrtEpDevice* */ ep_device); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtHardwareDevice* */ DOrtEpDevice_Device( + IntPtr /* const OrtEpDevice* */ ep_device); + + + public static DOrtEpDevice_EpName OrtEpDevice_EpName; + public static DOrtEpDevice_EpVendor OrtEpDevice_EpVendor; + public static DOrtEpDevice_EpMetadata OrtEpDevice_EpMetadata; + public static DOrtEpDevice_EpOptions OrtEpDevice_EpOptions; + public static DOrtEpDevice_Device OrtEpDevice_Device; + + // + // Auto Selection EP registration and selection customization + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name, + byte[] /* const ORTCHAR_T* */ path); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name); + + public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; + public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices( + IntPtr /* const OrtEnv* */ env, + out IntPtr /* const OrtEpDevice* const** */ ep_devices, + out UIntPtr /* size_t* */ num_ep_devices); + + public static DOrtGetEpDevices OrtGetEpDevices; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2( + IntPtr /* OrtSessionOptions* */ sess_options, + IntPtr /* OrtEnv* */ env, + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + IntPtr /* const char* const* */ ep_option_keys, // use OrtKeyValuePairs.GetKeyValuePairHandles + IntPtr /* const char* const* */ ep_option_vals, + UIntPtr /* size_t */ num_ep_options); + + public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr DOrtEpSelectionDelegate( + IntPtr /* OrtEpDevice** */ epDevices, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadata, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadata, + IntPtr /* OrtEpDevice** */ selected, + uint maxSelected, + out UIntPtr numSelected + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy( + IntPtr /* OrtSessionOptions* */ session_options, + int /* OrtExecutionProviderDevicePolicy */ policy, + IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate); + public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy; + + + #endregion + #region Misc API /// /// Queries all the execution providers supported in the native onnxruntime shared library diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index f4b2649f8d055..5c70808b82be1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime @@ -376,6 +377,68 @@ public OrtLoggingLevel EnvLogLevel } } + /// + /// Register an execution provider library with the OrtEnv instance. + /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. + /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can + /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. + /// + /// Coming: A selection policy can be specified and ORT will automatically select the best execution providers + /// and devices for the model. + /// + /// The name to register the library under. + /// The path to the library to register. + /// + /// + public void RegisterExecutionProviderLibrary(string registrationName, string libraryPath) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + var pathUtf8 = NativeOnnxValueHelper.GetPlatformSerializedString(libraryPath); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtRegisterExecutionProviderLibrary(handle, registrationNameUtf8, pathUtf8)); + } + + /// + /// Unregister an execution provider library from the OrtEnv instance. + /// + /// The name the library was registered under. + public void UnregisterExecutionProviderLibrary(string registrationName) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtUnregisterExecutionProviderLibrary(handle, registrationNameUtf8)); + } + + /// + /// Get the list of all execution provider and device combinations that are available. + /// These can be used to select the execution provider and device for a session. + /// + /// + /// + /// + public IReadOnlyList GetEpDevices() + { + IntPtr epDevicesPtr; + UIntPtr numEpDevices; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtGetEpDevices(handle, out epDevicesPtr, out numEpDevices)); + + int count = (int)numEpDevices; + var epDevices = new List(count); + + IntPtr[] epDevicePtrs = new IntPtr[count]; + Marshal.Copy(epDevicesPtr, epDevicePtrs, 0, count); + + foreach (var ptr in epDevicePtrs) + { + epDevices.Add(new OrtEpDevice(ptr)); + } + + return epDevices.AsReadOnly(); + } + #endregion #region SafeHandle overrides diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs new file mode 100644 index 0000000000000..e3947d900214e --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.OnnxRuntime +{ + /// + /// Represents the combination of an execution provider and a hardware device + /// that the execution provider can utilize. + /// + public class OrtEpDevice : SafeHandle + { + /// + /// Construct an OrtEpDevice from an existing native OrtEpDevice instance. + /// + /// Native OrtEpDevice handle. + internal OrtEpDevice(IntPtr epDeviceHandle) + : base(epDeviceHandle, ownsHandle: false) + { + } + + internal IntPtr Handle => handle; + + /// + /// The name of the execution provider. + /// + public string EpName + { + get + { + IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr); + } + } + + /// + /// The vendor who owns the execution provider. + /// + public string EpVendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Execution provider metadata. + /// + public OrtKeyValuePairs EpMetadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(handle)); + } + } + + /// + /// Execution provider options. + /// + public OrtKeyValuePairs EpOptions + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(handle)); + } + } + + /// + /// The hardware device that the execution provider can utilize. + /// + public OrtHardwareDevice HardwareDevice + { + get + { + IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(handle); + return new OrtHardwareDevice(devicePtr); + } + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// No-op. OrtEpDevice is always read-only as the instance is owned by native ORT. + /// + /// True + protected override bool ReleaseHandle() + { + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs new file mode 100644 index 0000000000000..8e7caae90ff79 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Runtime.InteropServices; + + /// + /// Represents the type of hardware device. + /// Matches OrtHardwareDeviceType in the ORT C API. + /// + public enum OrtHardwareDeviceType + { + CPU = 0, + GPU = 1, + NPU = 2, + } + + /// + /// Represents a hardware device that is available on the current system. + /// + public class OrtHardwareDevice : SafeHandle + { + + /// + /// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance. + /// + /// Native OrtHardwareDevice handle. + internal OrtHardwareDevice(IntPtr deviceHandle) + : base(deviceHandle, ownsHandle: false) + { + } + + /// + /// Get the type of hardware device. + /// + public OrtHardwareDeviceType Type + { + get + { + return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(handle); + } + } + + /// + /// Get the vendor ID of the hardware device if known. + /// + /// + /// For PCIe devices the vendor ID is the PCIe vendor ID. See https://pcisig.com/membership/member-companies. + /// + public uint VendorId + { + get + { + return NativeMethods.OrtHardwareDevice_VendorId(handle); + } + } + + /// + /// The vendor (manufacturer) of the hardware device. + /// + public string Vendor + { + get + { + IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(handle); + return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); + } + } + + /// + /// Get the device ID of the hardware device if known. + /// + /// + /// This is the identifier of the device model. + /// PCIe device IDs can be looked up at https://www.pcilookup.com/ when combined with the VendorId. + /// It is NOT a unique identifier for the device in the current system. + /// + public uint DeviceId + { + get + { + return NativeMethods.OrtHardwareDevice_DeviceId(handle); + } + } + + /// + /// Get device metadata. + /// This may include information such as whether a GPU is discrete or integrated. + /// The available metadata will differ by platform and device type. + /// + public OrtKeyValuePairs Metadata + { + get + { + return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(handle)); + } + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// No-op. OrtHardwareDevice is always read-only as the instance is owned by native ORT. + /// + /// True + protected override bool ReleaseHandle() + { + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs new file mode 100644 index 0000000000000..6a8d1037d9017 --- /dev/null +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +namespace Microsoft.ML.OnnxRuntime +{ + using System; + using System.Collections.Generic; + using System.Runtime.InteropServices; + + /// + /// Class to manage key-value pairs. + /// These are most often used for options and metadata. + /// + /// + /// + /// + public class OrtKeyValuePairs : SafeHandle + { + private readonly bool _createdHandle; + + // cache the values here for convenience. + // we could force a call to the C API every time in case something was changed in the background. + private Dictionary _keyValuePairs; + + /// + /// Create a new OrtKeyValuePairs instance. + /// + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs() + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(); + } + + /// + /// Create a new OrtKeyValuePairs instance from an existing native OrtKeyValuePairs handle. + /// + /// Native OrtKeyValuePairs handle. + /// + /// The instance is read-only, so calling Add or Remove will throw an InvalidOperationError. + /// + internal OrtKeyValuePairs(IntPtr constHandle) + : base(constHandle, ownsHandle: false) + { + _createdHandle = false; + _keyValuePairs = GetLatest(); + } + + /// + /// Create a new OrtKeyValuePairs instance from a dictionary. + /// + /// Key-value pairs to add. + /// + /// A backing native instance is created and kept in sync with the C# content. + /// + public OrtKeyValuePairs(IReadOnlyDictionary keyValuePairs) + : base(IntPtr.Zero, ownsHandle: true) + { + NativeMethods.OrtCreateKeyValuePairs(out handle); + _createdHandle = true; + _keyValuePairs = new Dictionary(keyValuePairs != null ? keyValuePairs.Count : 0); + + if (keyValuePairs != null && keyValuePairs.Count > 0) + { + foreach (var kvp in keyValuePairs) + { + Add(kvp.Key, kvp.Value); + } + } + } + + /// + /// Current key-value pair entries. + /// + /// + /// Call Refresh() to update the cached values with the latest from the backing native instance. + /// In general that should not be required as it's not expected an OrtKeyValuePairs instance would be + /// updated by both native and C# code. + /// + public IReadOnlyDictionary Entries => _keyValuePairs; + + /// + /// Adds a key-value pair. Overrides any existing value for the key. + /// + /// Key to add. Must not be null or empty. + /// Value to add. May be empty. Must not be null. + public void Add(string key, string value) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Add)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + var valuePtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(value); + NativeMethods.OrtAddKeyValuePair(handle, keyPtr, valuePtr); + _keyValuePairs[key] = value; // update the cached value + } + + /// + /// Update the cached values with the latest from the backing native instance as that is the source of truth. + /// + public void Refresh() + { + // refresh the cached values. + _keyValuePairs = GetLatest(); + } + + /// + /// Removes a key-value pair by key. Ignores keys that do not exist. + /// + /// Key to remove. + public void Remove(string key) + { + if (!_createdHandle) + { + throw new InvalidOperationException($"{nameof(Remove)} can only be called on instances you created."); + } + + var keyPtr = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(key); + NativeMethods.OrtRemoveKeyValuePair(handle, keyPtr); + + _keyValuePairs.Remove(key); // update the cached value + } + + // for internal usage to pass into the call to OrtSessionOptionsAppendExecutionProvider_V2 + // from SessionOptions::AppendExecutionProvider + internal void GetKeyValuePairHandles(out IntPtr keysHandle, out IntPtr valuesHandle, out UIntPtr numEntries) + { + if (IsInvalid) + { + throw new InvalidOperationException($"{nameof(GetKeyValuePairHandles)}: Invalid instance."); + } + + NativeMethods.OrtGetKeyValuePairs(handle, out keysHandle, out valuesHandle, out numEntries); + } + + /// + /// Fetch all the key/value pairs to make sure we are in sync with the C API. + /// + private Dictionary GetLatest() + { + var dict = new Dictionary(); + if (IsInvalid) + { + return dict; + } + + IntPtr keys, values; + UIntPtr numEntries; + NativeMethods.OrtGetKeyValuePairs(handle, out keys, out values, out numEntries); + + ulong count = numEntries.ToUInt64(); + int offset = 0; + for (ulong i = 0; i < count; i++, offset += IntPtr.Size) + { + IntPtr keyPtr = Marshal.ReadIntPtr(keys, offset); + IntPtr valuePtr = Marshal.ReadIntPtr(values, offset); + var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyPtr); + var value = NativeOnnxValueHelper.StringFromNativeUtf8(valuePtr); + dict.Add(key, value); + } + + return dict; + } + + /// + /// Indicates whether the native handle is invalid. + /// + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// + /// Release the native instance of OrtKeyValuePairs if we own it. + /// + /// true + protected override bool ReleaseHandle() + { + if (_createdHandle) + { + NativeMethods.OrtReleaseKeyValuePairs(handle); + handle = IntPtr.Zero; + } + + return true; + } + } +} \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index 7a5c3aaa19eac..f3c0287d2bf9d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -689,8 +689,8 @@ public static OrtValue CreateTensorValueFromMemory(T[] data, long[] shape) wh /// The method will attempt to pin managed memory so no copying occurs when data is passed down /// to native code. /// - /// Tensor object - /// discovered tensor element type + /// + /// Tensor object /// And instance of OrtValue constructed on top of the object [Experimental("SYSLIB5001")] public static OrtValue CreateTensorValueFromSystemNumericsTensorObject(SystemNumericsTensors.Tensor tensor) where T : unmanaged diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 9b0f183f03681..de6189e105f78 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -32,6 +32,21 @@ public enum ExecutionMode ORT_PARALLEL = 1, } + /// + /// Controls the execution provider selection when using automatic EP selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + public enum ExecutionProviderDevicePolicy + { + DEFAULT = 0, + PREFER_CPU = 1, + PREFER_NPU, + PREFER_GPU, + MAX_PERFORMANCE, + MAX_EFFICIENCY, + MIN_OVERALL_POWER, + } + /// /// Holds the options for creating an InferenceSession /// It forces the instantiation of the OrtEnv singleton. @@ -408,6 +423,82 @@ public void AppendExecutionProvider(string providerName, Dictionary + /// Select execution providers from the list of available execution providers and devices returned by + /// GetEpDevices. + /// + /// One or more OrtEpDevice instances may be provided in epDevices, but must all be for the same + /// execution provider. + /// + /// Make multiple calls to AppendExecutionProvider if you wish to use multiple execution providers. + /// + /// e.g. + /// - if execution provider 'A' has an OrtEpDevice for NPU and one for GPU and you wish to use it for + /// both devices, pass the two OrtEpDevice instances in the epDevices list in one call. + /// - if you wish to use execution provider 'B' for GPU and execution provider 'C' for CPU, + /// make two calls to AppendExecutionProvider, with one OrtEpDevice in the epDevices list in each call. + /// + /// The priority of the execution providers is set by the order in which they are appended. + /// Highest priority is first. + /// + /// OrtEnv that provided the OrtEpDevice instances via a call to GetEpDevices. + /// One or more OrtEpDevice instances to append. + /// These must all have the save EpName value. + /// Optional options to configure the execution provider. May be null. + /// epDevices was empty. + /// + public void AppendExecutionProvider(OrtEnv env, IReadOnlyList epDevices, + IReadOnlyDictionary epOptions) + { + if (epDevices == null || epDevices.Count == 0) + { + throw new ArgumentException("No execution provider devices were specified."); + } + + // Convert EpDevices to native pointers + IntPtr[] epDevicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; i++) + { + epDevicePtrs[i] = epDevices[i].Handle; + } + + if (epOptions != null && epOptions.Count > 0) + { + // this creates an OrtKeyValuePairs instance with a backing native instance + using var kvps = new OrtKeyValuePairs(epOptions); + + // get the native key/value handles so we can pass those straight through to the C API + // and not have to do any special marshaling here. + IntPtr epOptionsKeys, epOptionsValues; + UIntPtr epOptionsCount; + kvps.GetKeyValuePairHandles(out epOptionsKeys, out epOptionsValues, out epOptionsCount); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + epOptionsKeys, + epOptionsValues, + epOptionsCount)); + } + else + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsAppendExecutionProvider_V2( + handle, + env.Handle, + epDevicePtrs, + (UIntPtr)epDevices.Count, + IntPtr.Zero, // EP options keys + IntPtr.Zero, // EP options values + UIntPtr.Zero)); // EP options count + } + + } + #endregion //ExecutionProviderAppends #region Public Methods @@ -452,8 +543,8 @@ public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHand // End result of that is // SessionOptions.RegisterCustomOpLibrary calls NativeMethods.OrtRegisterCustomOpsLibrary_V2 // SessionOptions.RegisterCustomOpLibraryV2 calls NativeMethods.OrtRegisterCustomOpsLibrary - var utf8Path = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); - NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, utf8Path, + var platformPath = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath); + NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, platformPath, out libraryHandle)); } @@ -536,6 +627,18 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) var utf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName); NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, utf8, dimValue)); } + + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Policy to use. + public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy, IntPtr.Zero)); + } + #endregion internal IntPtr Handle diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs new file mode 100644 index 0000000000000..72c165df56418 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Globalization; +using System.Runtime.InteropServices; +using Xunit; + + +public class CompileApiTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void BasicUsage() + { + var so = new SessionOptions(); + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.SetInputModelPath("model.onnx"); + compileOptions.SetOutputModelPath("compiled_model.onnx"); + + compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); + compileOptions.SetEpContextEmbedMode(true); + + } + + // setup a new instance as SetOutputModelExternalInitializersFile is incompatible with SetOutputModelBuffer + using (var compileOptions = new OrtModelCompilationOptions(so)) + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + compileOptions.SetInputModelFromBuffer(model); + + // SetOutputModelBuffer updates the user provided IntPtr and size when it allocates data post-compile. + // Due to that we need to allocate an IntPtr and UIntPtr here. + IntPtr bytePtr = new IntPtr(); + UIntPtr bytesSize = new UIntPtr(); + var allocator = OrtAllocator.DefaultInstance; + compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + + compileOptions.CompileModel(); + + Assert.NotEqual(IntPtr.Zero, bytePtr); + Assert.NotEqual(UIntPtr.Zero, bytesSize); + + byte[] compiledBytes = new byte[bytesSize.ToUInt64()]; + Marshal.Copy(bytePtr, compiledBytes, 0, (int)bytesSize.ToUInt32()); + + // Check the compiled model is valid + using (var session = new InferenceSession(compiledBytes, so)) + { + Assert.NotNull(session); + } + + allocator.FreeMemory(bytePtr); + } + } +} + +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs new file mode 100644 index 0000000000000..1aa4db15d275c --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using System.IO; +using System.Runtime.InteropServices; +using Xunit; +using System.Collections.Generic; + +/// +/// Tests for auto ep selection/registration. +/// Includes testing of OrtHardwareDevice and OrtEpDevice as those only come from auto ep related code and we only +/// get read-only access to them (i.e. we can't directly create instances of them to test). +/// +public class OrtAutoEpTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private void ReadHardwareDeviceValues(OrtHardwareDevice device) + { + Assert.True(device.Type == OrtHardwareDeviceType.CPU || + device.Type == OrtHardwareDeviceType.GPU || + device.Type == OrtHardwareDeviceType.NPU); + if (device.Type == OrtHardwareDeviceType.CPU) + { + Assert.NotEmpty(device.Vendor); + } + else + { + Assert.True(device.VendorId != 0); + Assert.True(device.DeviceId != 0); + } + + var metadata = device.Metadata; + Assert.NotNull(metadata); + foreach (var kvp in metadata.Entries) + { + Assert.NotEmpty(kvp.Key); + // Assert.NotEmpty(kvp.Value); this is allowed + } + } + + [Fact] + public void GetEpDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + foreach (var ep_device in epDevices) + { + Assert.NotEmpty(ep_device.EpName); + Assert.NotEmpty(ep_device.EpVendor); + var metadata = ep_device.EpMetadata; + Assert.NotNull(metadata); + var options = ep_device.EpOptions; + Assert.NotNull(options); + ReadHardwareDeviceValues(ep_device.HardwareDevice); + } + } + + [Fact] + public void RegisterUnregisterLibrary() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + // register. shouldn't throw + ortEnvInstance.RegisterExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.True(found); + + // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + } + } + + [Fact] + public void AppendToSessionOptionsV2() + { + var runTest = (Func> getEpOptions) => + { + SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't break. + List selectedEpDevices = epDevices.Where(d => d.EpName == "CPUExecutionProvider").ToList(); + + Dictionary epOptions = getEpOptions(); + sessionOptions.AppendExecutionProvider(ortEnvInstance, selectedEpDevices, epOptions); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model)) + { + Assert.NotNull(session); + } + }; + + runTest(() => + { + // null options + return null; + }); + + runTest(() => + { + // empty options + return new Dictionary(); + }); + + runTest(() => + { + // dummy options + return new Dictionary + { + { "random_key", "value" }, + }; + }); + } + + [Fact] + public void SetEpSelectionPolicy() + { + SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicy(ExecutionProviderDevicePolicy.PREFER_GPU); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model)) + { + Assert.NotNull(session); + } + } +} +#endif \ No newline at end of file diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs new file mode 100644 index 0000000000000..b89b970688d5f --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtKeyValuePairsTests.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using Xunit; + +namespace Microsoft.ML.OnnxRuntime.Tests; + +public class OrtKeyValuePairsTests +{ + private OrtEnv ortEnvInstance = OrtEnv.Instance(); + + + [Fact] + public void CRUD() + { + using var kvp = new OrtKeyValuePairs(); + kvp.Add("key1", "value1"); + kvp.Add("key2", "value2"); + kvp.Add("key3", ""); // allowed + + Assert.Equal("value1", kvp.Entries["key1"]); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + + kvp.Remove("key1"); + Assert.False(kvp.Entries.ContainsKey("key1")); + + kvp.Remove("invalid_key"); // shouldn't break + + Assert.Equal(2, kvp.Entries.Count); + + // refresh from the C API to make sure everything is in sync + kvp.Refresh(); + Assert.Equal(2, kvp.Entries.Count); + Assert.Equal("value2", kvp.Entries["key2"]); + Assert.Equal("", kvp.Entries["key3"]); + } +} diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj index a8abcd2b4aa1c..ee3c8c69aa2ae 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp.csproj @@ -70,7 +70,8 @@ + $(NativeBuildOutputDir)\custom_op_library*.dll; + $(NativeBuildOutputDir)\example_plugin_ep.dll"> PreserveNewest false diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 7928f9b822cf0..7be518a39480f 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -562,8 +562,9 @@ ORT_API(void, GetKeyValuePairs, _In_ const OrtKeyValuePairs* kvps, ORT_API(void, RemoveKeyValuePair, _In_ OrtKeyValuePairs* kvps, _In_ const char* key); ORT_API(void, ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs*); -ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* ep_name, const ORTCHAR_T* path); -ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* ep_name); +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, const char* registration_name, + const ORTCHAR_T* path); +ORT_API_STATUS_IMPL(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); ORT_API_STATUS_IMPL(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); From 6fa8ba107f461d72073ee4842830a9d0a96d2546 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Sat, 3 May 2025 14:33:57 -0700 Subject: [PATCH 04/84] Remove neural_compressor dependency in MatMulNBits (#24627) ### Description As titled. ### Motivation and Context Dependency no need. --- .../quantization/matmul_nbits_quantizer.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py index ef08b56cfe7ad..0297472d0738c 100644 --- a/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py @@ -8,7 +8,6 @@ import argparse import copy -import importlib import logging import os @@ -16,7 +15,6 @@ import numpy.typing as npt import onnx from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto -from packaging import version from onnxruntime.capi._pybind_state import quantize_matmul_4bits, quantize_matmul_8bits, quantize_qdq_matmul_4bits @@ -1356,21 +1354,7 @@ def process(self): self.model = ONNXModel(self.model) # Ensure the model is wrapped back into ONNXModel self.model.clean_initializers() else: - # use IntelĀ® Neural Compressor for RTN or GPTQ weight-only quantize algorithm - try: - importlib.import_module("neural_compressor") - except Exception as e: - logging.error(f"{e}.") - raise RuntimeError( - "neural-compressor is not correctly installed. Please check your environment." - ) from e - - import neural_compressor - - assert version.parse(neural_compressor.__version__) >= version.parse("2.3.2"), ( - "Require neural-compressor >= 2.3.2 to support weight only quantization!" - ) - + # RTN or GPTQ weight-only quantize algorithm self.int4_quant_algo() From ade008ed4c02bf3c6b2df96bf6d30d8ec6cc22ce Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Sun, 4 May 2025 00:57:40 -0700 Subject: [PATCH 05/84] [QNN EP] Enable automatic selection of QNN EP for PREFER_NPU policy (#24629) ### Description - Enables automatic selection of QNN EP for PREFER_NPU policy - Fixes cpuid vendor id for Qualcomm to be `'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24);` Sample code from unit test: ```c++ // Tests autoEP feature to automatically select an EP that supports the NPU. // Currently only works on Windows. TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider, ORT_TSTR("onnxruntime_providers_qnn.dll"))); Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_NPU); const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx"; Ort::Session session(*ort_env, ort_model_path, so); EXPECT_TRUE(SessionHasEp(session, kQnnExecutionProvider)); ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider)); } ``` ### Motivation and Context A recent feature allows ORT to automatically select an EP according to policies set by the user (e.g., prefer npu or prefer gpu). This PR allows QNN EP to be potentially selected when the user sets the `PREFER_NPU` policy. --- onnxruntime/core/common/cpuid_info.cc | 2 +- .../providers/qnn/qnn_provider_factory.cc | 138 ++++++++++++++++++ onnxruntime/core/providers/qnn/symbols.def | 2 + .../test/providers/qnn/qnn_basic_test.cc | 59 +++++--- 4 files changed, 181 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index ee7782e3c8763..91961bf22ce1e 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -155,7 +155,7 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { if (vendor == "GenuineIntel") return 0x8086; if (vendor == "GenuineAMD") return 0x1022; - if (vendor.find("Qualcomm") == 0) return 'Q' << 24 | 'C' << 16 | 'O' << 8 | 'M'; + if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); if (vendor.find("NV") == 0) return 0x10DE; return 0; } diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 7b92a23e428eb..b2f289448b013 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -79,6 +79,27 @@ struct QNN_Provider : Provider { return std::make_shared(*provider_options, config_options); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t num_devices, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + if (num_devices != 1) { + return Status(common::ONNXRUNTIME, ORT_EP_FAIL, "QNN EP only supports one device."); + } + + const ConfigOptions* config_options = &session_options.GetConfigOptions(); + + std::array configs_array = {&provider_options, config_options}; + const void* arg = reinterpret_cast(&configs_array); + auto ep_factory = CreateExecutionProviderFactory(arg); + ep = ep_factory->CreateProvider(session_options, logger); + + return Status::OK(); + } + void Initialize() override {} void Shutdown() override {} } g_provider; @@ -93,4 +114,121 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } } + +#include "core/framework/error_code_helper.h" + +// OrtEpApi infrastructure to be able to use the QNN EP as an OrtEpFactory for auto EP selection. +struct QnnEpFactory : OrtEpFactory { + QnnEpFactory(const OrtApi& ort_api_in, + const char* ep_name, + OrtHardwareDeviceType hw_type, + const char* qnn_backend_type) + : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} { + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + // Returns the name for the EP. Each unique factory configuration must have a unique name. + // Ex: a factory that supports NPU should have a different than a factory that supports GPU. + static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->ep_name.c_str(); + } + + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->vendor.c_str(); + } + + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. + // An EP created with this factory is expected to be able to execute a model with *all* supported + // hardware devices at once. A single instance of QNN EP is not currently setup to partition a model among + // multiple different QNN backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only + // support one backend: npu. To support a different backend, like gpu, create a different factory instance + // that only supports GPU. + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); + factory->ort_api.AddKeyValuePair(ep_options, "backend_type", factory->qnn_backend_type.c_str()); + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "QNN EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + const std::string ep_name; // EP name + const std::string vendor{"Microsoft"}; // EP vendor name + + // Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List + const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}; + const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice + const std::string qnn_backend_type; // QNN backend type for OrtHardwareDevice +}; + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + + // Factory could use registration_name or define its own EP name. + auto factory_npu = std::make_unique(*ort_api, + onnxruntime::kQnnExecutionProvider, + OrtHardwareDeviceType_NPU, "htp"); + + // If want to support GPU, create a new factory instance because QNN EP is not currently setup to partition a single model + // among heterogeneous devices. + // std::unique_ptr factory_gpu = std::make_unique(*ort_api, "QNNExecutionProvider_GPU", OrtHardwareDeviceType_GPU, "gpu"); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory_npu.release(); + *num_factories = 1; + + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} #endif // !BUILD_QNN_EP_STATIC_LIB diff --git a/onnxruntime/core/providers/qnn/symbols.def b/onnxruntime/core/providers/qnn/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/qnn/symbols.def +++ b/onnxruntime/core/providers/qnn/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index f736abcd3006d..0212dacadbced 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -16,6 +16,7 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "test/providers/qnn/qnn_test_utils.h" +#include "test/util/include/api_asserts.h" #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -37,24 +38,24 @@ namespace test { // TODO: When we need QNN in a minimal build we should add an ORT format version of the model #if !defined(ORT_MINIMAL_BUILD) +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + // Tests that the QNN EP is registered when added via the public C++ API. // Loads a simple ONNX model that adds floats. TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { - auto session_has_qnn_ep = [](Ort::Session& session) -> bool { - // Access the underlying InferenceSession. - const OrtSession* ort_session = session; - const InferenceSession* s = reinterpret_cast(ort_session); - bool have_qnn_ep = false; - - for (const auto& provider : s->GetRegisteredProviderTypes()) { - if (provider == kQnnExecutionProvider) { - have_qnn_ep = true; - break; - } - } - return have_qnn_ep; - }; - onnxruntime::ProviderOptions options; #if defined(_WIN32) options["backend_path"] = "QnnHtp.dll"; @@ -77,8 +78,9 @@ TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); - ASSERT_TRUE(session_has_qnn_ep(session)) << "QNN EP was not found in registered providers for session " - << "when added to session with name 'QNN'"; + ASSERT_TRUE(SessionHasEp(session, kQnnExecutionProvider)) + << "QNN EP was not found in registered providers for session " + << "providers for session when added to session with name 'QNN'"; } { @@ -92,8 +94,9 @@ TEST_F(QnnHTPBackendTests, TestAddEpUsingPublicApi) { so.AppendExecutionProvider(kQnnExecutionProvider, options); Ort::Session session(*ort_env, ort_model_path, so); - ASSERT_TRUE(session_has_qnn_ep(session)) << "QNN EP was not found in registered providers for session " - << "when added to session with name '" << kQnnExecutionProvider << "'"; + ASSERT_TRUE(SessionHasEp(session, kQnnExecutionProvider)) + << "QNN EP was not found in registered providers for session " + << "when added to session with name '" << kQnnExecutionProvider << "'"; } } @@ -1265,6 +1268,24 @@ TEST_F(QnnHTPBackendTests, LoadingAndUnloadingOfQnnLibrary_FixSegFault) { } #endif // !BUILD_QNN_EP_STATIC_LIB +#if defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB +// Tests autoEP feature to automatically select an EP that supports the NPU. +// Currently only works on Windows. +TEST_F(QnnHTPBackendTests, AutoEp_PreferNpu) { + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider, + ORT_TSTR("onnxruntime_providers_qnn.dll"))); + + Ort::SessionOptions so; + so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_NPU); + + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx"; + Ort::Session session(*ort_env, ort_model_path, so); + EXPECT_TRUE(SessionHasEp(session, kQnnExecutionProvider)); + + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, kQnnExecutionProvider)); +} +#endif // defined(WIN32) && !BUILD_QNN_EP_STATIC_LIB + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) From 74dcf7e296639095dfa55d31336998b6f719ed76 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 5 May 2025 15:40:22 +1000 Subject: [PATCH 06/84] =?UTF-8?q?Fix=20OrtEpDevices=20sort.=20Debug=20asse?= =?UTF-8?q?rtion=20when=20there=20are=20two=20devices=20of=20the=20same=20?= =?UTF-8?q?type=E2=80=A6=20(#24633)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Fix debug assertion when there are two devices of the same type that don't match the vendor. e.g. WebGPU and DML. ### Motivation and Context --- onnxruntime/core/session/provider_policy_context.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 565891fe2cdfb..4ce13fe36ea86 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -95,7 +95,7 @@ std::vector OrderDevices(const std::vector Date: Mon, 5 May 2025 18:19:45 -0700 Subject: [PATCH 07/84] fix checks for metal when running under wasm (#24637) When under wasm we can't check for metal by looking at backend because it will always be WEBGPU. Because of this we'll take the DP4A path on metal that results in sub-optimal performance. Use vendor to check for metal instead. --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 3 ++- .../webgpu/quantization/subgroup_matrix_matmul_nbits.cc | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index e3353c921094a..1411c07e97b2a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -513,8 +513,9 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, bool has_zero_points) { // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + // Use 'vendor' to check for metal; 'backend' is always WEBGPU when running under wasm. bool use_dp4a = context.HasFeature(wgpu::FeatureName::Subgroups) && - context.AdapterInfo().backendType != wgpu::BackendType::Metal; + context.AdapterInfo().vendor != std::string_view{"apple"}; return (accuracy_level == 4 && block_size % 32 == 0 && batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 && !has_zero_points && use_dp4a); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 09650be9358d0..674473a173445 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -248,8 +248,8 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, // FP322 is around 7s. - return context.AdapterInfo().backendType == wgpu::BackendType::Metal && - has_subgroup_matrix && + return has_subgroup_matrix && + context.AdapterInfo().vendor == std::string_view{"apple"} && accuracy_level == 4 && block_size == 32 && batch_count == 1 && From 8bf5362a24300f60f5a6344fb1f493a1d847e94c Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 5 May 2025 20:52:23 -0700 Subject: [PATCH 08/84] Enable use_vcpkg for QNN Nuget package build and Python arm64ec build (#24640) ### Description enable use_vcpkg for QNN Nuget package build and Python arm64ec build --- .../github/azure-pipelines/templates/py-win-arm64ec-qnn.yml | 2 +- tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 6df46bfc8e1b0..1a00d67bdbb2a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -91,7 +91,7 @@ jobs: --use_qnn --qnn_home $(QnnSDKRootDir) --enable_pybind - --parallel --update --arm64ec + --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --update --arm64ec $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index e4888ffd62df3..6bfc00b5b46eb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -20,7 +20,7 @@ stages: name: ${{ parameters.qnn_ep_build_pool_name }} variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} - commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_binskim_compliant_compile_flags ' + commonBuildArgs: '--compile_no_warning_as_error --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ' steps: - template: set-version-number-variables-step.yml From 1ef7b1b7af5b905441f190a9602e1613a618ffeb Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 5 May 2025 21:00:47 -0700 Subject: [PATCH 09/84] Publish debug symbols for windows (#24643) --- ...artifacts-package-and-publish-steps-windows.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index f15a2992e0d00..8c3a9eba82356 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -31,6 +31,20 @@ parameters: default: true steps: + - task: PublishSymbols@2 + displayName: 'Publish Build Debug Symbols' + condition: and(succeeded(), or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))) + inputs: + SymbolsFolder: '$(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\' + SearchPattern: '*.pdb' + IndexSources: true + PublishSymbols: true + SymbolServerType: 'TeamServices' + SymbolExpirationInDays: '36530' + IndexableFileFormats: 'Default' + DetailedLog: true + SymbolsArtifactName: 'Symbols_${{parameters.buildConfig}}' + - task: CmdLine@2 displayName: 'Copy build artifacts for zipping' inputs: From 9e7fcef4988ff3d53f75735d1d4ad7181a410912 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 5 May 2025 21:02:51 -0700 Subject: [PATCH 10/84] Add Python bindings for "compile" and "auto EP" APIs (#24614) ### Description Python API updates for auto ep selection and the compilation API. - Adds Python API `SessionOptions.add_provider()` (equivalent to C API's `SessionOptionsAppendExecutionProvider`) - Adds Python API `SessionOptions.add_provider_for_devices()` (equivalent to C API's `SessionOptionsAppendExecutionProvider_V2`) - Adds Python API `SessionOptions.set_provider_selection_policy()` (equivalent to C API's `SessionOptionsSetEpSelectionPolicy`) - Adds Python API class `ModelCompiler` to compile models (wraps C API's `OrtModelCompilationOptions` and `CompileModel()`) - TODO: Finish delegate callback. Need to add a `void*` parameter to delegate function. ### Sample program that uses autoep APIs Adapted from a unit test. ```python def test_cuda_prefer_gpu_and_inference(self): """ Test selecting CUDA EP via the PREFER_GPU policy and running inference. """ ep_lib_path = "onnxruntime_providers_cuda.dll" ep_registration_name = "CUDAExecutionProvider" if sys.platform != "win32": self.skipTest("Skipping test because device discovery is only supported on Windows") if not os.path.exists(ep_lib_path): self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") onnxrt.register_execution_provider_library(ep_registration_name, os.path.realpath(ep_lib_path)) # Set a policy to prefer GPU. Cuda should be selected. sess_options = onnxrt.SessionOptions() sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) self.assertTrue(sess_options.has_providers()) # Run sample model and check output sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) input_name = sess.get_inputs()[0].name res = sess.run([], {input_name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) ``` ### Sample program that uses compile APIs Adapted from a unit test that compiles using EP selection policy. ```python def test_compile_with_files_prefer_npu_policy(self): """ Tests compiling a model (to/from files) using an EP selection policy (PREFER_NPU). """ ep_lib_path = "onnxruntime_providers_qnn.dll" ep_registration_name = "QNNExecutionProvider" onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path) input_model_path = get_name("nhwc_resize_scales_opset18.onnx") output_model_path = os.path.join(self._tmp_dir_path, "model.compiled0.onnx") session_options = onnxrt.SessionOptions() session_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU) model_compiler = onnxrt.ModelCompiler( session_options, input_model_path, embed_compiled_data_into_model=True, external_initializers_file_path=None, ) model_compiler.compile_to_file(output_model_path) self.assertTrue(os.path.exists(output_model_path)) onnxrt.unregister_execution_provider_library(ep_registration_name) ``` Adapted from a unit test that compiles using explicit EPs. ```python def test_compile_with_input_and_output_files(self): """ Tests compiling a model (to/from files) using explicit EP. """ provider = None provider_options = dict() if "QNNExecutionProvider" in available_providers: provider = "QNNExecutionProvider" provider_options["backend_type"] = "htp" # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) input_model_path = get_name("nhwc_resize_scales_opset18.onnx") output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") session_options = onnxrt.SessionOptions() if provider: session_options.add_provider(provider, provider_options) model_compiler = onnxrt.ModelCompiler( session_options, input_model_path, embed_compiled_data_into_model=True, external_initializers_file_path=None, ) model_compiler.compile_to_file(output_model_path) self.assertTrue(os.path.exists(output_model_path)) ``` ### Motivation and Context --- onnxruntime/__init__.py | 8 + onnxruntime/core/session/compile_api.cc | 32 +- .../core/session/model_compilation_options.cc | 11 +- .../core/session/model_compilation_options.h | 10 +- onnxruntime/core/session/onnxruntime_c_api.cc | 53 +- onnxruntime/core/session/utils.cc | 131 ++++- onnxruntime/core/session/utils.h | 32 ++ .../onnxruntime_inference_collection.py | 126 ++++- .../python/onnxruntime_pybind_exceptions.cc | 6 + .../python/onnxruntime_pybind_exceptions.h | 6 + .../onnxruntime_pybind_model_compiler.cc | 80 +++ .../onnxruntime_pybind_model_compiler.h | 78 +++ .../python/onnxruntime_pybind_state.cc | 468 ++++++++++++++++-- .../python/onnxruntime_pybind_state_common.h | 23 +- onnxruntime/test/python/autoep_helper.py | 67 +++ .../test/python/onnxruntime_test_python.py | 19 +- .../python/onnxruntime_test_python_autoep.py | 167 +++++++ .../python/onnxruntime_test_python_backend.py | 16 +- .../onnxruntime_test_python_compile_api.py | 185 +++++++ tools/ci_build/build.py | 9 + 20 files changed, 1376 insertions(+), 151 deletions(-) create mode 100644 onnxruntime/python/onnxruntime_pybind_model_compiler.cc create mode 100644 onnxruntime/python/onnxruntime_pybind_model_compiler.h create mode 100644 onnxruntime/test/python/autoep_helper.py create mode 100644 onnxruntime/test/python/onnxruntime_test_python_autoep.py create mode 100644 onnxruntime/test/python/onnxruntime_test_python_compile_api.py diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index f8e13f7629e6b..6ef0707f4b7c6 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -30,6 +30,10 @@ NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 + OrtEpDevice, # noqa: F401 + OrtExecutionProviderDevicePolicy, # noqa: F401 + OrtHardwareDevice, # noqa: F401 + OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 @@ -44,11 +48,14 @@ get_available_providers, # noqa: F401 get_build_info, # noqa: F401 get_device, # noqa: F401 + get_ep_devices, # noqa: F401 get_version_string, # noqa: F401 has_collective_ops, # noqa: F401 + register_execution_provider_library, # noqa: F401 set_default_logger_severity, # noqa: F401 set_default_logger_verbosity, # noqa: F401 set_seed, # noqa: F401 + unregister_execution_provider_library, # noqa: F401 ) import_capi_exception = None @@ -64,6 +71,7 @@ AdapterFormat, # noqa: F401 InferenceSession, # noqa: F401 IOBinding, # noqa: F401 + ModelCompiler, # noqa: F401 OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index a3f6addd100ad..ad128fee6cc3d 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -8,11 +8,13 @@ #include #include "core/common/common.h" +#include "core/session/allocator_adapters.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" #include "core/session/model_compilation_options.h" #include "core/session/ort_apis.h" +#include "core/session/ort_env.h" #include "core/session/utils.h" #else #include "core/common/common.h" @@ -43,7 +45,8 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CreateModelCompilationOptionsFromSessionOptio return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "The session_options argument must be a non-null pointer"); } - auto model_compile_options = std::make_unique(*env, *session_options); + auto model_compile_options = std::make_unique(env->GetEnvironment(), + *session_options); *out = reinterpret_cast(model_compile_options.release()); return nullptr; #else @@ -150,7 +153,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExterna ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelCompilationOptions* ort_model_compile_options, - _Inout_ OrtAllocator* allocator, void** output_model_data_ptr, size_t* output_model_data_size_ptr) { + _Inout_ OrtAllocator* ort_allocator, void** output_model_data_ptr, size_t* output_model_data_size_ptr) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -163,17 +166,18 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output model buffer: size pointer is null"); } - if (allocator == nullptr) { + if (ort_allocator == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid allocator for output model buffer: allocator pointer is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetOutputModelBuffer(allocator, + onnxruntime::AllocatorPtr allocator = std::make_shared(ort_allocator); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetOutputModelBuffer(std::move(allocator), output_model_data_ptr, output_model_data_size_ptr)); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); - ORT_UNUSED_PARAMETER(allocator); + ORT_UNUSED_PARAMETER(ort_allocator); ORT_UNUSED_PARAMETER(output_model_data_ptr); ORT_UNUSED_PARAMETER(output_model_data_size_ptr); return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); @@ -202,23 +206,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); - ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->Check()); - - std::unique_ptr session; - const OrtSessionOptions* session_options = &model_compile_options->GetSessionOptions(); - - if (model_compile_options->InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options->GetInputModelPath()); - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(session_options, env, - input_model_path.c_str(), - nullptr, 0, session)); - } else { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(session_options, env, nullptr, - model_compile_options->GetInputModelData(), - model_compile_options->GetInputModelDataSize(), session)); - } - - ORT_API_RETURN_IF_ERROR(InitializeSession(session_options, *session)); + ORT_API_RETURN_IF_STATUS_NOT_OK(onnxruntime::CompileModel(env->GetEnvironment(), *model_compile_options)); return nullptr; #else ORT_UNUSED_PARAMETER(env); diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index c4a7c5262d03d..d0cb092f78843 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -8,12 +8,12 @@ #include #include -#include "core/session/allocator_adapters.h" +#include "core/framework/allocator.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/ort_env.h" +#include "core/session/environment.h" namespace onnxruntime { -ModelCompilationOptions::ModelCompilationOptions(const OrtEnv& env, const OrtSessionOptions& session_options) +ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& env, const OrtSessionOptions& session_options) : env_(env), session_options_(session_options) { session_options_.value.has_explicit_ep_context_gen_options = true; session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); @@ -86,15 +86,14 @@ void ModelCompilationOptions::SetOutputModelExternalInitializersFile(const std:: external_initializer_size_threshold; } -Status ModelCompilationOptions::SetOutputModelBuffer(OrtAllocator* allocator, +Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) { ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = - std::make_shared(allocator); + session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); return Status::OK(); } diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 5ee64d48c3060..9238264003645 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -8,11 +8,13 @@ #include #include "core/common/status.h" #include "core/common/path_string.h" +#include "core/framework/allocator.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { +class Environment; /// /// Stores options to compile ONNX models into "EPContext" models. @@ -23,9 +25,9 @@ class ModelCompilationOptions { /// Creates an instance with the session options to use for model compilation. /// The session options are expected to have execution providers that compile. /// - /// Reference to OrtEnv + /// Reference to Environment /// Reference to session options - ModelCompilationOptions(const OrtEnv& env, const OrtSessionOptions& session_options); + ModelCompilationOptions(const onnxruntime::Environment& env, const OrtSessionOptions& session_options); /// /// Sets the file path to the input ONNX model to compile. @@ -67,7 +69,7 @@ class ModelCompilationOptions { /// Pointer to the buffer that will contain the compiled model /// Set to the size of the buffer /// Status indicating potential error - Status SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, + Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); /// @@ -122,7 +124,7 @@ class ModelCompilationOptions { Status CheckInputModelSettings() const; Status CheckOutputModelSettings() const; - const OrtEnv& env_; + const onnxruntime::Environment& env_; OrtSessionOptions session_options_; std::string input_model_path_; const void* input_model_data_ = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c70075b234faf..304966605c9cf 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2465,49 +2465,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options) { API_IMPL_BEGIN - if (num_ep_devices > 1) { - const auto& ep_name = ep_devices[0]->ep_name; - bool all_match = std::all_of(ep_devices + 1, ep_devices + num_ep_devices, - [&ep_name](const OrtEpDevice* ep_device) { return ep_device->ep_name == ep_name; }); - if (!all_match) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "All OrtEpDevice values in ep_devices must have the same execution provider."); - } - } - - EpFactoryInternal* internal_factory = nullptr; - for (size_t i = 0; i < num_ep_devices; ++i) { - const OrtEpDevice* entry = ep_devices[i]; - - // we expect the internal factory to be available for internal and provider bridge EPs, which is all we support. - internal_factory = env->GetEnvironment().GetEpFactoryInternal(entry->ep_factory); - if (!internal_factory) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "EP is not currently supported by this API"); - } - - // add the options to the session options with the EP prefix. - // first add the default values with prefix followed by user specified values so those win - const auto prefix = OrtSessionOptions::GetProviderOptionPrefix(entry->ep_name.c_str()); - auto& config_options = session_options->value.config_options; - for (const auto& [key, value] : entry->ep_options.entries) { - ORT_API_RETURN_IF_STATUS_NOT_OK(config_options.AddConfigEntry((prefix + key).c_str(), value.c_str())); - } - - for (size_t j = 0; j < num_ep_options; ++j) { - if (ep_option_keys[j] == nullptr) { - continue; - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), - ep_option_vals[j])); - } - } - - if (internal_factory) { - session_options->provider_factories.push_back( - std::make_unique( - *internal_factory, std::vector(ep_devices, ep_devices + num_ep_devices))); - } + std::unique_ptr provider_factory = nullptr; + + ORT_API_RETURN_IF_STATUS_NOT_OK(CreateIExecutionProviderFactoryForEpDevices( + env->GetEnvironment(), + session_options->value, + gsl::span(ep_devices, num_ep_devices), + gsl::span(ep_option_keys, num_ep_options), + gsl::span(ep_option_vals, num_ep_options), + /*output*/ provider_factory)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; API_IMPL_END diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d17514e54a945..c05394039d8c7 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -10,15 +10,18 @@ #include "core/session/environment.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" -#include "core/session/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/session/onnxruntime_session_options_config_keys.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/ep_factory_internal.h" +#include "core/session/ep_library_plugin.h" +#include "core/session/ep_library_provider_bridge.h" +#include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" +#endif // !defined(ORT_MINIMAL_BUILD) using namespace onnxruntime; #if !defined(ORT_MINIMAL_BUILD) @@ -120,13 +123,14 @@ common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); } -// provider either model_path, or modal_data + model_data_length. -OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { +// Internal function that creates an InferenceSession and loads the model. +// Caller should provide either model_path, or modal_data + model_data_length. +static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* options, + const onnxruntime::Environment& env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { // quick check here to decide load path. InferenceSession will provide error message for invalid values. // TODO: Could move to a helper const Env& os_env = Env::Default(); // OS environment (!= ORT environment) @@ -155,12 +159,12 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, if (model_path != nullptr) { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), + env, model_path); } else { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), + env, model_data, static_cast(model_data_length)); } #else @@ -169,20 +173,20 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, } else { sess = std::make_unique( options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); + env); } #if !defined(ORT_MINIMAL_BUILD) // TEMPORARY for testing. Manually specify the EP to select. auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); if (auto_select_ep_name) { - ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env->GetEnvironment(), *sess, *auto_select_ep_name)); + ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env, *sess, *auto_select_ep_name)); } // if there are no providers registered, and there's an ep selection policy set, do auto ep selection if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { ProviderPolicyContext context; - ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env->GetEnvironment(), *options, *sess)); + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env, *options, *sess)); } #endif // !defined(ORT_MINIMAL_BUILD) @@ -209,6 +213,17 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, return nullptr; } +// Creates an InferenceSession and loads the model. +// Caller should provide either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + return CreateSessionAndLoadModelImpl(options, env->GetEnvironment(), model_path, model_data, model_data_length, sess); +} + OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { @@ -245,6 +260,28 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options) { + ORT_RETURN_IF_ERROR(model_compile_options.Check()); + + std::unique_ptr session; + const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); + + if (model_compile_options.InputModelComesFromFile()) { + PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, + input_model_path.c_str(), + nullptr, 0, session))); + } else { + ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, nullptr, + model_compile_options.GetInputModelData(), + model_compile_options.GetInputModelDataSize(), + session))); + } + + ORT_RETURN_IF_ERROR(ToStatus(InitializeSession(session_options, *session))); + return Status::OK(); +} + Status LoadPluginOrProviderBridge(const std::string& registration_name, const ORTCHAR_T* library_path, std::unique_ptr& ep_library, @@ -283,5 +320,65 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, return Status::OK(); } -#endif + +Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, + SessionOptions& session_options, + gsl::span ep_devices, + gsl::span ep_option_keys, + gsl::span ep_option_vals, + /*output*/ std::unique_ptr& out) { + if (ep_devices.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Must provide one or more OrtEpDevice instances."); + } + + const size_t num_ep_options = ep_option_keys.size(); + if (ep_option_vals.size() != num_ep_options) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Must provide the same number of keys and values for EP options."); + } + + const auto& ep_name = ep_devices[0]->ep_name; + bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), + [&ep_name](const OrtEpDevice* ep_device) { return ep_device->ep_name == ep_name; }); + if (!all_match) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "All OrtEpDevice values in ep_devices must have the same execution provider."); + } + + EpFactoryInternal* internal_factory = nullptr; + for (const OrtEpDevice* ep_device : ep_devices) { + // we expect the internal factory to be available for internal and provider bridge EPs, which is all we support. + internal_factory = env.GetEpFactoryInternal(ep_device->ep_factory); + if (!internal_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); + } + + // add the options to the session options with the EP prefix. + // first add the default values with prefix followed by user specified values so those win + const std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_device->ep_name.c_str()); + auto& config_options = session_options.config_options; + for (const auto& [key, value] : ep_device->ep_options.entries) { + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + key).c_str(), value.c_str())); + } + + for (size_t j = 0; j < num_ep_options; ++j) { + if (ep_option_keys[j] == nullptr) { + continue; + } + + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); + } + } + + if (!internal_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EP is not currently supported by this API"); + } + + out = std::make_unique(*internal_factory, + std::vector(ep_devices.begin(), + ep_devices.end())); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 535a7b041609d..5a5dcae9165ed 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -3,7 +3,10 @@ #pragma once +#include +#include #include +#include #include "core/common/common.h" #include "core/session/onnxruntime_c_api.h" @@ -14,9 +17,18 @@ struct OrtStatus; struct OrtPrepackedWeightsContainer; namespace onnxruntime { class InferenceSession; +class ModelCompilationOptions; +} // namespace onnxruntime + +#if !defined(ORT_MINIMAL_BUILD) +namespace onnxruntime { +class Environment; class EpLibrary; class EpFactoryInternal; +struct IExecutionProviderFactory; +struct SessionOptions; } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, _In_ const OrtEnv* env, @@ -29,8 +41,18 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, _In_ onnxruntime::InferenceSession& sess, _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); +#if !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { +/// +/// Compiles an ONNX model into a model with EPContext nodes. Each EPContext node represents a subgraph compiled for +/// a specific execution provider. +/// +/// A reference to the Environment instance. +/// An object specifying the compilation options. +/// A Status indicating an error or success. +Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options); + // load a library that is added using RegisterExecutionProviderLibrary. // infer whether it's a provider bridge library or plugin library Status LoadPluginOrProviderBridge(const std::string& registration_name, @@ -38,4 +60,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, std::unique_ptr& ep_library, std::vector& internal_factories); +// Creates an IExecutionProviderFactory instance for a list of OrtEpDevices that all refer to the same EP. +// Adds all provider options to the OrtSessionOptions configuration. +Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, + SessionOptions& session_options, + gsl::span ep_devices, + gsl::span ep_options_keys, + gsl::span ep_options_vals, + /*output*/ std::unique_ptr& out); + } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index ed0298a85b8e7..e7dc4294f3672 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -172,10 +172,10 @@ class Session: This is the main class used to run a model. """ - def __init__(self): + def __init__(self, enable_fallback: bool = True): # self._sess is managed by the derived class and relies on bindings from C.InferenceSession self._sess = None - self._enable_fallback = True + self._enable_fallback = enable_fallback def get_session_options(self) -> onnxruntime.SessionOptions: "Return the session options. See :class:`onnxruntime.SessionOptions`." @@ -446,7 +446,7 @@ def __init__( means execute a node using `CUDAExecutionProvider` if capable, otherwise execute using `CPUExecutionProvider`. """ - super().__init__() + super().__init__(enable_fallback=int(kwargs.get("enable_fallback", 1)) == 1) if isinstance(path_or_bytes, (str, os.PathLike)): self._model_path = os.fspath(path_or_bytes) @@ -459,7 +459,6 @@ def __init__( self._sess_options = sess_options self._sess_options_initial = sess_options - self._enable_fallback = True if "read_config_from_model" in kwargs: self._read_config_from_model = int(kwargs["read_config_from_model"]) == 1 else: @@ -542,6 +541,16 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options, available_providers ) + # Print a warning if user passed providers to InferenceSession() but the SessionOptions instance + # already has provider information (e.g., via add_provider_for_devices()). The providers specified + # here will take precedence. + if self._sess_options is not None and (providers or provider_options) and self._sess_options.has_providers(): + warnings.warn( + "Specified 'providers'/'provider_options' when creating InferenceSession but SessionOptions has " + "already been configured with providers. InferenceSession will only use the providers " + "passed to InferenceSession()." + ) + session_options = self._sess_options if self._sess_options else C.get_default_session_options() self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) @@ -609,6 +618,115 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options, C.register_nv_tensorrt_rtx_plugins_as_custom_ops(session_options, providers[i][1]) +class ModelCompiler: + """ + This class is used to compile an ONNX model. A compiled ONNX model has EPContext nodes that each + encapsulates a subgraph compiled/optimized for a specific execution provider. + + Refer to the EPContext design document for more information about EPContext models: + https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html + + :: + + sess_options = onnxruntime.SessionOptions() + sess_options.add_provider("SomeExecutionProvider", {"option1": "value1"}) + # Alternatively, allow ONNX Runtime to select the provider automatically given a policy: + # sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU) + + model_compiler = onnxruntime.ModelCompiler(sess_options, "input_model.onnx") + model_compiler.compile_to_file("output_model.onnx") + """ + + def __init__( + self, + sess_options: onnxruntime.SessionOptions, + input_model_path_or_bytes: str | os.PathLike | bytes, + embed_compiled_data_into_model: bool = False, + external_initializers_file_path: str | os.PathLike | None = None, + external_initializers_size_threshold: int = 1024, + ): + """ + Creates a ModelCompiler instance. + + :param sess_options: Session options containing the providers for which the model will be compiled. + Refer to SessionOptions.add_provider() and SessionOptions.set_provider_selection_policy(). + :param input_model_path_or_bytes: The path to the input model file or bytes representing a serialized + ONNX model. + :param embed_compiled_data_into_model: Defaults to False. Set to True to embed compiled binary data into + EPContext nodes in the compiled model. + :param external_initializers_file_path: Defaults to None. Set to a path for a file that will store the + initializers for non-compiled nodes. + :param external_initializers_size_threshold: Defaults to 1024. Ignored if `external_initializers_file_path` + is None or empty. Initializers larger than this threshold are stored in the external initializers file. + """ + input_model_path: str | os.PathLike | None = None + input_model_bytes: bytes | None = None + if isinstance(input_model_path_or_bytes, (str, os.PathLike)): + if not input_model_path_or_bytes: + raise ValueError("Input model path is empty") + input_model_path = os.fspath(input_model_path_or_bytes) + elif isinstance(input_model_path_or_bytes, bytes): + if len(input_model_path_or_bytes) == 0: + raise ValueError("Input model bytes array is empty") + input_model_bytes = input_model_path_or_bytes + else: + raise TypeError(f"Unable to load from type '{type(input_model_path_or_bytes)}'") + + if external_initializers_file_path: + if not isinstance(external_initializers_file_path, (str, os.PathLike)): + arg_type = type(external_initializers_file_path) + raise TypeError(f"Output external initializer filepath is of unexpected type '{arg_type}'") + external_initializers_file_path = os.fspath(external_initializers_file_path) + else: + external_initializers_file_path = "" + + if input_model_path: + self._model_compiler = C.ModelCompiler( + sess_options, + input_model_path, + True, # is path + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold, + ) + else: + self._model_compiler = C.ModelCompiler( + sess_options, + input_model_bytes, + False, # is bytes + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold, + ) + + def compile_to_file(self, output_model_path: str | None = None): + """ + Compiles to an output file. If an output file path is not provided, + the output file path is generated based on the input model path by replacing + '.onnx' with '_ctx.onnx'. Ex: The generated output file is 'model_ctx.onnx' for + an input model with path 'model.onnx'. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :param output_model_path: Defaults to None. The path for the output/compiled model. + """ + if output_model_path: + if not isinstance(output_model_path, (str, os.PathLike)): + raise TypeError(f"Output model's filepath is of unexpected type '{type(output_model_path)}'") + output_model_path = os.fspath(output_model_path) + self._model_compiler.compile_to_file(output_model_path) + + def compile_to_bytes(self) -> bytes: + """ + Compiles to bytes representing the serialized compiled ONNX model. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :return: A bytes object representing the compiled ONNX model. + """ + return self._model_compiler.compile_to_bytes() + + class IOBinding: """ This class provides API to bind input/output to a specified device, e.g. GPU. diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 1a1aae6f48ad1..8f3b97c8c7786 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -35,6 +35,8 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "NotImplemented"); pybind11::register_exception(m, "InvalidGraph"); pybind11::register_exception(m, "EPFail"); + pybind11::register_exception(m, "ModelLoadCanceled"); + pybind11::register_exception(m, "ModelRequiresCompilation"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -61,6 +63,10 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw InvalidGraph(std::move(msg)); case onnxruntime::common::StatusCode::EP_FAIL: throw EPFail(std::move(msg)); + case onnxruntime::common::StatusCode::MODEL_LOAD_CANCELED: + throw ModelLoadCanceled(std::move(msg)); + case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: + throw ModelRequiresCompilation(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index bc7d31ff2be2d..86bc4a5da8d46 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -44,6 +44,12 @@ struct InvalidGraph : std::runtime_error { struct EPFail : std::runtime_error { explicit EPFail(const std::string& what) : std::runtime_error(what) {} }; +struct ModelLoadCanceled : std::runtime_error { + explicit ModelLoadCanceled(const std::string& what) : std::runtime_error(what) {} +}; +struct ModelRequiresCompilation : std::runtime_error { + explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc new file mode 100644 index 0000000000000..8bb7ee2098caf --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#include "python/onnxruntime_pybind_model_compiler.h" + +#include +#include +#include +#include "core/common/common.h" +#include "core/framework/error_code_helper.h" +#include "core/session/utils.h" + +namespace onnxruntime { +namespace python { + +onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr& out, + std::shared_ptr env, + const PySessionOptions& sess_options, + std::string&& input_model_path_or_bytes, bool input_model_is_path, + bool embed_compiled_data_into_model, + const std::string& external_initializers_file_path, + size_t external_initializers_size_threshold) { + auto model_compiler = std::make_unique(env, sess_options, PrivateConstructorTag{}); + ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; + + if (input_model_is_path) { + compile_options.SetInputModelPath(input_model_path_or_bytes); + } else { + model_compiler->input_model_bytes_ = std::move(input_model_path_or_bytes); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_compiler->input_model_bytes_.data()), + model_compiler->input_model_bytes_.size()); + } + + ORT_RETURN_IF_ERROR(compile_options.SetEpContextEmbedMode(embed_compiled_data_into_model)); + + if (!external_initializers_file_path.empty()) { + compile_options.SetOutputModelExternalInitializersFile(external_initializers_file_path, + external_initializers_size_threshold); + } + + out = std::move(model_compiler); + return Status::OK(); +} + +onnxruntime::Status PyModelCompiler::CompileToFile(const std::string& output_model_path) { + ORT_RETURN_IF_ERROR(model_compile_options_.SetOutputModelPath(output_model_path)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + return Status::OK(); +} + +onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) { + if (!output_buffer.empty()) { + // Opt to return an error if the output buffer is not empty instead of just calling output_buffer.clear() + // because the C++ standard does not explicitly require that capacity is unchanged by a call to clear(). + // Don't want to reallocate a large buffer an extra time unnecessarily. So, we'll consider this an internal + // ORT error. + // Refer to: https://en.cppreference.com/w/cpp/string/basic_string/clear + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output buffer should be empty."); + } + + onnxruntime::AllocatorPtr allocator = std::make_shared(); + + void* buffer_data = nullptr; + size_t buffer_size = 0; + ORT_RETURN_IF_ERROR(model_compile_options_.SetOutputModelBuffer(allocator, &buffer_data, &buffer_size)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + + // Copy into output buffer. + output_buffer.reserve(buffer_size); + gsl::span src(reinterpret_cast(buffer_data), buffer_size); + std::copy(src.begin(), src.end(), std::back_inserter(output_buffer)); + return Status::OK(); +} + +PyModelCompiler::PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, + PrivateConstructorTag) + : env_(env), model_compile_options_(*env, sess_options) { +} +} // namespace python +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h new file mode 100644 index 0000000000000..6c9f48fa00ba6 --- /dev/null +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// Licensed under the MIT License. +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/common/status.h" +#include "core/session/model_compilation_options.h" +#include "python/onnxruntime_pybind_state_common.h" + +namespace onnxruntime { +class Environment; + +namespace python { +/// +/// Class exposed to Python that enables compiling ONNX models. +/// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. +/// +class PyModelCompiler { + private: + // private tag to pass to constructor to ensure that constructor cannot be directly called externally + struct PrivateConstructorTag {}; + + public: + /// + /// Static class function that creates a unique_ptr with the given settings. + /// + /// Output parameter for the result + /// The Environment instance + /// The SessionOptions from which to initialize compilation options. + /// An r-value string that could be the input model's path or bytes + /// True if 'input_model_path_or_bytes' is a path, and false if its bytes. + /// True to embed compiled binary data into EPContext nodes. + /// The file into which to store initializers for non-compiled + /// nodes. + /// Ignored if 'external_initializers_file_path' is empty. + /// Initializers with a size greater than this threshold are dumped into the external file. + /// A Status indicating error or success. + static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, + std::shared_ptr env, + const PySessionOptions& sess_options, + std::string&& input_model_path_or_bytes, bool input_model_is_path, + bool embed_compiled_data_into_model = false, + const std::string& external_initializers_file_path = {}, + size_t external_initializers_size_threshold = 1024); + + // Note: Creation should be done via Create(). This constructor is public so that it can be called from + // std::make_shared(). + PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, + PrivateConstructorTag); + + /// + /// Compiles the input model and saves the result to an output file. + /// If the 'output_model_path' is not specified, + /// it is generated based on the input model's path by replacing '.onnx' with '_ctx.onnx'. + /// + /// The path into which to save the compiled model. + /// A Status indicating error or success. + onnxruntime::Status CompileToFile(const std::string& output_model_path = {}); + + /// + /// Compiles the input model and stores the result into a buffer. + /// + /// A reference to the output buffer into which to store the + /// serialized ONNX model bytes. + /// A Status indicating error or success. + onnxruntime::Status CompileToBytes(std::string& output_buffer); + + private: + std::shared_ptr env_; + onnxruntime::ModelCompilationOptions model_compile_options_; + std::string input_model_bytes_; +}; +} // namespace python +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 0f15c5fbbdba0..c29a8b0497b1f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2,10 +2,15 @@ // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#include #include "python/onnxruntime_pybind_exceptions.h" #include "python/onnxruntime_pybind_mlvalue.h" #include "python/onnxruntime_pybind_state_common.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "python/onnxruntime_pybind_model_compiler.h" +#endif // !defined(ORT_MINIMAL_BUILD) + #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API #include "python/numpy_helper.h" @@ -26,6 +31,7 @@ #include "core/graph/graph_viewer.h" #include "core/platform/env.h" #include "core/providers/get_execution_providers.h" +#include "core/providers/providers.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" #include "core/session/IOBinding.h" #include "core/session/abi_session_options_impl.h" @@ -34,6 +40,13 @@ #include "core/session/lora_adapters.h" +#if !defined(ORT_MINIMAL_BUILD) +#include "core/session/abi_devices.h" +#include "core/session/ep_factory_internal.h" +#include "core/session/provider_policy_context.h" +#include "core/session/utils.h" +#endif + #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op_executor.h" #endif @@ -402,7 +415,7 @@ py::object AddTensorAsPyObj(const OrtValue& val, const DataTransferManager* data return GetPyObjFromTensor(val, data_transfer_manager, mem_cpy_to_host_functions); } -static std::unique_ptr LoadExecutionProvider( +static std::shared_ptr LoadExecutionProviderFactory( const std::string& ep_shared_lib_path, const ProviderOptions& provider_options = {}, const std::string& entry_symbol_name = "GetProvider") { @@ -417,8 +430,7 @@ static std::unique_ptr LoadExecutionProvider( OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider)); Provider* provider = PGetProvider(); - std::shared_ptr ep_factory = provider->CreateExecutionProviderFactory(&provider_options); - return ep_factory->CreateProvider(); + return provider->CreateExecutionProviderFactory(&provider_options); } #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) @@ -539,14 +551,21 @@ void RegisterNvTensorRTRtxPluginsAsCustomOps(PySessionOptions& so, const Provide } #endif -std::unique_ptr CreateExecutionProviderInstance( +/** + * Creates an IExecutionProviderFactory instance of the specified type. + * @param session_options The session options. + * @param type The execution provider type (e.g., CUDAExecutionProvider). + * @param provider_options_map A map of provider options. + * + * @return A shared_ptr with the factory instance, or null if unable to create it. + */ +static std::shared_ptr CreateExecutionProviderFactoryInstance( const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map) { if (type == kCpuExecutionProvider) { return onnxruntime::CPUProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) - ->CreateProvider(); + session_options.enable_cpu_mem_arena); } else if (type == kTensorrtExecutionProvider) { #if defined(USE_TENSORRT) || defined(USE_TENSORRT_PROVIDER_INTERFACE) // If the environment variable 'ORT_TENSORRT_UNAVAILABLE' exists, then we do not load TensorRT. This is set by _ld_preload for the manylinux case @@ -869,11 +888,11 @@ std::unique_ptr CreateExecutionProviderInstance( } } if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) { - return tensorrt_provider_factory->CreateProvider(); + return tensorrt_provider_factory; } } else { if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(cuda_device_id)) { - return tensorrt_provider_factory->CreateProvider(); + return tensorrt_provider_factory; } } } @@ -892,11 +911,11 @@ std::unique_ptr CreateExecutionProviderInstance( ProviderOptions info = it->second; if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create( info, &session_options)) { - return nv_tensorrt_rtx_provider_factory->CreateProvider(); + return nv_tensorrt_rtx_provider_factory; } } else { if (std::shared_ptr nv_tensorrt_rtx_provider_factory = onnxruntime::NvProviderFactoryCreator::Create(cuda_device_id)) { - return nv_tensorrt_rtx_provider_factory->CreateProvider(); + return nv_tensorrt_rtx_provider_factory; } } } @@ -1024,12 +1043,12 @@ std::unique_ptr CreateExecutionProviderInstance( } if (std::shared_ptr migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { - return migraphx_provider_factory->CreateProvider(); + return migraphx_provider_factory; } } else { if (std::shared_ptr migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) { - return migraphx_provider_factory->CreateProvider(); + return migraphx_provider_factory; } } #endif @@ -1048,7 +1067,7 @@ std::unique_ptr CreateExecutionProviderInstance( // hence we must try to initialize it here if we can since FromProviderOptions might contain // external CUDA allocator. external_allocator_info = info.external_allocator_info; - return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return cuda_provider_info->CreateExecutionProviderFactory(info); } } #if defined(USE_CUDA) @@ -1081,7 +1100,7 @@ std::unique_ptr CreateExecutionProviderInstance( // however they still exist and are in-use. Nevertheless, it is used to return ROCMAllocator, hence we must // try to initialize it here if we can since FromProviderOptions might contain external ROCM allocator. external_allocator_info = info.external_allocator_info; - return rocm_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return rocm_provider_info->CreateExecutionProviderFactory(info); } else { if (!Env::Default().GetEnvironmentVar("ROCM_PATH").empty()) { ORT_THROW( @@ -1118,7 +1137,7 @@ std::unique_ptr CreateExecutionProviderInstance( #endif // !defined(DNNL_ORT_THREAD) dnnl_options.use_arena = session_options.enable_cpu_mem_arena; - return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options)->CreateProvider(); + return onnxruntime::DnnlProviderFactoryCreator::Create(&dnnl_options); #endif } else if (type == kOpenVINOExecutionProvider) { #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) @@ -1189,10 +1208,9 @@ std::unique_ptr CreateExecutionProviderInstance( } if (std::shared_ptr openvino_provider_factory = onnxruntime::OpenVINOProviderFactoryCreator::Create( &OV_provider_options_map, &session_options)) { - auto p = openvino_provider_factory->CreateProvider(); // Reset global variables config to avoid it being accidentally passed on to the next session openvino_device_type.clear(); - return p; + return openvino_provider_factory; } else { if (!Env::Default().GetEnvironmentVar("INTEL_OPENVINO_DIR").empty()) { ORT_THROW("INTEL_OPENVINO_DIR is set but OpenVINO library wasn't able to be loaded. Please install a supported version of OpenVINO as mentioned in the requirements page (https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#requirements), ensure dependency libraries are in the PATH and your hardware is supported."); @@ -1210,7 +1228,7 @@ std::unique_ptr CreateExecutionProviderInstance( } info["session_options"] = std::to_string((uintptr_t)(void*)&session_options); if (auto vitisai_factory = onnxruntime::VitisAIProviderFactoryCreator::Create(info); vitisai_factory) { - return vitisai_factory->CreateProvider(); + return vitisai_factory; } LOGS_DEFAULT(WARNING) << "Failed to create " << type @@ -1238,21 +1256,18 @@ std::unique_ptr CreateExecutionProviderInstance( } } } - return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math) - ->CreateProvider(); + return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math); #endif } else if (type == kArmNNExecutionProvider) { #ifdef USE_ARMNN return onnxruntime::ArmNNProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) - ->CreateProvider(); + session_options.enable_cpu_mem_arena); #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML auto cit = provider_options_map.find(type); return onnxruntime::DMLProviderFactoryCreator::CreateFromProviderOptions( - session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true) - ->CreateProvider(); + session_options.config_options, cit == provider_options_map.end() ? ProviderOptions{} : cit->second, true); #endif } else if (type == kNnapiExecutionProvider) { #if defined(USE_NNAPI) @@ -1261,15 +1276,15 @@ std::unique_ptr CreateExecutionProviderInstance( #endif const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); - return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider(); + return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list); #endif } else if (type == kVSINPUExecutionProvider) { #ifdef USE_VSINPU - return onnxruntime::VSINPUProviderFactoryCreator::Create()->CreateProvider(); + return onnxruntime::VSINPUProviderFactoryCreator::Create(); #endif } else if (type == kRknpuExecutionProvider) { #ifdef USE_RKNPU - return onnxruntime::RknpuProviderFactoryCreator::Create()->CreateProvider(); + return onnxruntime::RknpuProviderFactoryCreator::Create(); #endif } else if (type == kCoreMLExecutionProvider) { #if defined(USE_COREML) @@ -1300,36 +1315,35 @@ std::unique_ptr CreateExecutionProviderInstance( } } else { // read from provider_options - return onnxruntime::CoreMLProviderFactoryCreator::Create(options)->CreateProvider(); + return onnxruntime::CoreMLProviderFactoryCreator::Create(options); } } - return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); + return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags); #endif } else if (type == kXnnpackExecutionProvider) { #if defined(USE_XNNPACK) auto cit = provider_options_map.find(type); return onnxruntime::XnnpackProviderFactoryCreator::Create( - cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) - ->CreateProvider(); + cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options); #endif } else if (type == kWebGpuExecutionProvider) { #if defined(USE_WEBGPU) - return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider(); + return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN if (auto* cann_provider_info = TryGetProviderInfo_CANN()) { const CANNExecutionProviderInfo info = GetCannExecutionProviderInfo(cann_provider_info, provider_options_map); - return cann_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); + return cann_provider_info->CreateExecutionProviderFactory(info); } else { ORT_THROW("create CANN ExecutionProvider fail"); } #endif } else if (type == kAzureExecutionProvider) { #ifdef USE_AZURE - return onnxruntime::AzureProviderFactoryCreator::Create({})->CreateProvider(); + return onnxruntime::AzureProviderFactoryCreator::Create({}); #endif } else if (type == kQnnExecutionProvider) { #if defined(USE_QNN) || defined(USE_QNN_PROVIDER_INTERFACE) @@ -1337,7 +1351,7 @@ std::unique_ptr CreateExecutionProviderInstance( auto qnn_factory = onnxruntime::QNNProviderFactoryCreator::Create( cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options); if (qnn_factory) { - return qnn_factory->CreateProvider(); + return qnn_factory; } LOGS_DEFAULT(WARNING) << "Failed to create " << type @@ -1362,7 +1376,7 @@ std::unique_ptr CreateExecutionProviderInstance( provider_options.insert(option); } } - return LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol); + return LoadExecutionProviderFactory(shared_lib_path_it->second, provider_options, entry_symbol); } } // unknown provider @@ -1371,18 +1385,56 @@ std::unique_ptr CreateExecutionProviderInstance( return nullptr; } +/** + * Create an IExecutionProvider instance of the specified type. Note: this is called by orttraining code. + * @param session_options The session options. + * @param type The execution provider type (e.g., CUDAExecutionProvider). + * @param provider_options_map A map of provider options. + * + * @return A unique_ptr with the execution provider instance, or null if unable to create it. + */ +std::unique_ptr CreateExecutionProviderInstance(const SessionOptions& session_options, + const std::string& type, + const ProviderOptionsMap& provider_options_map) { + auto ep_factory = CreateExecutionProviderFactoryInstance(session_options, type, provider_options_map); + if (ep_factory) { + return ep_factory->CreateProvider(); + } + return nullptr; +} + /* * Register execution provider with options. */ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector& provider_types, const ProviderOptionsMap& provider_options_map) { - ORT_UNUSED_PARAMETER(provider_options_map); - for (const std::string& type : provider_types) { auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map); - if (ep) + if (ep) { OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep))); + } + } +} + +/** + * Adds an explicit execution provider factory to the session options. + * + * @param py_sess_options The session options. + * @param provider_type The type of the provider to add. + * @param provider_options The options for the execution provider as a map of string key/value pairs. + * + * @return A Status indicating an error or success. + */ +static Status AddExplicitEpFactory(PySessionOptions& py_sess_options, const std::string& provider_type, + const ProviderOptions& provider_options) { + const ProviderOptionsMap provider_options_map = {{provider_type, provider_options}}; + auto ep_factory = CreateExecutionProviderFactoryInstance(py_sess_options.value, provider_type, provider_options_map); + if (!ep_factory) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to add provider of type '", + provider_type, "' to SessionOptions. Provider configuration is not supported."); } + py_sess_options.provider_factories.push_back(std::move(ep_factory)); + return Status::OK(); } /** @@ -1426,6 +1478,83 @@ static void RegisterCustomOpDomains(PyInferenceSession* sess, const PySessionOpt } #endif +#if !defined(ORT_MINIMAL_BUILD) +/** + * Add the execution provider that is responsible for the selected OrtEpDevice instances to the session options. + * + * @param py_sess_options The session options. + * @param provider_type The type of the provider to add. + * @param provider_options The options for the execution provider as a map of string key/value pairs. + * + * @return A Status indicating an error or success. + */ +static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, + const std::vector& ep_devices, + const ProviderOptions& provider_options) { + std::shared_ptr env = GetEnv(); + const size_t num_ep_options = provider_options.size(); + std::vector ep_option_keys; + std::vector ep_option_vals; + + ep_option_keys.reserve(num_ep_options); + ep_option_vals.reserve(num_ep_options); + for (const auto& [key, val] : provider_options) { + ep_option_keys.push_back(key.c_str()); + ep_option_vals.push_back(val.c_str()); + } + + std::unique_ptr provider_factory = nullptr; + ORT_RETURN_IF_ERROR(CreateIExecutionProviderFactoryForEpDevices(*env, + py_sess_options.value, + ep_devices, + ep_option_keys, + ep_option_vals, + /*output*/ provider_factory)); + py_sess_options.provider_factories.push_back(std::move(provider_factory)); + return Status::OK(); +} + +/** + * Initializes the inference session using EPs specified in the session options. + * + * @param py_sess The inference session. + * @param disabled_optimizer_names Set of optimizers to disable. + * @return A Status indicating error or success. + */ +static Status InitializeSessionEpsFromSessionOptions(PyInferenceSession& py_sess, + const std::unordered_set& disabled_optimizer_names) { + ORT_RETURN_IF(py_sess.GetSessionHandle() == nullptr, "Invalid Python InferenceSession handle"); + InferenceSession& sess = *py_sess.GetSessionHandle(); + + const logging::Logger* sess_logger = sess.GetLogger(); + ORT_RETURN_IF(sess_logger == nullptr, "Invalid InferenceSession logger handle"); + + const OrtSessionOptions& ort_session_options = py_sess.GetOrtSessionOptions(); + + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection + if (ort_session_options.provider_factories.empty() && ort_session_options.value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_RETURN_IF_ERROR(context.SelectEpsForSession(*GetEnv(), ort_session_options, sess)); + } else { + for (const auto& provider_factory : ort_session_options.provider_factories) { + std::unique_ptr ep = provider_factory->CreateProvider(ort_session_options, + *(sess_logger->ToExternal())); + if (ep) { + ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); + } + } + } + + if (!disabled_optimizer_names.empty()) { + ORT_RETURN_IF_ERROR(sess.FilterEnabledOptimizers({disabled_optimizer_names.cbegin(), + disabled_optimizer_names.cend()})); + } + + ORT_RETURN_IF_ERROR(sess.Initialize()); + return Status::OK(); +} +#endif // !defined(ORT_MINIMAL_BUILD) + void InitializeSession(InferenceSession* sess, ExecutionProviderRegistrationFn ep_registration_fn, const std::vector& provider_types, @@ -1528,6 +1657,43 @@ void addGlobalMethods(py::module& m) { throw std::runtime_error("Error when creating and registering allocator in create_and_register_allocator_v2: " + st.ErrorMessage()); } }); + m.def( + "register_execution_provider_library", + [](const std::string& registration_name, const PathString& library_path) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + OrtPybindThrowIfError(env->RegisterExecutionProviderLibrary(registration_name, library_path.c_str())); +#else + ORT_UNUSED_PARAMETER(registration_name); + ORT_UNUSED_PARAMETER(library_path); + ORT_THROW("Execution provider libraries are not supported in this build."); +#endif + }, + R"pbdoc(Register an execution provider library with ONNX Runtime.)pbdoc"); + m.def( + "unregister_execution_provider_library", + [](const std::string& registration_name) -> void { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + OrtPybindThrowIfError(env->UnregisterExecutionProviderLibrary(registration_name)); +#else + ORT_UNUSED_PARAMETER(registration_name); + ORT_THROW("Execution provider libraries are not supported in this build."); +#endif + }, + R"pbdoc(Unregister an execution provider library from ONNX Runtime.)pbdoc"); + m.def( + "get_ep_devices", + []() -> const std::vector& { +#if !defined(ORT_MINIMAL_BUILD) + std::shared_ptr env = GetEnv(); + return env->GetOrtEpDevices(); +#else + ORT_THROW("OrtEpDevices are not supported in this build"); +#endif + }, + R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", + py::return_value_policy::reference); #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( @@ -1632,6 +1798,39 @@ void addGlobalMethods(py::module& m) { #endif } +// TODO(adrianlizarraga): C API's delegate function needs a void* param to store state. +// using PyEpSelectionDelegate = std::function(const std::vector& ep_devices, +// const std::unordered_map& model_metadata, +// const std::unordered_map& runtime_metadata)>; +// +// static OrtStatus* PyDelegateWrapper(void* delegate_state, +// _In_ const OrtEpDevice** ep_devices, +// _In_ size_t num_devices, +// _In_ const OrtKeyValuePairs* model_metadata, +// _In_opt_ const OrtKeyValuePairs* runtime_metadata, +// _Inout_ const OrtEpDevice** selected, +// _In_ size_t max_selected, +// _Out_ size_t* num_selected) { +// PyEpSelectionDelegate* actual_delegate = reinterpret_cast(delegate_state); +// std::vector py_ep_devices(ep_devices, ep_devices + num_devices); +// std::unordered_map py_model_metadata = +// model_metadata ? model_metadata->entries : std::unordered_map{}; +// std::unordered_map py_runtime_metadata = +// runtime_metadata ? runtime_metadata->entries : std::unordered_map{}; +// +// std::vector py_selected = (*actual_delegate)(py_ep_devices, py_model_metadata, py_runtime_metadata); +// +// // TODO: Check max_selected and return OrtStatus if necessary. +// assert(py_selected.size() <= max_selected); +// +// *num_selected = py_selected.size(); +// for (size_t i = 0; i < py_selected.size(); ++i) { +// selected[i] = py_selected[i]; +// } +// +// return nullptr; +// }; + void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { py::enum_(m, "GraphOptimizationLevel") .value("ORT_DISABLE_ALL", GraphOptimizationLevel::ORT_DISABLE_ALL) @@ -1672,6 +1871,75 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def_static("webgpu", []() { return OrtDevice::GPU; }) .def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; }); + py::enum_(m, "OrtExecutionProviderDevicePolicy") + .value("DEFAULT", OrtExecutionProviderDevicePolicy_DEFAULT) + .value("PREFER_CPU", OrtExecutionProviderDevicePolicy_PREFER_CPU) + .value("PREFER_NPU", OrtExecutionProviderDevicePolicy_PREFER_NPU) + .value("PREFER_GPU", OrtExecutionProviderDevicePolicy_PREFER_GPU) + .value("MAX_PERFORMANCE", OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE) + .value("MAX_EFFICIENCY", OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY) + .value("MIN_OVERALL_POWER", OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER); + + py::enum_(m, "OrtHardwareDeviceType") + .value("CPU", OrtHardwareDeviceType_CPU) + .value("GPU", OrtHardwareDeviceType_GPU) + .value("NPU", OrtHardwareDeviceType_NPU); + + py::class_ py_hw_device(m, "OrtHardwareDevice", R"pbdoc(ONNX Runtime hardware device information.)pbdoc"); + py_hw_device.def_property_readonly( + "type", + [](OrtHardwareDevice* hw_device) -> OrtHardwareDeviceType { return hw_device->type; }, + R"pbdoc(Hardware device's type.)pbdoc") + .def_property_readonly( + "vendor_id", + [](OrtHardwareDevice* hw_device) -> uint32_t { return hw_device->vendor_id; }, + R"pbdoc(Hardware device's vendor identifier.)pbdoc") + .def_property_readonly( + "vendor", + [](OrtHardwareDevice* hw_device) -> std::string { return hw_device->vendor; }, + R"pbdoc(Hardware device's vendor name.)pbdoc") + .def_property_readonly( + "device_id", + [](OrtHardwareDevice* hw_device) -> uint32_t { return hw_device->device_id; }, + R"pbdoc(Hardware device's unique identifier.)pbdoc") + .def_property_readonly( + "metadata", + [](OrtHardwareDevice* hw_device) -> std::unordered_map { + return hw_device->metadata.entries; + }, + R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + + py::class_ py_ep_device(m, "OrtEpDevice", + R"pbdoc(Represents a hardware device that an execution provider supports +for model inference.)pbdoc"); + py_ep_device.def_property_readonly( + "ep_name", + [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + R"pbdoc(The execution provider's name.)pbdoc") + .def_property_readonly( + "ep_vendor", + [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + R"pbdoc(The execution provider's vendor name.)pbdoc") + .def_property_readonly( + "ep_metadata", + [](OrtEpDevice* ep_device) -> std::unordered_map { + return ep_device->ep_metadata.entries; + }, + R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") + .def_property_readonly( + "ep_options", + [](OrtEpDevice* ep_device) -> std::unordered_map { + return ep_device->ep_options.entries; + }, + R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") + .def_property_readonly( + "device", + [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + return *ep_device->device; + }, + R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", + py::return_value_policy::reference_internal); + py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. // This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options @@ -1736,6 +2004,59 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); sess .def(py::init()) + .def( + // Equivalent to the C API's SessionOptionsAppendExecutionProvider. + "add_provider", + [](PySessionOptions* sess_options, + const std::string& provider_name, + const ProviderOptions& provider_options = {}) { + OrtPybindThrowIfError(AddExplicitEpFactory(*sess_options, provider_name, provider_options)); + }, + R"pbdoc(Adds an explicit execution provider.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsAppendExecutionProvider_V2. + "add_provider_for_devices", + [](PySessionOptions* sess_options, + const std::vector& ep_devices, + const ProviderOptions& provider_options = {}) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(AddEpFactoryFromEpDevices(*sess_options, + ep_devices, + provider_options)); +#else + ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(ep_devices); + ORT_UNUSED_PARAMETER(provider_options); + ORT_THROW("OrtEpDevices are not supported in this build"); +#endif + }, + R"pbdoc(Adds the execution provider that is responsible for the selected OrtEpDevice instances. All OrtEpDevice instances +must refer to the same execution provider.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsSetEpSelectionPolicy. + "set_provider_selection_policy", + [](PySessionOptions* sess_options, + OrtExecutionProviderDevicePolicy policy) { +#if !defined(ORT_MINIMAL_BUILD) + sess_options->value.ep_selection_policy.enable = true; + sess_options->value.ep_selection_policy.policy = policy; + sess_options->value.ep_selection_policy.delegate = nullptr; // TODO: need a void* param in delegate. +#else + ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(policy); + ORT_THROW("EP selection policies are not supported in this build"); +#endif + }, + R"pbdoc(Sets the execution provider selection policy for the session. Allows users to specify a +selection policy for automatic execution provider (EP) selection, or provide a delegate callback +for custom selection logic.)pbdoc") + .def( + "has_providers", + [](PySessionOptions* sess_options) -> bool { + return !sess_options->provider_factories.empty() || sess_options->value.ep_selection_policy.enable; + }, + R"pbdoc(Returns true if the SessionOptions has been configured with providers, OrtEpDevices, or +policies that will run the model.)pbdoc") .def_property( "enable_cpu_mem_arena", [](const PySessionOptions* options) -> bool { return options->value.enable_cpu_mem_arena; }, @@ -2132,11 +2453,18 @@ including arg name, arg type (contains both type and shape).)pbdoc") const std::vector& provider_types = {}, const ProviderOptionsVector& provider_options = {}, const std::unordered_set& disabled_optimizer_names = {}) { - InitializeSession(sess->GetSessionHandle(), - ep_registration_fn, - provider_types, - provider_options, - disabled_optimizer_names); + // If the user did not explicitly specify providers when creating InferenceSession and the SessionOptions + // has provider information (i.e., either explicit EPs or an EP selection policy), then use the information + // in the session options to initialize the session. + if (provider_types.empty() && sess->HasProvidersInSessionOptions()) { + OrtPybindThrowIfError(InitializeSessionEpsFromSessionOptions(*sess, disabled_optimizer_names)); + } else { + InitializeSession(sess->GetSessionHandle(), + ep_registration_fn, + provider_types, + provider_options, + disabled_optimizer_names); + } }, R"pbdoc(Load a model saved in ONNX or ORT format.)pbdoc") .def("run", @@ -2395,6 +2723,58 @@ including arg name, arg type (contains both type and shape).)pbdoc") .value("kNextPowerOfTwo", onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo) .value("kSameAsRequested", onnxruntime::ArenaExtendStrategy::kSameAsRequested) .export_values(); + + py::class_(m, "ModelCompiler", + R"pbdoc(This is the class used to compile an ONNX model.)pbdoc") + .def(py::init([](const PySessionOptions& sess_options, + std::string path_or_bytes, + bool is_path, + bool embed_compiled_data_into_model = false, + std::string external_initializers_file_path = {}, + size_t external_initializers_size_threshold = 1024) { +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr result; + OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, + std::move(path_or_bytes), is_path, + embed_compiled_data_into_model, + external_initializers_file_path, + external_initializers_size_threshold)); + return result; +#else + ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(path_or_bytes); + ORT_UNUSED_PARAMETER(is_path); + ORT_UNUSED_PARAMETER(embed_compiled_data_into_model); + ORT_UNUSED_PARAMETER(external_initializers_file_path); + ORT_UNUSED_PARAMETER(external_initializers_size_threshold); + ORT_THROW("Compile API is not supported in this build."); +#endif + })) + .def( + "compile_to_file", + [](PyModelCompiler* model_compiler, std::string output_model_path = {}) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToFile(output_model_path)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(output_model_path); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into a file.)pbdoc") + .def( + "compile_to_bytes", + [](PyModelCompiler* model_compiler) -> py::bytes { +#if !defined(ORT_MINIMAL_BUILD) + std::string output_bytes; + OrtPybindThrowIfError(model_compiler->CompileToBytes(output_bytes)); + return py::bytes(output_bytes); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); } bool CreateInferencePybindStateModule(py::module& m) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 3ae5c0d289c21..a964f199d43b3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -234,13 +234,13 @@ using PySessionOptions = OrtSessionOptions; // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user struct PyInferenceSession { PyInferenceSession(std::shared_ptr env, const PySessionOptions& so) - : env_(std::move(env)) { + : env_(std::move(env)), session_options_(so) { sess_ = std::make_unique(so.value, *env_); } #if !defined(ORT_MINIMAL_BUILD) PyInferenceSession(std::shared_ptr env, const PySessionOptions& so, const std::string& arg, bool is_arg_file_name) - : env_(std::move(env)) { + : env_(std::move(env)), session_options_(so) { if (is_arg_file_name) { // Given arg is the file path. Invoke the corresponding ctor(). sess_ = std::make_unique(so.value, *env_, arg); @@ -252,6 +252,24 @@ struct PyInferenceSession { } #endif + // Returns true if the session options have provider information from either + // setting explicit providers, setting a provider that supports a OrtEpDevice(s), or + // setting a selection policy (e.g., prefer gpu). + bool HasProvidersInSessionOptions() const { + return !session_options_.provider_factories.empty() || + session_options_.value.ep_selection_policy.enable; + } + + // Returns (and updates) a reference to the OrtSessionOptions for this inference session. + OrtSessionOptions& GetOrtSessionOptions() { + if (sess_) { + // Copy internal value from InferenceSession as it is the source of truth + // and the option configurations may have changed. + session_options_.value = sess_->GetSessionOptions(); + } + return session_options_; + } + InferenceSession* GetSessionHandle() const { return sess_.get(); } virtual ~PyInferenceSession() = default; @@ -264,6 +282,7 @@ struct PyInferenceSession { private: std::shared_ptr env_; std::unique_ptr sess_; + OrtSessionOptions session_options_; }; inline const PySessionOptions& GetDefaultCPUSessionOptions() { diff --git a/onnxruntime/test/python/autoep_helper.py b/onnxruntime/test/python/autoep_helper.py new file mode 100644 index 0000000000000..e3b214afa6e62 --- /dev/null +++ b/onnxruntime/test/python/autoep_helper.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import tempfile +import unittest + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import Fail + + +class AutoEpTestCase(unittest.TestCase): + """ + Base class for TestCase classes that need to register and unregister EP libraries. + Because EP libraries are registered with the ORT environment and all unit tests share + the same environment, this class tracks which libraries have already been registered + so that they are not erroneously registered or unregistered. + + Derived classes must use 'self.register_execution_provider_library()' and + 'self.unregister_execution_provider_library()' to benefit from these utilities. + """ + + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.autoep_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + # Track registered EP libraries across all tests. + cls._registered_providers = set() + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def register_execution_provider_library(self, ep_registration_name: str, ep_lib_path: os.PathLike | str): + if ep_registration_name in self._registered_providers: + return # Already registered + + try: + onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path) + except Fail as onnxruntime_error: + if "already registered" in str(onnxruntime_error): + pass # Allow register to fail if the EP library was previously registered. + else: + raise onnxruntime_error + + # Add this EP library to set of registered EP libraries. + # If the unit test itself does not unregister the library, tearDown() will try. + self._registered_providers.add(ep_registration_name) + + def unregister_execution_provider_library(self, ep_registration_name: str): + if ep_registration_name not in self._registered_providers: + return # Not registered + + try: + onnxrt.unregister_execution_provider_library(ep_registration_name) + except Fail as onnxruntime_error: + if "was not registered" in str(onnxruntime_error): + pass # Allow unregister to fail if the EP library was never registered. + else: + raise onnxruntime_error + + self._registered_providers.remove(ep_registration_name) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index f3ebc92409f77..5631e49069ae3 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1360,13 +1360,20 @@ def test_register_custom_ops_library(self): ) def test_ort_value(self): + providers_to_test = onnxrt.get_available_providers() numpy_arr_input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) numpy_arr_output = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) - def test_session_with_ortvalue_input(ortvalue): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + def test_session_with_ortvalue_input(ortvalue, providers): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=providers) res = sess.run(["Y"], {"X": ortvalue}) - self.assertTrue(np.array_equal(res[0], numpy_arr_output)) + + if "QNNExecutionProvider" in providers: + # QNN runs float32 with fp16 precision, so relax accuracy expectations + np.testing.assert_allclose(numpy_arr_output, res[0], rtol=1e-04, atol=1e-06) + else: + self.assertTrue(np.array_equal(res[0], numpy_arr_output)) + vect = sess._sess.run_with_ort_values({"X": ortvalue._get_c_value()}, ["Y"], RunOptions()) self.assertIsInstance(vect, OrtValueVector) @@ -1378,7 +1385,7 @@ def test_session_with_ortvalue_input(ortvalue): self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) # Pass in the constructed OrtValue to a session via Run() and check results - test_session_with_ortvalue_input(ortvalue1) + test_session_with_ortvalue_input(ortvalue1, providers_to_test) # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) @@ -1392,7 +1399,7 @@ def test_session_with_ortvalue_input(ortvalue): self.assertEqual(float_tensor_data_type, ort_value_with_type.element_type()) self.assertEqual([3, 2], ort_value_with_type.shape()) - if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + if "CUDAExecutionProvider" in providers_to_test: ortvalue2 = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input, "cuda", 0) self.assertEqual(ortvalue2.device_name(), "cuda") self.assertEqual(ortvalue2.shape(), [3, 2]) @@ -1401,7 +1408,7 @@ def test_session_with_ortvalue_input(ortvalue): self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input)) # Pass in the constructed OrtValue to a session via Run() and check results - test_session_with_ortvalue_input(ortvalue2) + test_session_with_ortvalue_input(ortvalue2, providers_to_test) # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input)) diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py new file mode 100644 index 0000000000000..1d7dba5662257 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import platform +import sys +import unittest + +import numpy as np +from autoep_helper import AutoEpTestCase +from helper import get_name + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + +available_providers = list(onnxrt.get_available_providers()) + + +class TestAutoEP(AutoEpTestCase): + def test_cuda_ep_register_and_inference(self): + """ + Test registration of CUDA EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_registration_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_registration_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_registration_name, ep_lib_path) + + ep_devices = onnxrt.get_ep_devices() + has_cpu_ep = False + cuda_ep_device = None + for ep_device in ep_devices: + ep_name = ep_device.ep_name + if ep_name == "CPUExecutionProvider": + has_cpu_ep = True + if ep_name == ep_registration_name: + cuda_ep_device = ep_device + + self.assertTrue(has_cpu_ep) + self.assertIsNotNone(cuda_ep_device) + self.assertEqual(cuda_ep_device.ep_vendor, "Microsoft") + + hw_device = cuda_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Add CUDA's OrtEpDevice to session options + sess_options = onnxrt.SessionOptions() + sess_options.add_provider_for_devices([cuda_ep_device], {"prefer_nhwc": "1"}) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + # TODO(adrianlizarraga): Unregistering CUDA EP library causes issues. Investigate. + # self.unregister_execution_provider_library(ep_registration_name) + + def test_cuda_prefer_gpu_and_inference(self): + """ + Test selecting CUDA EP via the PREFER_GPU policy and running inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_registration_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_registration_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_registration_name, ep_lib_path) + + # Set a policy to prefer GPU. Cuda should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + # TODO(adrianlizarraga): Unregistering CUDA EP library causes issues. Investigate. + # self.unregister_execution_provider_library(ep_registration_name) + + def test_example_plugin_ep_devices(self): + """ + Test registration of an example EP plugin and retrieval of its OrtEpDevice. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because it device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_registration_name = "example_ep" + self.register_execution_provider_library(ep_registration_name, os.path.realpath(ep_lib_path)) + + ep_devices = onnxrt.get_ep_devices() + has_cpu_ep = False + test_ep_device = None + for ep_device in ep_devices: + ep_name = ep_device.ep_name + + if ep_name == "CPUExecutionProvider": + has_cpu_ep = True + if ep_name == ep_registration_name: + test_ep_device = ep_device + + self.assertTrue(has_cpu_ep) + self.assertIsNotNone(test_ep_device) + + # Test the OrtEpDevice getters. Expected values are from /onnxruntime/test/autoep/library/example_plugin_ep.cc + self.assertEqual(test_ep_device.ep_vendor, "Contoso") + + ep_metadata = test_ep_device.ep_metadata + self.assertEqual(ep_metadata["version"], "0.1") + + ep_options = test_ep_device.ep_options + self.assertEqual(ep_options["run_really_fast"], "true") + + # The CPU hw device info will vary by machine so check for the common values. + hw_device = test_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.CPU) + self.assertGreaterEqual(hw_device.vendor_id, 0) + self.assertGreaterEqual(hw_device.device_id, 0) + self.assertGreater(len(hw_device.vendor), 0) + + hw_metadata = hw_device.metadata + self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + + # Test adding this EP plugin's OrtEpDevice to the SessionOptions. + sess_options = onnxrt.SessionOptions() + with self.assertRaises(InvalidArgument) as context: + # Will raise InvalidArgument because ORT currently only supports provider bridge APIs. + # Actual plugin EPs will be supported in the future. + sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) + self.assertIn("EP is not currently supported", str(context.exception)) + + self.unregister_execution_provider_library(ep_registration_name) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend.py b/onnxruntime/test/python/onnxruntime_test_python_backend.py index 1f6cd78f28334..6ed7dfe59b1f6 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend.py @@ -40,18 +40,28 @@ def test_allocation_plan_works_with_only_execute_path_to_fetches_option(self): This case is handled specifically in ExecutionFrame::AllocateAsPerAllocationPlan(). This test is to ensure that the case is covered. """ + providers = onnxrt.get_available_providers() + has_qnn_ep = "QNNExecutionProvider" in providers name = get_name("alloc_tensor_reuse.onnx") - sess = onnxrt.InferenceSession(name, providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession(name, providers=providers) run_options = onnxrt.RunOptions() run_options.only_execute_path_to_fetches = True inp0, inp1 = np.ones((10,), dtype=np.float32), np.ones((10,), dtype=np.float32) session_run_results = sess.run(["outp0"], {"inp0": inp0, "inp1": inp1}, run_options) - assert_allclose(session_run_results[0], -(inp0 + inp1)) + if has_qnn_ep: + # QNN EP runs fp32 with fp16 precision, so relax tolerance. + assert_allclose(session_run_results[0], -(inp0 + inp1), rtol=1e-6, atol=1e-6) + else: + assert_allclose(session_run_results[0], -(inp0 + inp1)) session_run_results = sess.run(["outp1"], {"inp0": inp0, "inp1": inp1}, run_options) - assert_allclose(session_run_results[0], -(inp0 - inp1)) + if has_qnn_ep: + # QNN EP runs fp32 with fp16 precision, so relax tolerance. + assert_allclose(session_run_results[0], -(inp0 - inp1), rtol=1e-6, atol=1e-6) + else: + assert_allclose(session_run_results[0], -(inp0 - inp1)) if __name__ == "__main__": diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py new file mode 100644 index 0000000000000..f5f23e2da1e43 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -0,0 +1,185 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import os +import platform +import sys +import unittest + +import onnx +from autoep_helper import AutoEpTestCase +from helper import get_name + +import onnxruntime as onnxrt +from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + +available_providers = list(onnxrt.get_available_providers()) + + +class TestCompileApi(AutoEpTestCase): + def test_compile_with_files_prefer_npu_policy(self): + """ + Tests compiling a model (to/from files) using an EP selection policy (PREFER_NPU). + """ + if "QNNExecutionProvider" not in available_providers: + self.skipTest("Skipping test because it needs to run on QNN EP") + + if sys.platform != "win32": + self.skipTest("Skipping test because provider selection policies are only supported on Windows") + + ep_lib_path = "onnxruntime_providers_qnn.dll" + ep_registration_name = "QNNExecutionProvider" + self.register_execution_provider_library(ep_registration_name, ep_lib_path) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled0.onnx") + + session_options = onnxrt.SessionOptions() + session_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_NPU) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + self.unregister_execution_provider_library(ep_registration_name) + + def test_compile_with_input_and_output_files(self): + """ + Tests compiling a model (to/from files) using explicit EP. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled1.onnx") + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_to_file_with_input_model_in_buffer(self): + """ + Tests compiling an input model that is stored in a buffer. The output is saved to a file. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_onnx_model = onnx.load(get_name("nhwc_resize_scales_opset18.onnx")) + input_model_bytes = input_onnx_model.SerializeToString() + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled2.onnx") + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_bytes, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + def test_compile_from_buffer_to_buffer(self): + """ + Tests compiling an input model that is stored in a buffer. The output is stored in a buffer too. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_onnx_model = onnx.load(get_name("nhwc_resize_scales_opset18.onnx")) + input_model_bytes = input_onnx_model.SerializeToString() + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_bytes, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + output_model_bytes = model_compiler.compile_to_bytes() + self.assertTrue(isinstance(output_model_bytes, bytes)) + self.assertGreater(len(output_model_bytes), 0) + + def test_fail_load_uncompiled_model_and_then_compile(self): + """ + Tests compiling scenario: + - Load uncompiled model into session that disables JIT compilation. + - Expect an error (ModelRequiresCompilation) + - Compile model and retry creating an inference session successfully. + """ + if "QNNExecutionProvider" not in available_providers: + self.skipTest("Skipping test because it needs to run on a compiling EP") + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + session_options = onnxrt.SessionOptions() + session_options.add_session_config_entry("session.disable_model_compile", "1") # Disable JIT model compilation! + session_options.add_provider("QNNExecutionProvider", {"backend_type": "htp"}) + + # Session creation should fail with error ORT_MODEL_REQUIRES_COMPILATION because the input model + # is not compiled and we disabled JIT compilation for this session. + with self.assertRaises(ModelRequiresCompilation) as context: + onnxrt.InferenceSession( + input_model_path, + sess_options=session_options, + enable_fallback=False, + ) + self.assertIn("needs to compile", str(context.exception)) + + # Try to compile the model now. + compiled_model_path = os.path.join(self._tmp_dir_path, "model.compiled3.onnx") + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path="external_weights.bin", + external_initializers_size_threshold=128, + ) + model_compiler.compile_to_file(compiled_model_path) + + self.assertTrue(os.path.exists(compiled_model_path)) + self.assertEqual(session_options.get_session_config_entry("session.disable_model_compile"), "1") + self.assertTrue(session_options.has_providers()) + + # Creating the session with the compiled model should not fail. + sess = onnxrt.InferenceSession(compiled_model_path, sess_options=session_options) + self.assertIsNotNone(sess) + + +if __name__ == "__main__": + unittest.main(verbosity=1) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 03b51790e0ef6..f4b7a78fe4c99 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1667,6 +1667,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) + log.info("Testing AutoEP feature") + run_subprocess([sys.executable, "onnxruntime_test_python_autoep.py"], cwd=cwd, dll_path=dll_path) + if not args.disable_contrib_ops: run_subprocess([sys.executable, "onnxruntime_test_python_sparse_matmul.py"], cwd=cwd, dll_path=dll_path) @@ -1761,6 +1764,12 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): cwd=cwd, ) + if not args.disable_contrib_ops: + log.info("Testing Python Compile API") + run_subprocess( + [sys.executable, "onnxruntime_test_python_compile_api.py"], cwd=cwd, dll_path=dll_path + ) + if not args.skip_onnx_tests: run_subprocess([os.path.join(cwd, "onnx_test_runner"), "test_models"], cwd=cwd) if config != "Debug": From bb5d2c22b8f8be3566e95323683013fab4076194 Mon Sep 17 00:00:00 2001 From: David Fan <30608893+jiafatom@users.noreply.github.com> Date: Tue, 6 May 2025 00:12:40 -0700 Subject: [PATCH 11/84] k_quant should have zero_point (#24647) ### Description As titled. --- .../python/tools/quantization/neural_compressor/weight_only.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py index 4eda7efc9b8fe..558415f028c7b 100644 --- a/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py +++ b/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py @@ -506,7 +506,7 @@ def rtn_quantize( k_blocks=k_blocks, q_weight=q_weight.astype("uint8"), scale=scale.astype(dtype), - zero_point=zp if scheme == "asym" else None, + zero_point=zp if scheme == "asym" or algorithm == "k_quant" else None, accuracy_level=accuracy_level, ) From 337839e33623abd220ab82e5c84d3db3bb5672e0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 6 May 2025 01:03:36 -0700 Subject: [PATCH 12/84] Bump vite from 6.2.6 to 6.3.4 in /js/web/test/e2e/exports/testcases/vite-default (#24607) Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 6.2.6 to 6.3.4.
Release notes

Sourced from vite's releases.

v6.3.4

Please refer to CHANGELOG.md for details.

v6.3.3

Please refer to CHANGELOG.md for details.

v6.3.2

Please refer to CHANGELOG.md for details.

create-vite@6.3.1

Please refer to CHANGELOG.md for details.

v6.3.1

Please refer to CHANGELOG.md for details.

create-vite@6.3.0

Please refer to CHANGELOG.md for details.

v6.3.0

Please refer to CHANGELOG.md for details.

v6.3.0-beta.2

Please refer to CHANGELOG.md for details.

v6.3.0-beta.1

Please refer to CHANGELOG.md for details.

v6.3.0-beta.0

Please refer to CHANGELOG.md for details.

v6.2.7

Please refer to CHANGELOG.md for details.

Changelog

Sourced from vite's changelog.

6.3.4 (2025-04-30)

  • fix: check static serve file inside sirv (#19965) (c22c43d), closes #19965
  • fix(optimizer): return plain object when using require to import externals in optimized dependenci (efc5eab), closes #19940
  • refactor: remove duplicate plugin context type (#19935) (d6d01c2), closes #19935

6.3.3 (2025-04-24)

  • fix: ignore malformed uris in tranform middleware (#19853) (e4d5201), closes #19853
  • fix(assets): ensure ?no-inline is not included in the asset url in the production environment (#1949 (16a73c0), closes #19496
  • fix(css): resolve relative imports in sass properly on Windows (#19920) (ffab442), closes #19920
  • fix(deps): update all non-major dependencies (#19899) (a4b500e), closes #19899
  • fix(ssr): fix execution order of re-export (#19841) (ed29dee), closes #19841
  • fix(ssr): fix live binding of default export declaration and hoist exports getter (#19842) (80a91ff), closes #19842
  • perf: skip sourcemap generation for renderChunk hook of import-analysis-build plugin (#19921) (55cfd04), closes #19921
  • test(ssr): test ssrTransform re-export deps and test stacktrace with first line (#19629) (9399cda), closes #19629

6.3.2 (2025-04-18)

6.3.1 (2025-04-17)

6.3.0 (2025-04-16)

6.3.0-beta.2 (2025-04-11)

... (truncated)

Commits
  • b040d54 release: v6.3.4
  • c22c43d fix: check static serve file inside sirv (#19965)
  • efc5eab fix(optimizer): return plain object when using require to import externals ...
  • d6d01c2 refactor: remove duplicate plugin context type (#19935)
  • db9eb97 release: v6.3.3
  • e4d5201 fix: ignore malformed uris in tranform middleware (#19853)
  • 55cfd04 perf: skip sourcemap generation for renderChunk hook of import-analysis-build...
  • ffab442 fix(css): resolve relative imports in sass properly on Windows (#19920)
  • 16a73c0 fix(assets): ensure ?no-inline is not included in the asset url in the produc...
  • 9399cda test(ssr): test ssrTransform re-export deps and test stacktrace with first ...
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=vite&package-manager=npm_and_yarn&previous-version=6.2.6&new-version=6.3.4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../testcases/vite-default/package-lock.json | 58 +++++++++++++++++-- .../testcases/vite-default/package.json | 2 +- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index c9da59b4b0021..48f0a8f3e9d5c 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -12,7 +12,7 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.6" + "vite": "^6.3.5" } }, "node_modules/@babel/helper-string-parser": { @@ -944,6 +944,21 @@ "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", "license": "MIT" }, + "node_modules/fdir": { + "version": "6.4.4", + "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.4.4.tgz", + "integrity": "sha512-1NZP+GK4GfuAv3PqKvxQRDMjdSRZjnkq7KfhlNrCNNlZ0ygQFpebfrnfnq/W7fpUnAv9aGWmY1zKx7FYL3gwhg==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "picomatch": "^3 || ^4" + }, + "peerDependenciesMeta": { + "picomatch": { + "optional": true + } + } + }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -992,6 +1007,19 @@ "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", "license": "ISC" }, + "node_modules/picomatch": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.2.tgz", + "integrity": "sha512-M7BAV6Rlcy5u+m6oPhAPFgJTzAioX/6B0DxyvDlo9l8+T3nLKbrczg2WLUyzd45L8RqfUMyGPzekbMvX2Ldkwg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, "node_modules/postcss": { "version": "8.5.3", "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.5.3.tgz", @@ -1068,16 +1096,36 @@ "node": ">=0.10.0" } }, + "node_modules/tinyglobby": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.13.tgz", + "integrity": "sha512-mEwzpUgrLySlveBwEVDMKk5B57bhLPYovRfPAXD5gA/98Opn0rCDj3GtLwFvCvH5RK9uPCExUROW5NjDwvqkxw==", + "dev": true, + "license": "MIT", + "dependencies": { + "fdir": "^6.4.4", + "picomatch": "^4.0.2" + }, + "engines": { + "node": ">=12.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/SuperchupuDev" + } + }, "node_modules/vite": { - "version": "6.2.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.6.tgz", - "integrity": "sha512-9xpjNl3kR4rVDZgPNdTL0/c6ao4km69a/2ihNQbcANz8RuCOK3hQBmLSJf3bRKVQjVMda+YvizNE8AwvogcPbw==", + "version": "6.3.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.3.5.tgz", + "integrity": "sha512-cZn6NDFE7wdTpINgs++ZJ4N49W2vRp8LCKrn3Ob1kYNtOo21vfDoaV5GzBfLU4MovSAB8uNRm4jgzVQZ+mBzPQ==", "dev": true, "license": "MIT", "dependencies": { "esbuild": "^0.25.0", + "fdir": "^6.4.4", + "picomatch": "^4.0.2", "postcss": "^8.5.3", - "rollup": "^4.30.1" + "rollup": "^4.34.9", + "tinyglobby": "^0.2.13" }, "bin": { "vite": "bin/vite.js" diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json index 5169734074299..f7d5751354905 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package.json @@ -13,6 +13,6 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.6" + "vite": "^6.3.5" } } From 01aef5f71cc1d8d35ef64cad56a1235263772dc9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 May 2025 01:08:23 -0700 Subject: [PATCH 13/84] add check in NPM packaging pipeline for @dev package version (#24641) ### Description This PR adds a check for the package version for dev channel. This PR should be able to help avoid publishing packages like "-rc.*" to dev channel automatically. ### Motivation and Context --- tools/ci_build/github/js/validate-npm-packages.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tools/ci_build/github/js/validate-npm-packages.py b/tools/ci_build/github/js/validate-npm-packages.py index b009330764973..37b5b3d9807a3 100644 --- a/tools/ci_build/github/js/validate-npm-packages.py +++ b/tools/ci_build/github/js/validate-npm-packages.py @@ -129,6 +129,16 @@ if RELEASE_WEB and RELEASE_REACT_NATIVE and ort_web_ver != ort_react_native_ver: raise Exception("version number is different for onnxruntime-web and onnxruntime-react-native") +# @dev build has to match the following pattern: +# "x.y.z-dev.*" +if tag == "dev": + if RELEASE_NODE and "-dev" not in ort_node_ver: + raise Exception(f'@dev build version should contain "-dev". ort_node_ver={ort_node_ver}') + if RELEASE_WEB and "-dev" not in ort_web_ver: + raise Exception(f'@dev build version should contain "-dev". ort_web_ver={ort_web_ver}') + if RELEASE_REACT_NATIVE and "-dev" not in ort_react_native_ver: + raise Exception(f'@dev build version should contain "-dev". ort_react_native_ver={ort_react_native_ver}') + print("====== validated versions ======") print(f"source_branch={source_branch}") print(f"tag={tag}") From 1f4156c6146d04999e5e419df4ae6628e928aaad Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 6 May 2025 18:29:05 +1000 Subject: [PATCH 14/84] Add support for selection policy delegate (#24635) ### Description Add support for selection policy delegate - split API function into one for the policy enum and one for the delegate - add `void*` for user state - required to wire up using the delegate in other languages. Add C# support for specifying the selection policy delegate. Address comments from initial C# autoep support PR. ### Motivation and Context --- .../NativeMethods.shared.cs | 99 +++++++++++- .../OrtEpDevice.shared.cs | 30 ++-- .../OrtHardwareDevice.shared.cs | 30 ++-- .../SessionOptions.shared.cs | 151 +++++++++++++++++- .../OrtAutoEpTests.cs | 49 +++++- .../core/session/onnxruntime_c_api.h | 38 +++-- .../core/session/onnxruntime_cxx_api.h | 6 +- .../core/session/onnxruntime_cxx_inline.h | 11 +- onnxruntime/core/framework/session_options.h | 3 +- .../core/session/abi_session_options.cc | 16 +- onnxruntime/core/session/inference_session.cc | 5 + onnxruntime/core/session/inference_session.h | 2 + onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 7 +- .../core/session/provider_policy_context.cc | 52 ++++-- onnxruntime/core/session/utils.cc | 46 +++--- .../test/autoep/test_autoep_selection.cc | 149 ++++++++++++++++- 17 files changed, 592 insertions(+), 105 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 620c13b8641b5..c543414ca13a9 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -353,6 +353,7 @@ public struct OrtApi public IntPtr SessionOptionsAppendExecutionProvider_V2; public IntPtr SessionOptionsSetEpSelectionPolicy; + public IntPtr SessionOptionsSetEpSelectionPolicyDelegate; public IntPtr HardwareDevice_Type; public IntPtr HardwareDevice_VendorId; @@ -692,6 +693,11 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicy, typeof(DSessionOptionsSetEpSelectionPolicy)); + + OrtSessionOptionsSetEpSelectionPolicyDelegate = + (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicyDelegate, + typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); } internal class NativeLib @@ -2278,28 +2284,49 @@ out IntPtr lora_adapter #region Auto EP API related // // OrtKeyValuePairs + + /// + /// Create an OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps); + /// + /// Add/replace a key-value pair in the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key, byte[] /* const char* */ value); + /// + /// Get the value for the provided key. + /// + /// Value. Returns IntPtr.Zero if key was not found. [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key); + /// + /// Get all the key-value pairs in the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps, out IntPtr /* const char* const** */ keys, out IntPtr /* const char* const** */ values, out UIntPtr /* size_t* */ numEntries); + /// + /// Remove a key-value pair from the OrtKeyValuePairs instance. + /// Ignores keys that are not present. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key); + /// + /// Release the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps); @@ -2370,12 +2397,27 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, // // Auto Selection EP registration and selection customization + + /// + /// Register an execution provider library. + /// The library must implement CreateEpFactories and ReleaseEpFactory. + /// + /// Environment to add the EP library to. + /// Name to register the library under. + /// Absolute path to the library. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary( IntPtr /* OrtEnv* */ env, byte[] /* const char* */ registration_name, byte[] /* const ORTCHAR_T* */ path); + /// + /// Unregister an execution provider library. + /// + /// The environment to unregister the library from. + /// The name the library was registered under. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary( IntPtr /* OrtEnv* */ env, @@ -2384,6 +2426,11 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; + /// + /// Get the OrtEpDevices that are available. + /// These are all the possible execution provider and device pairs. + /// + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices( IntPtr /* const OrtEnv* */ env, @@ -2392,6 +2439,20 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; + /// + /// Add execution provider devices to the session options. + /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. + /// All OrtEpDevice instances in ep_devices must be for the same execution provider. + /// e.g. selecting OpenVINO for GPU and NPU would have an OrtEpDevice for GPU and NPU. + /// + /// SessionOptions to add to. + /// Environment that the OrtEpDevice instances came from by calling GetEpDevices + /// One or more OrtEpDevice instances. + /// Number of OrtEpDevice instances. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// Number of user overrides for execution provider options. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2( IntPtr /* OrtSessionOptions* */ sess_options, @@ -2404,6 +2465,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2; + /// + /// Delegate to do custom execution provider selection. + /// + /// Available OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Metadata from the ONNX model. + /// Runtime metadata. May be IntPtr.Zero. + /// OrtEpDevices that were selected. Pre-allocated array for delegate to update. + /// Maximum number of OrtEpDevices that can be selected. + /// Number of OrtEpDevices that were selected. + /// State that was provided in when the delegate was registered. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtEpSelectionDelegate( IntPtr /* OrtEpDevice** */ epDevices, @@ -2412,16 +2485,36 @@ public delegate IntPtr DOrtEpSelectionDelegate( IntPtr /* OrtKeyValuePairs* */ runtimeMetadata, IntPtr /* OrtEpDevice** */ selected, uint maxSelected, - out UIntPtr numSelected + out UIntPtr numSelected, + IntPtr /* void* */ state ); + /// + /// Set the execution provider selection policy. + /// + /// SessionOptions to set the policy for. + /// Selection policy. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy( IntPtr /* OrtSessionOptions* */ session_options, - int /* OrtExecutionProviderDevicePolicy */ policy, - IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate); + int /* OrtExecutionProviderDevicePolicy */ policy); public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy; + /// + /// Set the execution provider selection policy delegate. + /// + /// SessionOptions to set the policy for. + /// Selection policy delegate. + /// State that is passed through to the selection delegate. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicyDelegate( + IntPtr /* OrtSessionOptions* */ session_options, + IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate, + IntPtr /* void* */ state); + public static DSessionOptionsSetEpSelectionPolicyDelegate OrtSessionOptionsSetEpSelectionPolicyDelegate; + #endregion #region Misc API diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs index e3947d900214e..0318e08519128 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -10,18 +10,18 @@ namespace Microsoft.ML.OnnxRuntime /// Represents the combination of an execution provider and a hardware device /// that the execution provider can utilize. ///
- public class OrtEpDevice : SafeHandle + public class OrtEpDevice { /// /// Construct an OrtEpDevice from an existing native OrtEpDevice instance. /// /// Native OrtEpDevice handle. internal OrtEpDevice(IntPtr epDeviceHandle) - : base(epDeviceHandle, ownsHandle: false) { + _handle = epDeviceHandle; } - internal IntPtr Handle => handle; + internal IntPtr Handle => _handle; /// /// The name of the execution provider. @@ -30,7 +30,7 @@ public string EpName { get { - IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(handle); + IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr); } } @@ -42,7 +42,7 @@ public string EpVendor { get { - IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(handle); + IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); } } @@ -54,7 +54,7 @@ public OrtKeyValuePairs EpMetadata { get { - return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(_handle)); } } @@ -65,7 +65,7 @@ public OrtKeyValuePairs EpOptions { get { - return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(_handle)); } } @@ -76,23 +76,11 @@ public OrtHardwareDevice HardwareDevice { get { - IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(handle); + IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(_handle); return new OrtHardwareDevice(devicePtr); } } - /// - /// Indicates whether the native handle is invalid. - /// - public override bool IsInvalid => handle == IntPtr.Zero; - - /// - /// No-op. OrtEpDevice is always read-only as the instance is owned by native ORT. - /// - /// True - protected override bool ReleaseHandle() - { - return true; - } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs index 8e7caae90ff79..af7115a92285e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs @@ -21,16 +21,16 @@ public enum OrtHardwareDeviceType /// /// Represents a hardware device that is available on the current system. /// - public class OrtHardwareDevice : SafeHandle + public class OrtHardwareDevice { /// /// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance. /// /// Native OrtHardwareDevice handle. - internal OrtHardwareDevice(IntPtr deviceHandle) - : base(deviceHandle, ownsHandle: false) + internal OrtHardwareDevice(IntPtr deviceHandle) { + _handle = deviceHandle; } /// @@ -40,7 +40,7 @@ public OrtHardwareDeviceType Type { get { - return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(handle); + return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(_handle); } } @@ -54,7 +54,7 @@ public uint VendorId { get { - return NativeMethods.OrtHardwareDevice_VendorId(handle); + return NativeMethods.OrtHardwareDevice_VendorId(_handle); } } @@ -65,7 +65,7 @@ public string Vendor { get { - IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(handle); + IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); } } @@ -82,7 +82,7 @@ public uint DeviceId { get { - return NativeMethods.OrtHardwareDevice_DeviceId(handle); + return NativeMethods.OrtHardwareDevice_DeviceId(_handle); } } @@ -95,22 +95,10 @@ public OrtKeyValuePairs Metadata { get { - return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(_handle)); } } - /// - /// Indicates whether the native handle is invalid. - /// - public override bool IsInvalid => handle == IntPtr.Zero; - - /// - /// No-op. OrtHardwareDevice is always read-only as the instance is owned by native ORT. - /// - /// True - protected override bool ReleaseHandle() - { - return true; - } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index de6189e105f78..d60bf75ccbd7c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -636,9 +636,32 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy) { NativeApiStatus.VerifySuccess( - NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy, IntPtr.Zero)); + NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy)); } + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Delegate that implements the custom selection policy. + public void SetEpSelectionPolicyDelegate(EpSelectionDelegate selectionDelegate = null) + { + _epSelectionPolicyConnector = new EpSelectionPolicyConnector(selectionDelegate); + _epSelectionPolicyDelegate = new NativeMethods.DOrtEpSelectionDelegate( + EpSelectionPolicyConnector.EpSelectionPolicyWrapper); + + // make sure these stay alive. not sure if this is necessary when they're class members though + _epSelectionPolicyConnectorHandle = GCHandle.Alloc(_epSelectionPolicyConnector); + _epSelectionPolicyDelegateHandle = GCHandle.Alloc(_epSelectionPolicyDelegate); + + IntPtr funcPtr = Marshal.GetFunctionPointerForDelegate(_epSelectionPolicyDelegate); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicyDelegate( + handle, + funcPtr, + GCHandle.ToIntPtr(_epSelectionPolicyConnectorHandle))); + } #endregion internal IntPtr Handle @@ -914,7 +937,98 @@ public void SetLoadCancellationFlag(bool value) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); } + #endregion + + #region Selection Policy Delegate helpers + /// + /// Delegate to select execution provider devices from a list of available devices. + /// + /// OrtEpDevices to select from. + /// Model metadata. + /// Runtime metadata. + /// Maximum number of devices that can be selected. + /// Selected devices. Ordered by priority. Highest priority first. + public delegate List EpSelectionDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections); + + /// + /// Class to bridge the C# and native worlds for the EP selection policy delegate + /// + internal class EpSelectionPolicyConnector + { + private readonly EpSelectionDelegate _csharpDelegate; + + internal EpSelectionPolicyConnector(EpSelectionDelegate selectionDelegate) + { + _csharpDelegate = selectionDelegate; + } + + /// + /// Delegate to convert between the C and C# worlds + /// + /// OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Model metadata. + /// Runtime metadata. + /// Pre-allocated OrtEpDevice buffer to update with selected devices. + /// Number of entries in selectedOut. + /// Number of OrtEpDevies that were selected. + /// Opaque state. + /// nullptr for OrtStatus* to indicate success. + /// Currently we don't have a way to create an OrtStatus instance from the C# bindings. + /// Can add if we need to return an explicit error message. + /// + public static IntPtr EpSelectionPolicyWrapper(IntPtr /* OrtEpDevice** */ epDevicesIn, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadataIn, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadataIn, + IntPtr /* OrtEpDevice** */ selectedOut, + uint maxSelected, + out UIntPtr numSelected, + IntPtr state) + { + Span epDevicesIntPtrs; + Span selectedDevicesIntPtrs; + EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target; + + unsafe + { + void* ptr = epDevicesIn.ToPointer(); + epDevicesIntPtrs = new Span(ptr, checked((int)numDevices)); + } + + List epDevices = new List(); + for (int i = 0; i < numDevices; i++) + { + + epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i])); + } + + OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn); + OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn); + var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected); + + numSelected = (UIntPtr)selected.Count; + + unsafe + { + void* ptr = selectedOut.ToPointer(); + selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected); + } + + int idx = 0; + foreach (var epDevice in selected) + { + selectedDevicesIntPtrs[idx] = epDevice.Handle; + idx++; + } + + return IntPtr.Zero; + } + } #endregion #region Private Methods @@ -1000,8 +1114,43 @@ protected override bool ReleaseHandle() { NativeMethods.OrtReleaseSessionOptions(handle); handle = IntPtr.Zero; + + if (_epSelectionPolicyConnectorHandle.IsAllocated) + { + _epSelectionPolicyConnectorHandle.Free(); + _epSelectionPolicyConnector = null; + } + + if (_epSelectionPolicyDelegateHandle.IsAllocated) + { + _epSelectionPolicyDelegateHandle.Free(); + _epSelectionPolicyDelegate = null; + } + + return true; } #endregion + + /// + /// Helper class to connect C and C# usage of the EP selection policy delegate. + /// + EpSelectionPolicyConnector _epSelectionPolicyConnector = null; + + /// + /// Handle to the EP selection policy connector that is passed to the C API as state for the + /// EP selection policy delegate. + /// + GCHandle _epSelectionPolicyConnectorHandle = default; + + /// + /// Delegate instance that is provided to the C API. + /// + NativeMethods.DOrtEpSelectionDelegate _epSelectionPolicyDelegate = null; + + /// + /// Handle to the EP selection policy delegate that is passed to the C API. + /// + GCHandle _epSelectionPolicyDelegateHandle = default; } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index 1aa4db15d275c..d95a649bd95c5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -93,7 +93,7 @@ public void AppendToSessionOptionsV2() { var runTest = (Func> getEpOptions) => { - SessionOptions sessionOptions = new SessionOptions(); + using SessionOptions sessionOptions = new SessionOptions(); sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; var epDevices = ortEnvInstance.GetEpDevices(); @@ -138,7 +138,7 @@ public void AppendToSessionOptionsV2() [Fact] public void SetEpSelectionPolicy() { - SessionOptions sessionOptions = new SessionOptions(); + using SessionOptions sessionOptions = new SessionOptions(); sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; var epDevices = ortEnvInstance.GetEpDevices(); @@ -150,7 +150,50 @@ public void SetEpSelectionPolicy() var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); // session should load successfully - using (var session = new InferenceSession(model)) + using (var session = new InferenceSession(model, sessionOptions)) + { + Assert.NotNull(session); + } + } + + private static List SelectionPolicyDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + Assert.NotEmpty(modelMetadata.Entries); + Assert.True(epDevices.Count > 0); + + // select first device and last (if there are more than one). + var selected = new List(); + + selected.Add(epDevices[0]); + + // add ORT CPU EP which is always last. + if (maxSelections > 2 && epDevices.Count > 1) + { + selected.Add(epDevices.Last()); + } + + return selected; + } + + [Fact] + public void SetEpSelectionPolicyDelegate() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegate); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model, sessionOptions)) { Assert.NotNull(session); } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cef5eab9a505e..6c7d910b4963b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -438,18 +438,20 @@ typedef enum OrtExecutionProviderDevicePolicy { * \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array. Currently the maximum is 8. * \param num_ep_devices The number of selected devices. + * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. * * \return OrtStatus* Selection status. Return nullptr on success. * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. * ORT will release the OrtStatus* if not null. */ -typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, - _In_ size_t num_devices, - _In_ const OrtKeyValuePairs* model_metadata, - _In_opt_ const OrtKeyValuePairs* runtime_metadata, - _Inout_ const OrtEpDevice** selected, - _In_ size_t max_selected, - _Out_ size_t* num_selected); +typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state); /** \brief Algorithm to use for cuDNN Convolution Op */ @@ -5127,18 +5129,30 @@ struct OrtApi { /** \brief Set the execution provider selection policy for the session. * - * Allows users to specify a device selection policy for automatic execution provider (EP) selection, - * or provide a delegate callback for custom selection logic. + * Allows users to specify a device selection policy for automatic execution provider (EP) selection. + * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead. * * \param[in] session_options The OrtSessionOptions instance. * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). - * \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy. * * \since Version 1.22 */ ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate); + _In_ OrtExecutionProviderDevicePolicy policy); + + /** \brief Set the execution provider selection policy delegate for the session. + * + * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] delegate Delegate callback for custom selection. + * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* delegate_state); /** \brief Get the hardware device type. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 6c175c606b4a1..bc6f381bb82a0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1103,8 +1103,10 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::unordered_map& ep_options); /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy - SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, - EpSelectionDelegate* delegate = nullptr); + SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy); + + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicyDelegate + SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr); SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 1fdb8f16d9600..94ad2118fa4d6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1150,9 +1150,14 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( } template -inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, - EpSelectionDelegate* delegate) { - ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate)); +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state)); return *this; } diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8f8a3d6634a7e..89a43c4f71ee6 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -96,7 +96,8 @@ struct EpSelectionPolicy { // and no selection policy was explicitly provided. bool enable{false}; OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT; - EpSelectionDelegate* delegate{}; + EpSelectionDelegate delegate{}; + void* state{nullptr}; // state for the delegate }; /** diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index b1c0467da642e..c205e05baadb9 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -367,12 +367,24 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate) { + _In_ OrtExecutionProviderDevicePolicy policy) { API_IMPL_BEGIN options->value.ep_selection_policy.enable = true; options->value.ep_selection_policy.policy = policy; + options->value.ep_selection_policy.delegate = nullptr; + options->value.ep_selection_policy.state = nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* options, + _In_opt_ EpSelectionDelegate delegate, + _In_opt_ void* state) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT; options->value.ep_selection_policy.delegate = delegate; + options->value.ep_selection_policy.state = state; return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 8ec7312cc6354..df70856a64e99 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3266,6 +3266,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod // save model metadata model_metadata_.producer_name = model.ProducerName(); + model_metadata_.producer_version = model.ProducerVersion(); model_metadata_.description = model.DocString(); model_metadata_.graph_description = model.GraphDocString(); model_metadata_.domain = model.Domain(); @@ -3430,6 +3431,10 @@ const Model& InferenceSession::GetModel() const { return *model_; } +const Environment& InferenceSession::GetEnvironment() const { + return environment_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ba9812a59fec3..51350390a0456 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -80,6 +80,7 @@ struct ModelMetadata { ModelMetadata& operator=(const ModelMetadata&) = delete; std::string producer_name; + std::string producer_version; std::string graph_name; std::string domain; std::string description; @@ -603,6 +604,7 @@ class InferenceSession { #endif const Model& GetModel() const; + const Environment& GetEnvironment() const; protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 304966605c9cf..d03b98a9c1eb5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2980,6 +2980,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetEpDevices, &OrtApis::SessionOptionsAppendExecutionProvider_V2, &OrtApis::SessionOptionsSetEpSelectionPolicy, + &OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, &OrtApis::HardwareDevice_Type, &OrtApis::HardwareDevice_VendorId, @@ -3029,7 +3030,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 7be518a39480f..47d1a543b5a31 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -577,8 +577,11 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt size_t num_ep_options); ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate); + _In_ OrtExecutionProviderDevicePolicy policy); + +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* sess_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* state); // OrtHardwareDevice accessors. ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 4ce13fe36ea86..f706bd05d8494 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -94,8 +94,8 @@ std::vector OrderDevices(const std::vectorep_name < b->ep_name; } // one is the default CPU EP @@ -104,31 +104,57 @@ std::vector OrderDevices(const std::vector GPU -> NPU // TODO: Should environment.cc do the ordering? - const auto& execution_devices = OrderDevices(env.GetOrtEpDevices()); + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies std::vector devices_selected; // Run the delegate if it was passed in lieu of any other policy if (options.value.ep_selection_policy.delegate) { - auto policy_fn = options.value.ep_selection_policy.delegate; + auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); std::array selected_devices{nullptr}; - size_t num_selected = 0; - auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(), - nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected); + + EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; + auto* status = delegate(delegate_devices.data(), delegate_devices.size(), + &model_metadata, &runtime_metadata, + selected_devices.data(), selected_devices.size(), &num_selected, + options.value.ep_selection_policy.state); // return or fall-through for both these cases // going with explicit failure for now so it's obvious to user what is happening @@ -142,6 +168,12 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const if (num_selected == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); } + + // Copy the selected devices to the output vector + devices_selected.reserve(num_selected); + for (size_t i = 0; i < num_selected; ++i) { + devices_selected.push_back(selected_devices[i]); + } } else { // Create the selector for the chosen policy std::unique_ptr selector; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index c05394039d8c7..8ca4ef6af1f44 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -176,20 +176,6 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op env); } -#if !defined(ORT_MINIMAL_BUILD) - // TEMPORARY for testing. Manually specify the EP to select. - auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); - if (auto_select_ep_name) { - ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env, *sess, *auto_select_ep_name)); - } - - // if there are no providers registered, and there's an ep selection policy set, do auto ep selection - if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { - ProviderPolicyContext context; - ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env, *options, *sess)); - } -#endif // !defined(ORT_MINIMAL_BUILD) - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Add custom domains if (options && !options->custom_op_domains_.empty()) { @@ -231,22 +217,38 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_ENFORCE(session_logger != nullptr, "Session logger is invalid, but should have been initialized during session construction."); - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { + const bool has_provider_factories = options != nullptr && !options->provider_factories.empty(); + + if (has_provider_factories) { + std::vector> provider_list; for (auto& factory : options->provider_factories) { auto provider = factory->CreateProvider(*options, *session_logger->ToExternal()); provider_list.push_back(std::move(provider)); } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } } +#if !defined(ORT_MINIMAL_BUILD) + else { + // TEMPORARY for testing. Manually specify the EP to select. + auto auto_select_ep_name = sess.GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); + if (auto_select_ep_name) { + ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(sess.GetEnvironment(), sess, *auto_select_ep_name)); + } - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection. + // note: the model has already been loaded so model metadata should be available to the policy delegate callback. + if (options != nullptr && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(sess.GetEnvironment(), *options, sess)); } } +#endif // !defined(ORT_MINIMAL_BUILD) if (prepacked_weights_container != nullptr) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index 04b1b2ea0bdc4..cea1299adc26f 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -68,6 +68,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod const std::function&)>& select_devices = nullptr, // auto select using policy std::optional policy = std::nullopt, + std::optional delegate = std::nullopt, bool test_session_creation_only = false) { Ort::SessionOptions session_options; @@ -77,7 +78,9 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } if (auto_select) { - if (policy) { + if (delegate) { + session_options.SetEpSelectionPolicy(*delegate, nullptr); + } else if (policy) { session_options.SetEpSelectionPolicy(*policy); } else { // manually specify EP to select @@ -353,6 +356,150 @@ TEST(AutoEpSelection, PreferNpu) { OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU); } +static OrtStatus* ORT_API_CALL PolicyDelegate(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + if (max_selected <= 2) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Expected to be able to select 2 devices."); + } + + if (model_metadata->entries.empty()) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Model metadata was empty."); + } + + selected[0] = ep_devices[0]; + *num_selected = 1; + if (num_devices > 1) { + // CPU EP is always last. + selected[1] = ep_devices[num_devices - 1]; + *num_selected = 2; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateSelectNone(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateReturnError(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Selection error."); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegate) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegate); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegateSelectsNothing) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateSelectNone, + /*test_session_creation_only*/ true), + Ort::Exception); +} + +TEST(AutoEpSelection, PolicyDelegateReturnsError) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateReturnError, + /*test_session_creation_only*/ true), + Ort::Exception); +} + namespace { struct ExamplePluginInfo { const std::filesystem::path library_path = From 5160c67a3bdc1bc8de80be2afa1da3affb84ad46 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 7 May 2025 00:38:05 +0800 Subject: [PATCH 15/84] [webgpu] Add Matmul8bits Support (#24546) ### Description This PR adds the support for 8-bit quantization in the `MatMulNBits` operation in WebGPU. It does below things: 1. Unify to use `MatMulNBitsProgram` as the fallback path which is the original generation path for block size = 32. Now make it support any blocks size without limitations. And remove the original complicated programs. 2. Enable `MatMulNBitsWideTileProgram` for all platforms. --- .../webgpu/quantization/matmul_nbits.cc | 863 ++++++------------ .../webgpu/quantization/matmul_nbits.h | 47 +- .../test/contrib_ops/matmul_8bits_test.cc | 12 +- 3 files changed, 303 insertions(+), 619 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index a29cdea6b4af9..65ecdff44acd6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "contrib_ops/webgpu/quantization/matmul_nbits.h" @@ -18,17 +19,42 @@ namespace webgpu { namespace { -std::string QuantizedDataType(int components) { - switch (components) { - case 1: - return "array"; - case 2: - return "mat4x2"; - case 4: - return "mat2x4"; - default: - return "array"; +std::string ReadZeroPoint(uint32_t nbits, bool has_zero_points) { + ORT_ENFORCE(nbits == 8 || nbits == 4, "Only 4/8 bits are supported for webgpu matmulnbits"); + std::stringstream ss; + if (has_zero_points) { + ss << "const elements_in_uint32 = " << (32 / nbits) << "u;\n" + << "const bits = " << nbits << "u;\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { + if (row < r_dim && col < c_dim) { + let offset = row * c_dim + col; + + // u32 holds elements_in_uint32 packed nbits. + let array_index = offset / elements_in_uint32; + let component_index = offset % elements_in_uint32; + let packed_value = zero_points[array_index]; + + // Extract the nbits component + let shift_amount = component_index * bits; +)"; + ss << " let masked_value = (packed_value >> shift_amount) & " << (nbits == 4 ? "0xFu" : "0xFF") << ";\n"; + ss << R"( + return output_element_t(masked_value); } + return output_element_t(0); +} +)"; + } else { + ss << "const default_zero_point = " << (nbits == 4 ? 8 : 128) << ";\n"; + ss << R"( +fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_element_t { + // The default zero point is 8. + return output_element_t(default_zero_point); +} +)"; + } + return ss.str(); } constexpr unsigned int kMinMForTileOptimization = 4; @@ -46,483 +72,6 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); -Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); - const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - - if (block_size_ == 32) { - const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); - const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. - const uint32_t a_length_per_tile = tile_size / a.NumComponents(); - const uint32_t blocks_per_tile = tile_size / block_size_; - if (tile_m_ > 1 && use_subgroup_) { - ORT_ENFORCE(a.NumComponents() == 4, "input a's components must be equal to 4."); - ORT_ENFORCE(components_b_ == 4, "input b's components must be equal to 4."); - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_b: array, " << WorkgroupSizeY() << ">;\n" - << "var sub_scale: array, " << WorkgroupSizeY() << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; - shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m_ << ";\n" - << " let batch = workgroup_id.z;\n"; - shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" - // Loop over shared dimension. - << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" - << " // load one tile B/scale data into shared memory.\n" - // Each thread processes one block. - " let b_col = col + local_id.y;\n" - << " let block = tile * " << blocks_per_tile << " + local_id.x;\n" - << " if (b_col < uniforms.input_b_shape[0] && block < n_blocks_per_col) {\n" - << " sub_b[local_id.y][local_id.x] = " << b.GetByIndices("input_b_indices_t(b_col, block, 0)") << ";\n" - << " sub_scale[local_id.y][local_id.x] = " << scales.GetByOffset("b_col * n_blocks_per_col + block") << ";\n" - << " } else {\n" - " sub_b[local_id.y][local_id.x] = input_b_value_t(0);\n" - " sub_scale[local_id.y][local_id.x] = output_value_t(0);\n" - " }\n" - " workgroupBarrier();\n" - << " var in_y = (local_idx % 32) / 4;\n" - " var in_x = (local_idx / 32) * 4 + local_idx % 4;\n" - << " var word_offset = (local_idx % 4) * " << block_size_ / a.NumComponents() << ";\n" - << " if (sg_size == 8u) {\n" - " in_y = local_idx % 8;\n" - " in_x = local_idx / 8;\n" - << " word_offset = 0u;\n" - " } else if (sg_size == 16u) {\n" - " in_y = (local_idx % 16) / 2;\n" - " in_x = (local_idx / 16) * 2 + local_idx % 2;\n" - << " word_offset = (local_idx % 2) * " << block_size_ / a.NumComponents() << ";\n" - << " } else if (sg_size == 32u) {\n" - " in_y = (local_idx % 32) / 4;\n" - " in_x = (local_idx / 32) * 4 + local_idx % 4;\n" - << " word_offset = (local_idx % 4) * " << block_size_ / a.NumComponents() << ";\n" - << " } else if (sg_size == 64u) {\n" - " in_y = local_idx / 8;\n" - " in_x = local_idx % 8;\n" - << " word_offset = (local_idx % 8) * " << block_size_ / a.NumComponents() << ";\n" - << " }\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " let zero_point_byte_count = b_col * zero_point_bytes_per_col + (block >> 0x1u);\n" - " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; - } else { - // The default zero point is 8 for unsigned 4-bit quantization. - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - } - shader.MainFunctionBody() << " let scale = sub_scale[in_y][in_x];\n" - " let b_data = sub_b[in_y][in_x];\n"; - shader.MainFunctionBody() << " let a_col_start = tile * " << a_length_per_tile << ";\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - shader.MainFunctionBody() << " let a_data" << i << " = mm_readA(batch, row + " << i << ", a_col_start + local_idx);\n"; - } - - shader.MainFunctionBody() << " if (sg_size == 8u) {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", i * 2);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", i * 2 + 1);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " } else if (sg_size == 16u) {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", i * 2);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a00 = subgroupShuffle(a_data" << i << ", i * 2 + 8);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", i * 2 + 1);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a11 = subgroupShuffle(a_data" << i << ", i * 2 + 9);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(select(a00, a0, local_idx % 2 == 0), b_dequantized_values[0]) + dot(select(a11, a1, local_idx % 2 == 0), b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n"; - shader.MainFunctionBody() << " } else {\n"; - shader.MainFunctionBody() << " for (var i: u32 = 0; i < 4; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data[i];\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4(zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point, zero_point)) * scale;\n"; - for (uint32_t i = 0; i < tile_m_; i++) { - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a0 = subgroupShuffle(a_data" << i << ", word_offset);\n"; - if (i == 0) { - shader.MainFunctionBody() << " var "; - } - shader.MainFunctionBody() << " a1 = subgroupShuffle(a_data" << i << ", word_offset + 1);\n"; - shader.MainFunctionBody() << " inter_results[" << i << "][in_y][in_x] += dot(a0, b_dequantized_values[0]) + dot(a1, b_dequantized_values[1]);\n"; - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " workgroupBarrier();\n"; - - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() * tile_m_ << ") {\n" - << " let inner_row = local_idx / " << WorkgroupSizeY() << ";\n" - << " let inner_col = local_idx % " << WorkgroupSizeY() << ";\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[inner_row][inner_col][b];\n" - " }\n" - " if (row + inner_row < uniforms.output_shape[1] && col + inner_col < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row + inner_row, col + inner_col)", "output_value") << ";\n" - << " }\n" - " }\n"; - } else { - if (tile_m_ == 1) { - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_a: array;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">;\n"; - std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY()); - shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" - << " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n"; - } else { - ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY."); - ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY."); - - shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" - << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" - << " } else {\n" - " return input_a_value_t(0);\n" - " }\n" - "}\n" - << "var sub_a: array," << tile_m_ << ">;\n" - << "var inter_results: array, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n"; - shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n" - << " let row = workgroup_id.y * " << tile_m_ << ";\n" - << " let batch = workgroup_id.z;\n"; - } - shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - << " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n" - // Loop over shared dimension. - << " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n" - << " let a_col_start = tile * " << a_length_per_tile << ";\n" - << " // load one tile A data into shared memory.\n" - << " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n" - << " let a_col = a_col_start + a_offset;\n"; - if (tile_m_ == 1) { - shader.MainFunctionBody() << " sub_a[a_offset] = mm_readA(batch, row, a_col);\n"; - } else { - for (uint32_t i = 0; i < tile_m_; i++) { - shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n"; - } - } - shader.MainFunctionBody() << " }\n" - " workgroupBarrier();\n" - // Each thread processes one block. - " let b_row = col + local_id.y;\n" - << " let block = tile * " << blocks_per_tile << " + local_id.x;\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n" - " let zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point = output_element_t((zero_point_word) & 0xFu);\n"; - } else { - // The default zero point is 8 for unsigned 4-bit quantization. - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - } - shader.MainFunctionBody() << " var scale = output_element_t(0);\n" - " var b_data = input_b_value_t(0);\n" - << " if (block < n_blocks_per_col) {\n" - << " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n" - << " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n" - << " }\n" - << " var word_offset = local_id.x * " << block_size_ / a.NumComponents() << ";\n" - << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - shader.MainFunctionBody() << " let b_value = b_data"; - if (components_b_ > 1) { - shader.MainFunctionBody() << "[i]"; - } - shader.MainFunctionBody() << ";\n" - " let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n" - " let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n" - " let b_quantized_values = mat2x4(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - " let b_dequantized_values = (b_quantized_values - mat2x4("; - for (int i = 0; i < 8; i++) { - shader.MainFunctionBody() << "zero_point"; - if (i < 7) { - shader.MainFunctionBody() << ", "; - } - } - shader.MainFunctionBody() << ")) * scale;\n"; - if (tile_m_ == 1) { - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; - } - } else { - for (uint32_t i = 0; i < tile_m_; i++) { - switch (a.NumComponents()) { - case 1: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n"; - break; - case 2: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n"; - break; - case 4: - shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n"; - break; - default: - break; - } - } - } - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n" - " workgroupBarrier();\n" - " }\n"; - if (tile_m_ == 1) { - shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_idx][b];\n" - " }\n" - " if (col + local_idx < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n" - << " }\n" - " }\n"; - } else { - shader.MainFunctionBody() << " if (local_id.y < " << tile_m_ << ") {\n" - << " var output_value = output_value_t(0);\n" - << " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n" - << " output_value += inter_results[local_id.y][local_id.x][b];\n" - " }\n" - " if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n" - << " }\n" - " }\n"; - } - } - } else { - const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); - const int output_element_number = y.NumComponents() * onnxruntime::narrow(output_number_); - - const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; - std::string offset = "workgroup_idx * " + std::to_string(output_number_); - shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; - shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" - << " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - " let blob_size = uniforms.input_b_shape[2];\n" - " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" - << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; - - // prepare scale and zero point - shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; - if (has_zero_points_) { - const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - " var zero_point_byte_count: u32;\n" - " var zero_point_word_index: u32;\n" - " var zero_point_byte_offset: u32;\n" - " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - " var zero_point_bits_offset: u32;\n" - " var zero_point_word: u32;\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" - << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" - " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" - << " col_index += 1;\n"; - } - } else { - shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" - << " col_index += 1;\n"; - } - } - - shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; - - // prepare b data - shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" - << " col_index += 1;\n"; - } - shader.MainFunctionBody() << " var b_value : u32;\n" - " let b_mask : u32 = 0x0F0F0F0Fu;\n" - " var b_value_lower : vec4;\n" - " var b_value_upper : vec4;\n" - << " var b_quantized_values : " << quantized_data_type << ";\n" - << " var b_dequantized_values : " << quantized_data_type << ";\n"; - - shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - - // process one word - shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" - << " var a_data: " << quantized_data_type << ";\n" - << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" - << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" - << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" - << " input_offset++;\n" - " } else {\n" - " a_data[j] = input_a_value_t(0);\n" - " }\n" - " }\n"; - for (int c = 0; c < output_element_number; c++) { - shader.MainFunctionBody() << " b_value = b" << c << "_data"; - if (components_b_ > 1) { - shader.MainFunctionBody() << "[i]"; - } - shader.MainFunctionBody() << ";\n" - " b_value_lower = unpack4xU8(b_value & b_mask);\n" - " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" - << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - << " b_dequantized_values = "; - if (a.NumComponents() == 1) { - if (has_zero_points_) { - shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; - } else { - shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " - << "(b_quantized_values[1] - zero_point) * scale" << c << "," - << "(b_quantized_values[2] - zero_point) * scale" << c << "," - << "(b_quantized_values[3] - zero_point) * scale" << c << "," - << "(b_quantized_values[4] - zero_point) * scale" << c << "," - << "(b_quantized_values[5] - zero_point) * scale" << c << "," - << "(b_quantized_values[6] - zero_point) * scale" << c << "," - << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; - } - } else { - shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; - for (int i = 0; i < 8; i++) { - if (has_zero_points_) { - shader.MainFunctionBody() << "zero_point" << c; - } else { - shader.MainFunctionBody() << "zero_point"; - } - if (i < 7) { - shader.MainFunctionBody() << ", "; - } - } - shader.MainFunctionBody() << ")) * scale" << c << ";\n"; - } - - shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; - if (y.NumComponents() > 1) { - shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; - } - shader.MainFunctionBody() << " += "; - if (a.NumComponents() == 1) { - shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " - "a_data[1] * b_dequantized_values[1] + " - "a_data[2] * b_dequantized_values[2] + " - "a_data[3] * b_dequantized_values[3] + " - "a_data[4] * b_dequantized_values[4] + " - "a_data[5] * b_dequantized_values[5] + " - "a_data[6] * b_dequantized_values[6] + " - "a_data[7] * b_dequantized_values[7];\n"; - } else if (a.NumComponents() == 2) { - shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " - "dot(a_data[1], b_dequantized_values[1]) + " - "dot(a_data[2], b_dequantized_values[2]) + " - "dot(a_data[3], b_dequantized_values[3]);\n"; - } else if (a.NumComponents() == 4) { - shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " - "dot(a_data[1], b_dequantized_values[1]);\n"; - } - } - - shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" - << " }\n" - " }\n" - " }\n" - " workgroupBarrier();\n" - << " if (local_id.x < " << output_number_ << ") {\n" - << " var output_value = output_value_t(0);\n" - " var workgroup_shared_offset = local_id.x;\n" - << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" - << " for (var b = 0u; b < blocks_num; b++) {\n" - " output_value += workgroup_shared[workgroup_shared_offset];\n" - << " workgroup_shared_offset += " << output_number_ << ";\n" - << " }\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" - << " }\n"; - } - - return Status::OK(); -} - Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); @@ -541,65 +90,21 @@ Status MatMulNBitsWideTileProgram::GenerateShaderCode(ShaderHelper& shader) cons // memory read/write helpers shader.AdditionalImplementation() << "fn mm_read_a(batch : u32, row : u32, col : u32) -> input_a_value_t {\n" - << " if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" + << " if (batch < uniforms.input_a_shape[0] && row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n" << " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n" << " }\n" << " return input_a_value_t(0);\n" << "}\n"; + if (nbits_ == 4) { + shader.AdditionalImplementation() << "\n" + << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n" + << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" + << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n" + << " }\n" + << " return input_b_value_t(0);\n" + << "}\n"; - shader.AdditionalImplementation() << "\n" - << "fn mm_read_b(row : u32, col : u32) -> input_b_value_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return " << b.GetByIndices("input_b_indices_t(row, col, 0)") << ";\n" - << " }\n" - << " return input_b_value_t(0);\n" - << "}\n"; - - shader.AdditionalImplementation() << "\n" - << "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n" - << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" - << " return scales[row * uniforms.input_b_shape[1] + col];\n" - << " }\n" - << " return output_element_t(0);\n" - << "}\n"; - - if (has_zero_points_) { shader.AdditionalImplementation() << R"( -fn mm_read_zero(row : u32, col : u32) -> output_element_t { - if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) { - let offset = row * uniforms.input_b_stride[0] + col * uniforms.input_b_stride[1]; - - // u32 holds 8 packed uint4. - let array_index = offset / 8u; - let component_index = offset % 8u; - let packed_value = zero_points[array_index]; - - // Extract the uint4 component - let shift_amount = component_index * 4u; - let masked_value = (packed_value >> shift_amount) & 0xFu; - - return output_element_t(masked_value); - } - return output_element_t(0); -} -)"; - } else { - shader.AdditionalImplementation() << R"( -fn mm_read_zero(row : u32, col : u32) -> output_element_t { - // The default zero point is 8. - return output_element_t(8); -} -)"; - } - - shader.AdditionalImplementation() << "\n" - << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" - << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n" - << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n" - << " }\n" - << "}\n"; - - shader.AdditionalImplementation() << R"( fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scale : output_element_t) -> mat2x4 { let lower_values: vec4 = unpack4xU8(packed_value & 0x0F0F0F0Fu); let upper_values: vec4 = unpack4xU8((packed_value >> 4u) & 0x0F0F0F0Fu); @@ -620,6 +125,23 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal return dequantized_values; } )"; + } + + shader.AdditionalImplementation() << "\n" + << "fn mm_read_scale(row : u32, col : u32) -> output_element_t {\n" + << " if (row < uniforms.input_b_shape[0] && col < uniforms.input_b_shape[1]) {\n" + << " return scales[row * uniforms.input_b_shape[1] + col];\n" + << " }\n" + << " return output_element_t(0);\n" + << "}\n" + << ReadZeroPoint(nbits_, has_zero_points_); + + shader.AdditionalImplementation() << "\n" + << "fn mm_write_y(batch : u32, row : u32, col : u32, value : output_value_t) {\n" + << " if (row < uniforms.output_shape[1] && col < uniforms.output_shape[2]) {\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col)", "value") << "\n" + << " }\n" + << "}\n"; // declare const variables shader.AdditionalImplementation() << "\n" @@ -635,9 +157,9 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal // main shader.MainFunctionBody() << R"MAIN_FN( - let batch = workgroup_id.z; - let row = workgroup_id.y * kTileM; - let col = workgroup_id.x * kTileN; + let batch = workgroup_idx / (uniforms.num_M_tile * uniforms.num_N_tile); + let row = ((workgroup_idx / uniforms.num_N_tile) % uniforms.num_M_tile) * kTileM; + let col = (workgroup_idx % uniforms.num_N_tile) * kTileN; let a_elements_per_col = uniforms.input_a_shape[2]; let a_blocks_per_col = (a_elements_per_col + kAComponentsForBlock32 - 1) / kAComponentsForBlock32; @@ -655,10 +177,13 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal let b_row = col + local_idx; let b_col = a_block_idx; - let b_data = mm_read_b(b_row, b_col); let scale = mm_read_scale(b_row, b_col); - let zero_point = mm_read_zero(b_row, b_col); + let zero_point = mm_read_zero(b_row, b_col, uniforms.input_b_shape[0], uniforms.zero_blocks_per_col); +)MAIN_FN"; + if (nbits_ == 4) { + shader.MainFunctionBody() << R"MAIN_FN( + let b_data = mm_read_b(b_row, b_col); // `b` component size is 4. for (var b_idx = 0u; b_idx < 4u; b_idx++) { let b_dequantized = dequantize_packed8xU4(b_data[b_idx], zero_point, scale); @@ -669,10 +194,37 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal results[m_idx] += f32(dot(a_data0, b_dequantized[0])) + f32(dot(a_data1, b_dequantized[1])); } } +)MAIN_FN"; + } else { + shader.MainFunctionBody() << " var b_data0 = vec4(0);\n" + " var b_data1 = vec4(0);\n" + " if (b_row < uniforms.input_b_shape[0] && b_col < uniforms.input_b_shape[1]) {\n" + << " b_data0 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 0)") << ";\n" + << " b_data1 = " << b.GetByIndices("input_b_indices_t(b_row, b_col, 1)") << ";\n" + " }" + << R"MAIN_FN( + for (var b_idx = 0u; b_idx < 4u; b_idx++) { + let b_dequantized0 = (vec4(unpack4xU8(b_data0[b_idx])) - vec4(zero_point)) * scale; + let b_dequantized1 = (vec4(unpack4xU8(b_data1[b_idx])) - vec4(zero_point)) * scale; + for (var m_idx = 0u; m_idx < kTileM; m_idx++) { + let a_data0 = a_data_tile[m_idx][b_idx]; + let a_data1 = a_data_tile[m_idx][b_idx + 4u]; + + results[m_idx] += f32(dot(a_data0, b_dequantized0)) + f32(dot(a_data1, b_dequantized1)); + } + } +)MAIN_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( workgroupBarrier(); } + if (batch >= uniforms.input_a_shape[0]) { + return; + } + // Write the results. for (var m_idx = 0u; m_idx < kTileM; m_idx++) { mm_write_y(batch, row + m_idx, col + local_idx, output_value_t(results[m_idx])); @@ -682,6 +234,156 @@ fn dequantize_packed8xU4(packed_value : u32, zero_point : output_element_t, scal return Status::OK(); } +// Apply similar idea with DP4AMatMulNBitsSmallMProgram algorithm. +Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias); + const auto& b = shader.AddInput("input_b"); + shader.AddInput("scales_b"); + if (has_zero_points_) { + shader.AddInput("zero_points", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includs 4 uint8. + constexpr uint32_t tile_size_k_vec = 16; + uint32_t elements_in_value_b = components_b * (32 / nbits_); + uint32_t tile_k_size = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_k_size / components_a; + + shader.AdditionalImplementation() << "const a_length_per_tile = " << a_length_per_tile << "u;\n" + << "const tile_size_k_vec = " << tile_size_k_vec << ";\n" + << "const tile_size_k = " << tile_k_size << "u;\n" + << "const tile_size = " << tile_size_ << "u;\n" + << "const elements_in_value_b = " << elements_in_value_b << "u;\n" + << "const sub_tile_count = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n" + << "const component_a = " << components_a << "u;\n" + << "const component_b = " << components_b << "u;\n"; + shader.AdditionalImplementation() << R"ADDNL_FN( + // Shared memory + var tile_A : array; + var inter_results: array, tile_size>; + fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32) + { + let k_offset = kidx / component_a + col; + if (batch < uniforms.batch_count && k_offset < uniforms.K_of_a) { + tile_A[col] = input_a[batch * uniforms.M * uniforms.K_of_a + a_global * uniforms.K_of_a + k_offset]; + } else { + tile_A[col] = input_a_value_t(0); + } + } +)ADDNL_FN" + << ReadZeroPoint(nbits_, has_zero_points_); + + shader.MainFunctionBody() << R"MAIN_FN( + let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile); + let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) + { + for (var id = local_idx; id < a_length_per_tile; id += workgroup_size_x) + { + loadSHMA(batch, a_global, kidx, id); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) + { + var b_global = b_global_base + local_row_offset + idy; + var k_offset = kidx / elements_in_value_b + idx; + if (b_global < uniforms.N && k_offset < uniforms.K_of_b) + { + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + let scale_b = scales_b[b_global * uniforms.blocks_per_col + block_idx]; + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + var b_value = input_b[b_global * uniforms.K_of_b + k_offset]; +)MAIN_FN"; + + if (nbits_ == 4) { + shader.MainFunctionBody() << R"MAIN_FN( + var sum = output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +)MAIN_FN"; + switch (components_a) { + case 1: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +" + " dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);\n" + " a_offset += 8;\n"; + break; + case 2: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) +" + "dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);\n" + " a_offset += 4;\n"; + break; + case 4: + shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);\n" + " a_offset += 2;\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " }\n"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + var sum = output_element_t(0); + var a_offset = idx * (4 / component_a) * component_b; + for (var i = 0u; i < component_b; i++) { + let b_value = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; +)MAIN_FN"; + switch (components_a) { + case 1: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);\n" + " a_offset += 4;\n"; + break; + case 2: + shader.MainFunctionBody() << " sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value);\n" + " a_offset += 2;\n"; + break; + case 4: + shader.MainFunctionBody() << " sum += dot(tile_A[a_offset], b_value);\n" + " a_offset += 1;\n"; + break; + default: + break; + } + shader.MainFunctionBody() << " }\n"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + inter_results[local_row_offset + idy][idx] += sum; + } + } + workgroupBarrier(); + } + + if (batch >= uniforms.batch_count) { + return; + } + + if (local_idx < tile_size) { + var output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -724,20 +426,18 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context } // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. - if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || nbits == 8 || + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, nbits, context, y); } - // TODO: Remvoe it once the 8bits is supported for the non-dp4 path. - ORT_ENFORCE(nbits == 4, "Only 4 bits are supported for the non-dp4 path for webgpu matmulnbits"); + // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. + uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - // TODO: loosen restrictions on vendor. - const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization && - context.AdapterInfo().vendor == std::string_view{"intel"}; + const bool use_wide_tile_program = block_size == 32 && components_a == 4 && components_b == 4 && M >= kMinMForTileOptimization; if (use_wide_tile_program) { // Enforce output components to 1. components = 1; @@ -745,8 +445,10 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t workgroup_size = 128; constexpr uint32_t tile_m = workgroup_size / 8; constexpr uint32_t tile_n = workgroup_size; + uint32_t num_N_tile = (N + tile_n - 1) / tile_n; + uint32_t num_M_tile = (M + tile_m - 1) / tile_m; - MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n}; + MatMulNBitsWideTileProgram program{has_zero_points, tile_m, tile_n, nbits}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_n - 1) / tile_n, (M + tile_m - 1) / tile_m, @@ -762,7 +464,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, onnxruntime::narrow(components_b * 4)}, {scales, ProgramTensorMetadataDependency::None}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, onnxruntime::narrow(components)}) - .AddUniformVariable({block_size}); + .AddUniformVariables({{block_size}, {zero_blocks_per_col}, {num_N_tile}, {num_M_tile}}) + .CacheHint(nbits, has_zero_points); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } @@ -770,47 +473,21 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return context.RunProgram(program); } - // Generic program - // TODO: Support output_number > 1. Some cases are failed when output_number > 1. - constexpr uint32_t output_number = 1; - const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; - const bool has_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups); - const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; - MatMulNBitsProgram program{output_number, block_size, tile_m, static_cast(components_b), has_zero_points, use_subgroup}; - if (M > kMinMForTileOptimization && block_size == 32) { - components = 1; - constexpr uint32_t workgroup_size = 64; - constexpr uint32_t workgroup_y = 8; - constexpr uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y, - (M + tile_m - 1) / tile_m, - batch_count); - program.CacheHint("T_M" + std::to_string(tile_m) + "Subgroup" + std::to_string(use_subgroup)); - } else if (block_size == 32) { - components = 1; - // TODO: Tune the workgroup size when `M=1`. - constexpr uint32_t workgroup_size = 128; - const uint32_t workgroup_y = N % 8 == 0 ? 8 : 1; - const uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize(data_size / components / workgroup_y); - program.CacheHint("T_M" + std::to_string(tile_m)); - } else { - program.SetDispatchGroupSize(data_size / components / output_number); - program.CacheHint("O_N" + std::to_string(output_number)); - } - - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; - + constexpr uint32_t workgroup_size = 128; + constexpr uint32_t tile_size = 8; + constexpr uint32_t kU32Components = 4; + uint32_t components_b_with_u32 = components_b * kU32Components; + uint32_t num_N_tile = (N + tile_size - 1) / tile_size; + MatMulNBitsProgram program{tile_size, nbits, has_zero_points}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, static_cast(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, static_cast(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(components)}) - .AddUniformVariable({block_size}); + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{M}, {N}, {K}, {K / components_a}, {n_blocks_per_col * blob_size / components_b_with_u32}, {block_size}, {n_blocks_per_col}, {zero_blocks_per_col}, {num_N_tile}, {batch_count}}) + .CacheHint(nbits, has_zero_points); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index d5e4bc68fc33a..807576c91752b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -12,41 +12,44 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class MatMulNBitsProgram final : public Program { +class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool use_subgroup) : Program{"MatMulNBits"}, - output_number_{output_number}, - block_size_{block_size}, - tile_m_{tile_m}, - components_b_{components_b}, - has_zero_points_{has_zero_points}, - use_subgroup_(use_subgroup) { - } + MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) + : Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"num_M_tile", ProgramUniformVariableDataType::Uint32}); private: - uint32_t output_number_; - uint32_t block_size_; - uint32_t tile_m_; - int components_b_; bool has_zero_points_; - bool use_subgroup_; + uint32_t tile_m_; + uint32_t tile_n_; + uint32_t nbits_; }; -class MatMulNBitsWideTileProgram final : public Program { +class MatMulNBitsProgram final : public Program { public: - MatMulNBitsWideTileProgram(bool has_zero_points, uint32_t tile_m, uint32_t tile_n) - : Program{"MatMulNBitsWideTileProgram"}, has_zero_points_{has_zero_points}, tile_m_(tile_m), tile_n_(tile_n) {} - + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points) : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"block_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"zero_blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}); private: + uint32_t tile_size_; + uint32_t nbits_; bool has_zero_points_; - uint32_t tile_m_; - uint32_t tile_n_; }; class MatMulNBits final : public WebGpuKernel { diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index b29fc5181eb46..257d3b3efdf9c 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD -#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) +#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) #include @@ -186,6 +186,10 @@ void RunTest8Bits(const TestOptions8Bits& opts) { std::vector> execution_providers; #ifdef USE_CUDA execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_WEBGPU + execution_providers.emplace_back(DefaultWebGpuExecutionProvider()); +#endif +#if defined(USE_CUDA) || defined(USE_WEBGPU) test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); execution_providers.clear(); @@ -226,8 +230,8 @@ void TestMatMul8BitsTyped() { RunTest8Bits(opts); } -// CUDA does not support bias for MatMulNBits -#if not defined(USE_CUDA) +// CUDA/WEBGPU does not support bias for MatMulNBits +#if !defined(USE_CUDA) && !defined(USE_WEBGPU) { TestOptions8Bits opts = base_opts; opts.has_bias = true; @@ -279,7 +283,7 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { TestMatMul8BitsTyped(); } -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_WEBGPU) TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); From 3a618f21a6a9efc327be0a2a0c77459978190ac0 Mon Sep 17 00:00:00 2001 From: Kee Date: Wed, 7 May 2025 00:58:51 +0800 Subject: [PATCH 16/84] [VSINPU EP]Fix gather OP with scalar indice issue (#24603) ### Description If indices is a scalar(0 dimensional tensor) , gather OP produces incorrect output shape. Fix the gather op bug in VSINPU EP. ### Motivation and Context Signed-off-by: Kee --- .../vsinpu/builders/impl/gather_op_builder.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h index bd91bbd81e1fa..0d3374bce325b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/gather_op_builder.h @@ -59,21 +59,33 @@ class GatherOpBuilder : public BaseOpBuilder { std::vector>& outputs, const NodeUnit& node_unit) override { LOGS_DEFAULT(VERBOSE) << "Creating Gather Op."; + auto indices = node_unit.Inputs()[1]; + int8_t is_scalar_indices = 0; NodeAttrHelper helper(node_unit.GetNode()); auto axis = helper.Get("axis", 0); axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); auto op = graph_ep->GetGraph()->CreateOperation(axis, 0); + auto indices_shape_proto = indices.node_arg.Shape(); + if (indices_shape_proto != nullptr) { + if (indices_shape_proto->dim_size() == 0) { + is_scalar_indices = 1; + } + } else { + is_scalar_indices = 1; + } bool is_i64_indices = inputs[1]->GetDataType() == tim::vx::DataType::INT64; if (!is_i64_indices) { + inputs[1]->SetScalar(is_scalar_indices); (*op).BindInputs(inputs).BindOutputs(outputs); } else { std::vector origin_data(inputs[1]->GetSpec().GetElementNum()); inputs[1]->CopyDataFromTensor(origin_data.data()); std::vector transformed_data(origin_data.begin(), origin_data.end()); - tim::vx::TensorSpec ts = inputs[1]->GetSpec().SetAttribute(tim::vx::TensorAttribute::INPUT); + tim::vx::TensorSpec ts = inputs[1]->GetSpec(); ts.SetDataType(tim::vx::DataType::INT32); auto transformed_indices = graph_ep->GetGraph()->CreateTensor(ts, transformed_data.data()); + transformed_indices->SetScalar(is_scalar_indices); (*op).BindInput(inputs[0]).BindInput(transformed_indices).BindOutput(outputs[0]); } graph_ep->GetOps().push_back(std::move(op)); From 65b4c3775040aaf4b7c9fe2e8e317c55cac4d12c Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 6 May 2025 10:49:54 -0700 Subject: [PATCH 17/84] [Native WebGPU] Fixed type mismatch. (#24655) ### Description Fix type mismatch using float in place of unsigned int. ### Motivation and Context --- .../core/providers/webgpu/quantization/quantize_linear.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc index e2b5d73168935..0305049e9b789 100644 --- a/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/webgpu/quantization/quantize_linear.cc @@ -79,7 +79,7 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { if (packed_) { shader.MainFunctionBody() << "let zero_point_index = " << output.IndicesGet("output_indices", "uniforms.axis") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("zero_point_index / 4") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_index / 4)") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_index % 4];\n"; } else { @@ -92,7 +92,7 @@ Status DequantizeLinearProgram::GenerateShaderCode(ShaderHelper& shader) const { if (packed_) { shader.MainFunctionBody() << "let zero_point_offset = " << scale.GetByIndices("scale_indices") << ";\n" - << "let zero_point_input = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" + << "let zero_point_input = " << zero_point.GetByOffset("u32(zero_point_offset / 4)") << ";\n" << "let zero_point_vec = " << unpack << ";\n" << "let zero_point_value = zero_point_vec[zero_point_offset % 4];\n"; } else { From 7942b0caa991ef03fc6d266a72971ad4d3354da3 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 6 May 2025 11:29:21 -0700 Subject: [PATCH 18/84] [webgpu] fix compile errors in instancenorm (#24639) fix shader compile; don't know how this made it past ci --- onnxruntime/core/providers/webgpu/nn/instance_norm.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc index f3bccec4872fc..7b39980f85605 100644 --- a/onnxruntime/core/providers/webgpu/nn/instance_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/instance_norm.cc @@ -88,13 +88,13 @@ Status InstanceNormProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& channel_scale_shift = shader.AddInput("channel_scale_shift", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << "let outputIndices = " << output.OffsetToIndices("global_idx") + << "let outputIndices = " << output.OffsetToIndices("global_idx") << ";\n" << "let batch = outputIndices[0];\n" << "let channel = outputIndices[1];\n" << "let channel_scale_shift_indices = channel_scale_shift_indices_t(batch, channel, 0);\n" << "let channel_scale_shift = " << channel_scale_shift.GetByIndices("channel_scale_shift_indices") << ";\n" << "let input_value = " << input.GetByOffset("global_idx") << ";\n" - << "let output_value = input_value * output_value_t(channel_scale_sift.x) + output_value_t(channel_scale_shift.y);\n" + << "let output_value = input_value * output_value_t(channel_scale_shift.x) + output_value_t(channel_scale_shift.y);\n" << output.SetByOffset("global_idx", "output_value") << ";\n"; return Status::OK(); } From cab3c421ead7631a080c692f2de850cc76fdc1c9 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 6 May 2025 11:30:36 -0700 Subject: [PATCH 19/84] Fix source name in CUDA publishing pipeline configuration (#24645) ### Description Python Cuda Publishing pipeline references old test pipeline --- .../github/azure-pipelines/py-cuda-publishing-pipeline.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml index 230c391c00ebd..016c09e6c01da 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-publishing-pipeline.yml @@ -1,7 +1,7 @@ resources: pipelines: - pipeline: build - source: 'Python CUDA12 Package Test Pipeline' + source: 'Python CUDA Package Test Pipeline' trigger: branches: include: @@ -37,4 +37,4 @@ extends: stages: - template: stages/py-cuda-publishing-stage.yml parameters: - artifact_feed: $(ArtifactFeed) \ No newline at end of file + artifact_feed: $(ArtifactFeed) From 1f4ca889744857d151e60e5867f683876caa5355 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 May 2025 12:08:24 -0700 Subject: [PATCH 20/84] allow upload log on failure for further investigating (#24649) ### Description The random failure on Web CI is hard to investigate because it's not reproducible. Add this step to upload the log to help investigate the issue. --- .github/workflows/windows-web-ci-workflow.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index 57f687d8502ff..b45663a6145e3 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -205,6 +205,14 @@ jobs: log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log is_chromium_log: true + # this step is added to help investigate the shader validation failure which is hard to reproduce + - name: Upload WebGPU shader validation log on failure + if: ${{ failure() && inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} + uses: actions/upload-artifact@v4 + with: + name: webgpu-shader-validation-logs + path: ${{ runner.temp }}\web\test\07\chrome_debug.log + - name: E2E package consuming test if: ${{ inputs.build_config == 'Release' }} run: npm run test:e2e -- --browser=Chrome_default From 8aa0b2851729ea99a7664c88bb5b476dfd46b669 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 6 May 2025 12:59:27 -0700 Subject: [PATCH 21/84] [JSEP] Fix outputSize calculation causing duplicate indices. (#24650) ### Description Fix the outputSize computation causing duplicate indices. The outputSize should be the size of indices tensor without counting the last dimension. ### Motivation and Context Fix the issue https://github.com/microsoft/onnxruntime/issues/24070 --- js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts | 100 ++++++------------ 1 file changed, 33 insertions(+), 67 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts index ec1d23e4887d5..286984c15feca 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -78,46 +78,15 @@ const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: } }; -const calcDataOffsetSnippet = (dataRank: number, parallel: boolean) => - `${ - dataRank === 1 - ? ` - let element_count_dim = uniforms.output_strides; - let dim_value = uniforms.output_shape;` - : ` - let element_count_dim = uniforms.output_strides[${parallel ? 'i - indices_start' : 'i'}]; - let dim_value = uniforms.output_shape[${parallel ? 'i - indices_start' : 'i'} + uniforms.last_index_dimension];` - } - - if (index >= 0) { - if (index >= i32(dim_value)) { - index = i32(dim_value - 1); - } - } else { - if (index < -i32(dim_value)) { - index = 0; - } else { - index += i32(dim_value); - } - } - data_offset += u32((u32(index) * element_count_dim));`; - -const updateElementsSnippet = (attributes: ScatterNDAttributes, outputTypeValue: ReductionType, parallel: boolean) => - `for (var i = 0u; i < uniforms.num_updates_elements; i++) { - let value = updates[uniforms.num_updates_elements * ${parallel ? 'global_idx' : 'idx'} + i]; - ${atomicReductionSnippet(attributes.reduction, 'output[data_offset + i]', 'value', outputTypeValue)} - }`; - const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { const inputShape = inputs[0].dims; const indicesShape = inputs[1].dims; const outputShape = inputShape; // TODO: support bool with components 4. const components = 1; - const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); + const outputSize = Math.ceil(ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1) / components); const lastIndexDimension = indicesShape[indicesShape.length - 1]; const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); - const numIndicesElements = ShapeUtil.sizeFromDimension(indicesShape, 0) / lastIndexDimension; const programUniforms: ProgramUniform[] = [ { type: DataType.uint32, data: outputSize }, @@ -142,48 +111,45 @@ const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: S .declareVariables(indices, updates, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} - var hasDuplicates = false; - if (${attributes.reduction === 'none'}) { - for (var i = 0; i < ${numIndicesElements}; i = i + 1) { - for (var j = i + 1; j < ${numIndicesElements}; j = j + 1) { - var index_i = i32(indices[i].x); - var index_j = i32(indices[j].x); - if (index_i == index_j) { - hasDuplicates = true; - break; - } + var data_offset = 0u; + let indices_start = uniforms.last_index_dimension * global_idx; + let indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${ + inputs[0].dims.length === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[i - indices_start]; + let dim_value = uniforms.output_shape[i - indices_start];` + } + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); } - if (hasDuplicates) { - break; + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); } } + data_offset += u32((u32(index) * element_count_dim)); } - if (${attributes.reduction === 'none'} && hasDuplicates) { - if (global_idx != 0u) { - return; - } - // Process each index-update pair individually when duplicates exist - for (var idx = 0u; idx < ${numIndicesElements}u; idx++) { - var data_offset = 0u; - for (var i = 0u; i < uniforms.last_index_dimension; i++) { - var index = i32(indices[idx * uniforms.last_index_dimension + i].x); - ${calcDataOffsetSnippet(inputShape.length, false)} - } - ${updateElementsSnippet(attributes, output.type.value as ReductionType, false)} - } - return; + for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * global_idx + i]; + ${atomicReductionSnippet( + attributes.reduction, + 'output[data_offset + i]', + 'value', + output.type.value as ReductionType, + )} } - var data_offset = 0u; - var indices_start = uniforms.last_index_dimension * global_idx; - var indices_end = indices_start + uniforms.last_index_dimension; - for (var i = indices_start; i < indices_end; i++) { - var index = i32(indices[i].x); - ${calcDataOffsetSnippet(inputShape.length, true)} - } - ${updateElementsSnippet(attributes, output.type.value as ReductionType, true)} - }`; + }`; }; return { name: 'ScatterND', From 7bec521e9bbf5ad88a5c8c7eb7898360923f7023 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 May 2025 13:08:57 -0700 Subject: [PATCH 22/84] fix header include in webgpu_context.cc (#24648) ### Description header file "dawn/dawn_proc.h" is only used in a non-monolithic build of dawn. --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 17f4fa1bd44b3..27380645baf57 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -10,7 +10,9 @@ #endif #if !defined(__wasm__) +#if !defined(BUILD_DAWN_MONOLITHIC_LIBRARY) #include "dawn/dawn_proc.h" +#endif #if !defined(USE_EXTERNAL_DAWN) #include "dawn/native/DawnNative.h" #endif From cdff2c1179e3d29c25ad8ae790a089ee65945ed0 Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 7 May 2025 04:31:07 +0800 Subject: [PATCH 23/84] [webgpu]: optimize pool operators (#24598) The patch optimizes pool operators when output size is small and kernel size is big ### Description ### Motivation and Context --- onnxruntime/core/providers/webgpu/nn/pool.cc | 74 +++++++++++++++++--- onnxruntime/core/providers/webgpu/nn/pool.h | 6 +- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/nn/pool.cc b/onnxruntime/core/providers/webgpu/nn/pool.cc index 12c135dbbf46d..d650392b71fb5 100644 --- a/onnxruntime/core/providers/webgpu/nn/pool.cc +++ b/onnxruntime/core/providers/webgpu/nn/pool.cc @@ -95,6 +95,9 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { var_decl_code = SS_GET(var_decl_ss); sampling_code = " value = max(value, x_val);\n"; + if (are_small_output_big_kernel_) { + downsampling_code = " sum_or_max_shared[local_idx] = value;\n"; + } } else { SS(var_decl_ss, kStringInitialSize); var_decl_ss << " var value = " << (is_float16_ ? "f16(0)" : "f32(0)") << ";\n"; @@ -113,7 +116,12 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { sampling_code = SS_GET(sampling_ss); SS(downsampling_ss, kStringInitialSize); - downsampling_ss << " value /= " << (is_float16_ ? "f16" : "f32") << "(count);\n"; + if (are_small_output_big_kernel_) { + downsampling_ss << " sum_or_max_shared[local_idx] = value;\n" + << " count_shared[local_idx] = count;\n"; + } else { + downsampling_ss << " value /= " << (is_float16_ ? "f16" : "f32") << "(count);\n"; + } downsampling_code = SS_GET(downsampling_ss); } @@ -125,13 +133,54 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { auto data_dim_end = input.Rank(); data_dim_end = is_nhwc_ ? data_dim_end - 1 : data_dim_end; + std::string sum_or_max_shared; + if (are_small_output_big_kernel_) { + shader.AdditionalImplementation() + << "var sum_or_max_shared : array<" << (is_float16_ ? "f16" : "f32") << ",workgroup_size_x >;\n" + << (!is_max_pool_ ? "var count_shared : array;\n" : ""); + + SS(shared_ss, 512); + std::string sum_or_max_shared_op; + std::string count_shared_op; + if (is_max_pool_) { + sum_or_max_shared_op = "sum_or_max_shared[local_idx] = max(sum_or_max_shared[local_idx], sum_or_max_shared[local_idx + reduce_size]);\n"; + } else { + sum_or_max_shared_op = "sum_or_max_shared[local_idx] += sum_or_max_shared[local_idx + reduce_size];\n"; + count_shared_op = "count_shared[local_idx] += count_shared[local_idx + reduce_size];\n"; + } + + shared_ss << " workgroupBarrier();\n" + << " var reduce_size : u32 = workgroup_size_x;\n" + << " for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (local_idx < curr_size) {\n" + << " " << sum_or_max_shared_op + << " " << count_shared_op + << " }\n" + << " workgroupBarrier();\n" + << " }\n"; + sum_or_max_shared = SS_GET(shared_ss); + } + std::string kernel_loop_decl_code = are_small_output_big_kernel_ ? " for (var i: u32 = local_idx; i < uniforms.kernel_size; i += workgroup_size_x) {\n" : " for (var i: u32 = 0; i < uniforms.kernel_size; i++) {\n"; + + SS(output_ss, kStringInitialSize); + if (are_small_output_big_kernel_) { + output_ss << " if (local_idx == 0) {\n" + << " value = sum_or_max_shared[0]" << (!is_max_pool_ ? (is_float16_ ? " / f16(count_shared[0])" : " / f32(count_shared[0])") : "") << ";\n" + << " " << output.SetByOffset("workgroup_idx", "value") << ";\n" + << " }\n"; + } else { + output_ss << " " << output.SetByOffset("global_idx", "value") << ";\n"; + } + std::string output_code = SS_GET(output_ss); + auto& body = shader.MainFunctionBody(); - body << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << " let y_indices = " << output.OffsetToIndices("global_idx") << ";\n" + body << (are_small_output_big_kernel_ ? "" : shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")) + << " let y_indices = " << output.OffsetToIndices((are_small_output_big_kernel_ ? "workgroup_idx" : "global_idx")) << ";\n" << " var x_indices = y_indices;\n" << " var k_indices: array;\n" << var_decl_code - << " for (var i: u32 = 0; i < uniforms.kernel_size; i++) {\n" + << kernel_loop_decl_code << " var offset = i;\n" // ---- Compute offset to indices in pooling window. << " for (var j = 0; j < " << kernel_rank << "; j++) {\n" @@ -162,7 +211,8 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << " }\n" << downsampling_code - << " " << output.SetByOffset("global_idx", "value") << ";\n"; + << sum_or_max_shared + << output_code; return Status::OK(); } @@ -225,7 +275,6 @@ Status Pool::ComputeInternal(ComputeContext& context) const { } bool is_float16 = X->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; bool count_include_pad = pool_attrs_.count_include_pad; - PoolProgram program{is_max_pool, is_nhwc, kernel_shape, is_float16, count_include_pad}; // Number of elements uint32_t output_size = gsl::narrow_cast(Y->Shape().Size()); @@ -235,16 +284,25 @@ Status Pool::ComputeInternal(ComputeContext& context) const { const auto strides_u32 = NarrowToU32(strides); const auto dilations_u32 = NarrowToU32(dilations); - program.CacheHint(kernel_shape.size(), is_max_pool, is_nhwc, is_float16, count_include_pad) + bool are_small_output_big_kernel = output_size <= 128 && kernel_size >= 128; + PoolProgram program{is_max_pool, is_nhwc, kernel_shape, is_float16, count_include_pad, are_small_output_big_kernel}; + + program.CacheHint(kernel_shape.size(), is_max_pool, is_nhwc, is_float16, count_include_pad, are_small_output_big_kernel) .AddInputs({{X, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutputs({{Y}}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({output_size, kernel_size, gsl::span(kernel_strides.data(), kernel_strides.size()), gsl::span(pads_u32.data(), pads_u32.size()), gsl::span(strides_u32.data(), strides_u32.size()), gsl::span(dilations_u32.data(), dilations_u32.size())}); + if (are_small_output_big_kernel) { + program.SetWorkgroupSize(128) + .SetDispatchGroupSize(output_size); + } else { + program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + } + return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/nn/pool.h b/onnxruntime/core/providers/webgpu/nn/pool.h index c1716542e5549..57bdc64954acd 100644 --- a/onnxruntime/core/providers/webgpu/nn/pool.h +++ b/onnxruntime/core/providers/webgpu/nn/pool.h @@ -14,13 +14,14 @@ namespace webgpu { class PoolProgram final : public Program { public: PoolProgram(bool is_max_pool, bool is_nhwc, const TensorShapeVector& kernel_shape, bool is_float16, - bool count_include_pad) + bool count_include_pad, bool are_small_output_big_kernel) : Program{"Pool"}, is_max_pool_{is_max_pool}, is_nhwc_{is_nhwc}, kernel_shape_{kernel_shape}, is_float16_{is_float16}, - count_include_pad_{count_include_pad} {} + count_include_pad_{count_include_pad}, + are_small_output_big_kernel_{are_small_output_big_kernel} {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -39,6 +40,7 @@ class PoolProgram final : public Program { const TensorShapeVector kernel_shape_; const bool is_float16_; const bool count_include_pad_; + const bool are_small_output_big_kernel_; }; template From 3c6fa0efe4caf8e39b3754d75897dd98b73a63fb Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 6 May 2025 16:42:20 -0700 Subject: [PATCH 24/84] Add support for EP selection delegate callback to Python bindings (#24634) ### Description Follow up to https://github.com/microsoft/onnxruntime/pull/24614 Example Python program (adapted from unit tests) that specifies a custom EP selection function to select a OrtEpDevice(s) for compiling: ```python def test_compile_with_ep_selection_delegate(self): # ... # User's custom EP selection function. def my_delegate( ep_devices: Sequence[onnxrt.OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], max_selections: int, ) -> Sequence[onnxrt.OrtEpDevice]: self.assertTrue(len(model_metadata) > 0) self.assertTrue(ep_devices and max_selections > 0) # Select the first and last devices (if there are more than one) selected_devices = [ep_devices[0]] if max_selections > 2 and len(ep_devices) > 1: selected_devices.append(ep_devices[-1]) # ORT CPU EP is always last return selected_devices session_options = onnxrt.SessionOptions() session_options.set_provider_selection_policy_delegate(my_delegate) model_compiler = onnxrt.ModelCompiler( session_options, input_model_path, embed_compiled_data_into_model=True, external_initializers_file_path=None, ) model_compiler.compile_to_file(output_model_path) ``` How to raise an exception from the Python EP selection function: ```python # User's custom EP selection function. custom_error_message = "MY ERROR" def my_delegate_that_fails( ep_devices: Sequence[onnxrt.OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], max_selections: int, ) -> Sequence[onnxrt.OrtEpDevice]: self.assertTrue(len(ep_devices) >= 1) raise ValueError(custom_error_message) sess_options = onnxrt.SessionOptions() sess_options.set_provider_selection_policy_delegate(my_delegate_that_fails) # Create session and expect ORT to raise a Fail exception that contains our message. with self.assertRaises(Fail) as context: onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) self.assertIn(custom_error_message, str(context.exception)) ``` ### Motivation and Context --- .../core/session/provider_policy_context.cc | 6 + .../python/onnxruntime_pybind_state.cc | 136 +++++++++++++----- .../python/onnxruntime_pybind_state_common.h | 17 ++- .../python/onnxruntime_test_python_autoep.py | 86 ++++++++++- .../onnxruntime_test_python_compile_api.py | 41 ++++++ 5 files changed, 241 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index f706bd05d8494..a4e0c16b411a1 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -169,6 +169,12 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); } + if (num_selected > selected_devices.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EP selection delegate selected too many EP devices (", num_selected, "). ", + "The limit is ", selected_devices.size(), " EP devices."); + } + // Copy the selected devices to the output vector devices_selected.reserve(num_selected); for (size_t i = 0; i < num_selected; ++i) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index c29a8b0497b1f..d5f8c7c181960 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -23,6 +23,7 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" +#include "core/framework/error_code_helper.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -1798,38 +1799,69 @@ void addGlobalMethods(py::module& m) { #endif } -// TODO(adrianlizarraga): C API's delegate function needs a void* param to store state. -// using PyEpSelectionDelegate = std::function(const std::vector& ep_devices, -// const std::unordered_map& model_metadata, -// const std::unordered_map& runtime_metadata)>; -// -// static OrtStatus* PyDelegateWrapper(void* delegate_state, -// _In_ const OrtEpDevice** ep_devices, -// _In_ size_t num_devices, -// _In_ const OrtKeyValuePairs* model_metadata, -// _In_opt_ const OrtKeyValuePairs* runtime_metadata, -// _Inout_ const OrtEpDevice** selected, -// _In_ size_t max_selected, -// _Out_ size_t* num_selected) { -// PyEpSelectionDelegate* actual_delegate = reinterpret_cast(delegate_state); -// std::vector py_ep_devices(ep_devices, ep_devices + num_devices); -// std::unordered_map py_model_metadata = -// model_metadata ? model_metadata->entries : std::unordered_map{}; -// std::unordered_map py_runtime_metadata = -// runtime_metadata ? runtime_metadata->entries : std::unordered_map{}; -// -// std::vector py_selected = (*actual_delegate)(py_ep_devices, py_model_metadata, py_runtime_metadata); -// -// // TODO: Check max_selected and return OrtStatus if necessary. -// assert(py_selected.size() <= max_selected); -// -// *num_selected = py_selected.size(); -// for (size_t i = 0; i < py_selected.size(); ++i) { -// selected[i] = py_selected[i]; -// } -// -// return nullptr; -// }; +#if !defined(ORT_MINIMAL_BUILD) +/** + * Calls the user's Python EP selection function and coverts the results to a format that can be used + * by ORT to select OrtEpDevice instances. The user's function is set by calling + * SessionOptions.set_provider_selection_policy_delegate() on the Python side. The result of this wrapper + * function is used in core/session/provider_policy_context.cc. + * + * @param ep_devices OrtEpDevices to select from. + * @param num_devices Number of OrtEpDevices to select from. + * @param model_metadata Model's metadata. + * @param runtime_metadata Runtime metadata. + * @param selected Pre-allocated OrtEpDevice buffer to update with selected devices. + * @param max_selected Maximum number of entries in the pre-allocated 'selected' buffer. + * @param state Opaque state that holds a pointer to the user's Python function. + * + * @return nullptr OrtStatus* to indicate success. + */ +static OrtStatus* ORT_API_CALL PyEpSelectionPolicyWrapper(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state) { + PyEpSelectionDelegate* actual_delegate = reinterpret_cast(state); + std::vector py_ep_devices(ep_devices, ep_devices + num_devices); + std::unordered_map py_model_metadata = + model_metadata ? model_metadata->entries : std::unordered_map{}; + std::unordered_map py_runtime_metadata = + runtime_metadata ? runtime_metadata->entries : std::unordered_map{}; + + *num_selected = 0; + std::vector py_selected; + OrtStatus* status = nullptr; + + // Call the Python delegate function and convert any exceptions to a status. + ORT_TRY { + py_selected = (*actual_delegate)(py_ep_devices, py_model_metadata, py_runtime_metadata, max_selected); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (status != nullptr) { + return status; + } + + if (py_selected.size() > max_selected) { + return ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "selected too many EP devices (", py_selected.size(), "). ", + "The limit is ", max_selected, " EP devices.")); + } + + *num_selected = py_selected.size(); + for (size_t i = 0; i < py_selected.size(); ++i) { + selected[i] = py_selected[i]; + } + + return nullptr; +} +#endif // !defined(ORT_MINIMAL_BUILD) void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registration_fn) { py::enum_(m, "GraphOptimizationLevel") @@ -2035,21 +2067,47 @@ must refer to the same execution provider.)pbdoc") .def( // Equivalent to the C API's SessionOptionsSetEpSelectionPolicy. "set_provider_selection_policy", - [](PySessionOptions* sess_options, + [](PySessionOptions* py_sess_options, OrtExecutionProviderDevicePolicy policy) { #if !defined(ORT_MINIMAL_BUILD) - sess_options->value.ep_selection_policy.enable = true; - sess_options->value.ep_selection_policy.policy = policy; - sess_options->value.ep_selection_policy.delegate = nullptr; // TODO: need a void* param in delegate. + py_sess_options->py_ep_selection_delegate = nullptr; + + py_sess_options->value.ep_selection_policy.enable = true; + py_sess_options->value.ep_selection_policy.policy = policy; + py_sess_options->value.ep_selection_policy.delegate = nullptr; + py_sess_options->value.ep_selection_policy.state = nullptr; #else - ORT_UNUSED_PARAMETER(sess_options); + ORT_UNUSED_PARAMETER(py_sess_options); ORT_UNUSED_PARAMETER(policy); ORT_THROW("EP selection policies are not supported in this build"); #endif }, R"pbdoc(Sets the execution provider selection policy for the session. Allows users to specify a -selection policy for automatic execution provider (EP) selection, or provide a delegate callback -for custom selection logic.)pbdoc") +selection policy for automatic execution provider (EP) selection.)pbdoc") + .def( + // Equivalent to the C API's SessionOptionsSetEpSelectionPolicyDelegate. + "set_provider_selection_policy_delegate", + [](PySessionOptions* py_sess_options, + PyEpSelectionDelegate delegate_fn) { +#if !defined(ORT_MINIMAL_BUILD) + py_sess_options->py_ep_selection_delegate = delegate_fn; // Store python callback in PySessionOptions + + py_sess_options->value.ep_selection_policy.enable = true; + py_sess_options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT; + py_sess_options->value.ep_selection_policy.delegate = PyEpSelectionPolicyWrapper; + py_sess_options->value.ep_selection_policy.state = + reinterpret_cast(&py_sess_options->py_ep_selection_delegate); +#else + ORT_UNUSED_PARAMETER(py_sess_options); + ORT_UNUSED_PARAMETER(delegate_fn); + ORT_THROW("EP selection policies are not supported in this build"); +#endif + }, + R"pbdoc(Sets the execution provider selection policy delegate for the session. Allows users to specify a +custom selection policy function for automatic execution provider (EP) selection. The delegate must return a list of +selected OrtEpDevice instances. The signature of the delegate is +def custom_delegate(ep_devices: Sequence[OrtEpDevice], model_metadata: dict[str, str], runtime_metadata: dict[str, str], +max_selections: int) -> Sequence[OrtEpDevice])pbdoc") .def( "has_providers", [](PySessionOptions* sess_options) -> bool { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index a964f199d43b3..4114bd4078799 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -4,6 +4,8 @@ #pragma once +#include + #include "core/common/logging/logging.h" #include "core/common/logging/sinks/cerr_sink.h" #include "core/common/optional.h" @@ -229,7 +231,20 @@ extern OrtDevice::DeviceId cuda_device_id; // TODO remove deprecated global config extern size_t gpu_mem_limit; -using PySessionOptions = OrtSessionOptions; +#if !defined(ORT_MINIMAL_BUILD) +using PyEpSelectionDelegate = std::function(const std::vector& ep_devices, + const std::unordered_map& model_metadata, + const std::unordered_map& runtime_metadata, + size_t max_selections)>; +#endif + +// Thin wrapper over internal C OrtSessionOptions to store additional state. +struct PySessionOptions : public OrtSessionOptions { +#if !defined(ORT_MINIMAL_BUILD) + // Callback function from Python application that allows the user to specify custom EP selection logic. + PyEpSelectionDelegate py_ep_selection_delegate; +#endif // !defined(ORT_MINIMAL_BUILD) +}; // Thin wrapper over internal C++ InferenceSession to accommodate custom op library management for the Python user struct PyInferenceSession { diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 1d7dba5662257..61dc0ff221318 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -6,13 +6,14 @@ import platform import sys import unittest +from collections.abc import Sequence import numpy as np from autoep_helper import AutoEpTestCase from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, InvalidArgument # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -68,8 +69,8 @@ def test_cuda_ep_register_and_inference(self): output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - # TODO(adrianlizarraga): Unregistering CUDA EP library causes issues. Investigate. - # self.unregister_execution_provider_library(ep_registration_name) + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_registration_name) def test_cuda_prefer_gpu_and_inference(self): """ @@ -100,8 +101,83 @@ def test_cuda_prefer_gpu_and_inference(self): output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - # TODO(adrianlizarraga): Unregistering CUDA EP library causes issues. Investigate. - # self.unregister_execution_provider_library(ep_registration_name) + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_registration_name) + + def test_cuda_ep_selection_delegate_and_inference(self): + """ + Test selecting CUDA EP via the custom EP selection delegate function and then run inference. + """ + ep_lib_path = "onnxruntime_providers_cuda.dll" + ep_registration_name = "CUDAExecutionProvider" + + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + if ep_registration_name not in available_providers: + self.skipTest("Skipping test because it needs to run on CUDA EP") + + self.register_execution_provider_library(ep_registration_name, ep_lib_path) + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 2) + self.assertGreaterEqual(max_selections, 2) + + cuda_ep_device = next((d for d in ep_devices if d.ep_name == ep_registration_name), None) + self.assertIsNotNone(cuda_ep_device) + + # Select the CUDA device and the ORT CPU EP device (should always be last) + return [cuda_ep_device, ep_devices[-1]] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + del sess # Delete session before unregistering library + self.unregister_execution_provider_library(ep_registration_name) + + def test_custom_ep_selection_delegate_that_raises_error(self): + """ + Test a custom EP selection delegate function that raises a Python exception. ORT should re-raise as FAIL. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + # User's custom EP selection function. + custom_error_message = "MY ERROR" + + def my_delegate_that_fails( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreaterEqual(len(ep_devices), 1) + raise ValueError(custom_error_message) + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate_that_fails) + + # Create session and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + self.assertIn(custom_error_message, str(context.exception)) def test_example_plugin_ep_devices(self): """ diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index f5f23e2da1e43..866388f3aa226 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -6,6 +6,7 @@ import platform import sys import unittest +from collections.abc import Sequence import onnx from autoep_helper import AutoEpTestCase @@ -52,6 +53,46 @@ def test_compile_with_files_prefer_npu_policy(self): self.assertTrue(os.path.exists(output_model_path)) self.unregister_execution_provider_library(ep_registration_name) + def test_compile_with_ep_selection_delegate(self): + """ + Tests compiling a model (to/from files) using an EP selection delegate callback. + """ + if sys.platform != "win32": + self.skipTest("Skipping test because provider selection policies are only supported on Windows") + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.delegate.onnx") + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(ep_devices), 0) + self.assertGreater(len(model_metadata), 0) + self.assertGreater(max_selections, 0) + + # Select the first and last devices (if there are more than one) + selected_devices = [ep_devices[0]] + if max_selections > 2 and len(ep_devices) > 1: + selected_devices.append(ep_devices[-1]) # ORT CPU EP is always last + + return selected_devices + + session_options = onnxrt.SessionOptions() + session_options.set_provider_selection_policy_delegate(my_delegate) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + def test_compile_with_input_and_output_files(self): """ Tests compiling a model (to/from files) using explicit EP. From c1ef02f74b0d648cc7e8558805fa90846ea11a35 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 6 May 2025 17:52:55 -0700 Subject: [PATCH 25/84] [Gen C\C++ API docs] Fix documentation warnings for new autoEP/compile APIs (#24661) ### Description Fixes documentation errors in comments within onnxruntime_c_api.h and onnxruntime__cxx_api.h. ### Motivation and Context The [Generate C/C++ API docs](https://github.com/microsoft/onnxruntime/actions/runs/14855108283/job/41706460753#logs) action is failing with error: ```shell Run mkdir -p build/doxygen /mnt/vss/_work/onnxruntime/onnxruntime/include/onnxruntime/core/session/onnxruntime_cxx_api.h:775: error: explicit link request to 'OrtKeyValuePair' could not be resolved (warning treated as error, aborting now) ``` --- include/onnxruntime/core/session/onnxruntime_c_api.h | 12 ++++++------ .../onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6c7d910b4963b..a4a32fbea630a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -435,9 +435,9 @@ typedef enum OrtExecutionProviderDevicePolicy { * \param model_metadata The model metadata. * \param runtime_metadata The runtime metadata. May be nullptr. * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. - * \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array. - Currently the maximum is 8. - * \param num_ep_devices The number of selected devices. + * \param max_selected The maximum number of devices that can be selected in the pre-allocated array. + Currently the maximum is 8. + * \param num_selected The number of selected devices. * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. * * \return OrtStatus* Selection status. Return nullptr on success. @@ -6116,7 +6116,7 @@ struct OrtEpFactory { * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the * session. This will include ep_options from GetSupportedDevices as well as any * user provided overrides. - * Execution provider options will have been added with a prefix of 'ep..'. + * Execution provider options will have been added with a prefix of 'ep.[ep name].'. * The OrtSessionOptions instance will NOT be valid after this call and should not be * stored for later use. * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. @@ -6124,7 +6124,7 @@ struct OrtEpFactory { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version . This is a placeholder. + * \since Version [coming soon]. This is a placeholder. */ OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, _In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -6138,7 +6138,7 @@ struct OrtEpFactory { * \param[in] this_ptr The OrtEpFactory instance. * \param[in] ep The OrtEp instance to release. * - * \since Version . This is a placeholder. + * \since Version [coming soon]. This is a placeholder. */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bc6f381bb82a0..39c20e237b02c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -772,7 +772,7 @@ struct KeyValuePairsImpl : Ort::detail::Base { // Const object holder that does not own the underlying object using ConstKeyValuePairs = detail::KeyValuePairsImpl>; -/** \brief Wrapper around ::OrtKeyValuePair */ +/** \brief Wrapper around ::OrtKeyValuePairs */ struct KeyValuePairs : detail::KeyValuePairsImpl { explicit KeyValuePairs(std::nullptr_t) {} ///< No instance is created /// Take ownership of a pointer created by C API From 6fef0693192d879e6adae4412d9164b823441309 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Tue, 6 May 2025 18:51:12 -0700 Subject: [PATCH 26/84] [Native WebGPU] Added ScatterND (#24623) ### Description Added ScatterND operator to Native WebGPU EP. ### Motivation and Context Required to increase coverage. --- .../core/providers/webgpu/shader_helper.cc | 19 +- .../providers/webgpu/tensor/scatter_nd.cc | 222 ++++++++++++++++++ .../core/providers/webgpu/tensor/scatter_nd.h | 60 +++++ .../webgpu/webgpu_execution_provider.cc | 10 + 4 files changed, 305 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/scatter_nd.h diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 59855e6117641..36f6b512a0a93 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -125,7 +125,8 @@ namespace { // Validate if the tensor element type matches the program variable data type Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type, bool is_atomic = false) { if (is_atomic) { - ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32, + // float32 is not a valid data type for atomic. However the data may be bitcast-ed to i32 and used to simulate atomic operation using atomicCompareExchangeWeak. + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || var_type == ProgramVariableDataType::Uint32 || var_type == ProgramVariableDataType::Float32, "Unexpected program variable type ", int(var_type), " for atomic variable"); } @@ -422,11 +423,17 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha bool is_atomic = program_.Outputs()[i].is_atomic; ss << "@group(0) @binding(" << input_vars_.size() + i << ") var " << output->name_ << ": array<"; if (is_atomic) { - ss << "atomic<"; - } - ss << output->StorageType(); - if (is_atomic) { - ss << ">"; + if (output->type_ == ProgramVariableDataType::Float32) { + ss << "atomic"; + } else if (output->type_ == ProgramVariableDataType::Uint32) { + ss << "atomic"; + } else if (output->type_ == ProgramVariableDataType::Int32) { + ss << "atomic"; + } else { + ORT_RETURN_IF(true, "Unsupported atomic type: ", int(output->type_)); + } + } else { + ss << output->StorageType(); } ss << ">;\n"; } diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc new file mode 100644 index 0000000000000..986255ea1f185 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.cc @@ -0,0 +1,222 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "scatter_nd.h" + +namespace onnxruntime { +namespace webgpu { + +Status ScatterNDProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + shader.AddInput("updates", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseShapeAndStride); + const auto output_rank = static_cast(output.Rank()); + auto atomic_reduction_snippet = [](ScatterNDReduction reduction, const std::string& ptr, const std::string& value, const std::string& data_type) -> std ::string { + std::ostringstream ss; + bool is_32_bit_integer = data_type == "i32" || data_type == "u32"; + bool is_unsigned_integer = data_type == "u32"; + std::ostringstream ss_float_start; + ss_float_start << " {\n" + << " var oldValue = 0" << (is_unsigned_integer ? "u" : "") << ";\n" + << " loop {\n" + << " let newValueF32 = "; + std::ostringstream ss_float_end; + ss_float_end << ";\n" + << " let newValue = bitcast<" << (is_unsigned_integer ? "u32" : "i32") << ">(newValueF32);\n" + << " let res = atomicCompareExchangeWeak(&" << ptr << ", oldValue, newValue);\n" + << " if res.exchanged {\n" + << " break;\n" + << " }\n" + << " oldValue = res.old_value;\n" + << " }\n" + << " }\n"; + switch (reduction) { + case ScatterNDReduction::None: + ss << " " << ptr << " = " << value << ";\n"; + break; + case ScatterNDReduction::Add: + if (is_32_bit_integer) { + ss << " atomicAdd(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicAdd only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "bitcast<" << data_type << ">(oldValue) + (" << value << ")" << ss_float_end.str() + << "\n"; + } + break; + case ScatterNDReduction::Max: + if (is_32_bit_integer) { + ss << " atomicMax(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicMax only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "max(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str(); + } + break; + case ScatterNDReduction::Min: + if (is_32_bit_integer) { + ss << " atomicMin(&" << ptr << ", bitcast<" << data_type << ">(" << value << "));\n"; + } else { + // atomicMin only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "min(bitcast<" << data_type << ">(oldValue), (" << value << "))" << ss_float_end.str(); + } + break; + case ScatterNDReduction::Mul: + // atomicMul is not supported, we use atomicCompareExchangeWeak to simulate. + ss << ss_float_start.str() << "(bitcast<" << data_type << ">(oldValue) * (" << value << "))" << ss_float_end.str(); + break; + default: + ORT_THROW("Unsupported reduction type: ", static_cast(reduction)); + // The controlflow should never reach here. + } + return ss.str(); + }; + + auto calc_data_offset_snippet = [](size_t output_rank) -> std::string { + std::ostringstream ss; + if (output_rank < 2) { + ss << " let element_count_dim = 1u;\n"; + } else { + ss << " let element_count_dim = select(" << GetElementAt("uniforms.output_stride", "i - indices_start", output_rank - 1) << ", 1u, i - indices_start == " << (output_rank - 1) << ");\n"; + } + ss << " let dim_value = " << GetElementAt("uniforms.output_shape", "i - indices_start", output_rank) << ";\n" + << " if (index >= 0) {\n" + << " if (index >= i32(dim_value)) {\n" + << " index = i32(dim_value - 1);\n" + << " }\n" + << " } else {\n" + << " if (index < -i32(dim_value)) {\n" + << " index = 0;\n" + << " } else {\n" + << " index += i32(dim_value);\n" + << " }\n" + << " }\n" + << " data_offset += u32((u32(index) * element_count_dim));\n"; + return ss.str(); + }; + + auto update_elements_snippet = [atomic_reduction_snippet](ScatterNDReduction reduction, const std::string& data_type) -> std::string { + std::ostringstream ss; + ss << " for (var i = 0u; i < uniforms.num_updates_elements; i++) {\n" + << " let value = updates[uniforms.num_updates_elements * global_idx + i];\n" + << atomic_reduction_snippet(reduction, "output[data_offset + i]", "value", data_type) << "\n" + << " }\n"; + return ss.str(); + }; + std::string data_type_str; + bool reducible = false; + if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "i32"; + } else if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "u32"; + } else if (data_type_ == DataTypeImpl::GetType()) { + reducible = true; + data_type_str = "f32"; + } else { + // Default value. + data_type_str = "output_element_t"; + } + if (reduction_ != ScatterNDReduction::None && !reducible) { + ORT_THROW("ScatterND: Reduction is not supported for data type ", data_type_str); + } + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var data_offset = 0u;\n" + << " var indices_start = uniforms.last_index_dimension * global_idx;\n" + << " var indices_end = indices_start + uniforms.last_index_dimension;\n" + << " for (var i = indices_start; i < indices_end; i++) {\n" + << " var index = i32(indices[i].x);\n" + << calc_data_offset_snippet(output_rank) + << " }\n" + << update_elements_snippet(reduction_, data_type_str); + return Status::OK(); +} + +Status ScatterND::ComputeInternal(ComputeContext& context) const { + const Tensor* input = context.Input(0); + const auto* indices = context.Input(1); + const auto* updates = context.Input(2); + const auto& input_shape = input->Shape(); + const auto& indices_shape = indices->Shape(); + auto indices_rank = indices_shape.NumDimensions(); + auto last_index_dimension = static_cast(indices_shape[indices_rank - 1]); + auto num_updates_elements = static_cast(input_shape.SizeFromDimension(last_index_dimension)); + // TODO: support bool with components 4. + const size_t components = 1; + auto output_size = static_cast((indices_shape.SizeToDimension(indices_rank - 1) + components - 1) / components); + auto* output = context.Output(0, input_shape); + MLDataType data_type = input->DataType(); + const void* source = input->DataRaw(); + void* target = output->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input, *output)); + } + ScatterNDProgram program(reduction_, data_type); + program + .CacheHint(static_cast(reduction_)) + .AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}, + {updates, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({output_size, last_index_dimension, num_updates_elements}); + if (reduction_ != ScatterNDReduction::None && (data_type == DataTypeImpl::GetType() || data_type == DataTypeImpl::GetType() || + data_type == DataTypeImpl::GetType())) { + program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, ProgramOutput::Atomic}); + } else { + program.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank}); + } + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + ScatterND, + kOnnxDomain, + 18, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 16, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 13, + 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 11, + 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .MayInplace(0, 0), + ScatterND); +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h new file mode 100644 index 0000000000000..40bcbadebf65d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/scatter_nd.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace webgpu { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +class ScatterNDProgram final : public Program { + public: + ScatterNDProgram(ScatterNDReduction reduction, MLDataType data_type) : Program{"ScatterND"}, reduction_(reduction), data_type_(data_type) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"last_index_dimension", ProgramUniformVariableDataType::Uint32}, + {"num_updates_elements", ProgramUniformVariableDataType::Uint32}); + ScatterNDReduction reduction_; + MLDataType data_type_; +}; + +class ScatterND : public WebGpuKernel { + public: + ScatterND(const OpKernelInfo& info) : WebGpuKernel(info) { + std::string reduction = info.GetAttrOrDefault("reduction", "none"); + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + reduction_ = ScatterNDReduction::None; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on webgpu when opset <= 18."); + } + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 928e48d78d7e5..9ea79e4cf28a3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -400,6 +400,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, DequantizeLinear); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterND); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -732,6 +737,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { From 8e7c0ac3acbcfb98ee0371579b642434c5ecc403 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 6 May 2025 20:28:36 -0700 Subject: [PATCH 27/84] Use build id when publishing symbols (#24662) --- .../c-api-artifacts-package-and-publish-steps-windows.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index 8c3a9eba82356..1f71e8c0e0125 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -43,7 +43,7 @@ steps: SymbolExpirationInDays: '36530' IndexableFileFormats: 'Default' DetailedLog: true - SymbolsArtifactName: 'Symbols_${{parameters.buildConfig}}' + SymbolsArtifactName: 'Symbols_${{parameters.artifactNameNoVersionString}}_$(Build.BuildId)' - task: CmdLine@2 displayName: 'Copy build artifacts for zipping' From f14fd59fe9bfe419376357128163e687bf88b6d6 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 7 May 2025 13:33:48 +1000 Subject: [PATCH 28/84] More error checking for user provided C# policy selection delegate (#24666) ### Description Handle user selection policy delegate throwing or returning too many selections in C# code and create error message. ### Motivation and Context --- .../NativeMethods.shared.cs | 8 ++ .../SessionOptions.shared.cs | 82 +++++++++++------- .../OrtAutoEpTests.cs | 84 +++++++++++++++++++ 3 files changed, 144 insertions(+), 30 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index c543414ca13a9..664a77ceab037 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -407,12 +407,15 @@ static NativeMethods() api_ = (OrtApi)OrtGetApi(ORT_API_VERSION); OrtGetVersionString = (DOrtGetVersionString)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetVersionString, typeof(DOrtGetVersionString)); #endif + OrtCreateStatus = (DOrtCreateStatus)Marshal.GetDelegateForFunctionPointer( + api_.CreateStatus, typeof(DOrtCreateStatus)); OrtCreateEnv = (DOrtCreateEnv)Marshal.GetDelegateForFunctionPointer(api_.CreateEnv, typeof(DOrtCreateEnv)); OrtCreateEnvWithCustomLogger = (DOrtCreateEnvWithCustomLogger)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLogger, typeof(DOrtCreateEnvWithCustomLogger)); OrtCreateEnvWithGlobalThreadPools = (DOrtCreateEnvWithGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithGlobalThreadPools, typeof(DOrtCreateEnvWithGlobalThreadPools)); OrtCreateEnvWithCustomLoggerAndGlobalThreadPools = (DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLoggerAndGlobalThreadPools, typeof(DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)); OrtReleaseEnv = (DOrtReleaseEnv)Marshal.GetDelegateForFunctionPointer(api_.ReleaseEnv, typeof(DOrtReleaseEnv)); + OrtEnableTelemetryEvents = (DOrtEnableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.EnableTelemetryEvents, typeof(DOrtEnableTelemetryEvents)); OrtDisableTelemetryEvents = (DOrtDisableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.DisableTelemetryEvents, typeof(DOrtDisableTelemetryEvents)); @@ -933,6 +936,11 @@ internal class NativeLib #endregion Status API #region InferenceSession API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateStatus( + uint /* OrtErrorCode */ code, + byte[] /* const char* */ msg); + public static DOrtCreateStatus OrtCreateStatus; [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateSession( diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index d60bf75ccbd7c..9794d2c184d5d 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -989,41 +989,63 @@ public static IntPtr EpSelectionPolicyWrapper(IntPtr /* OrtEpDevice** */ epDevic out UIntPtr numSelected, IntPtr state) { - Span epDevicesIntPtrs; - Span selectedDevicesIntPtrs; - EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target; + numSelected = UIntPtr.Zero; - unsafe + try { - void* ptr = epDevicesIn.ToPointer(); - epDevicesIntPtrs = new Span(ptr, checked((int)numDevices)); - } - - List epDevices = new List(); - for (int i = 0; i < numDevices; i++) - { - - epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i])); - } - OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn); - OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn); - - var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected); - - numSelected = (UIntPtr)selected.Count; - - unsafe - { - void* ptr = selectedOut.ToPointer(); - selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected); + Span epDevicesIntPtrs; + Span selectedDevicesIntPtrs; + EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target; + + unsafe + { + void* ptr = epDevicesIn.ToPointer(); + epDevicesIntPtrs = new Span(ptr, checked((int)numDevices)); + } + + List epDevices = new List(); + for (int i = 0; i < numDevices; i++) + { + + epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i])); + } + + OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn); + OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn); + + var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected); + + if (selected.Count > maxSelected) + { + var error = $"The number of selected devices ({selected.Count}) returned by " + + $"the C# selection delegate exceeds the maximum ({maxSelected})."; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; + } + + numSelected = (UIntPtr)selected.Count; + + unsafe + { + void* ptr = selectedOut.ToPointer(); + selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected); + } + + int idx = 0; + foreach (var epDevice in selected) + { + selectedDevicesIntPtrs[idx] = epDevice.Handle; + idx++; + } } - - int idx = 0; - foreach (var epDevice in selected) + catch (Exception ex) { - selectedDevicesIntPtrs[idx] = epDevice.Handle; - idx++; + var error = $"The C# selection delegate threw an exception: {ex.Message}"; + IntPtr status = NativeMethods.OrtCreateStatus((uint)ErrorCode.Fail, + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(error)); + return status; } return IntPtr.Zero; diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index d95a649bd95c5..9368f9d8bc298 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -198,5 +198,89 @@ public void SetEpSelectionPolicyDelegate() Assert.NotNull(session); } } + + // select max + 1, starting with all devices + private static List SelectionPolicyDelegateTooMany(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + Assert.NotEmpty(modelMetadata.Entries); + Assert.True(epDevices.Count > 0); + var selected = new List(epDevices); + + while (selected.Count < (maxSelections + 1)) + { + selected.Add(epDevices.Last()); + } + + return selected; + } + + [Fact] + public void SetEpSelectionPolicyDelegateTooMany() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // select too many devices + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegateTooMany); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should fail + try + { + using var session = new InferenceSession(model, sessionOptions); + Assert.Fail("Should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + // Current C++ max is 8. We copy all devices and keep adding until we exceed that. + const int max = 8; + var numSelected = epDevices.Count > max ? epDevices.Count : (max + 1); + var expected = "[ErrorCode:Fail] EP selection delegate failed: The number of selected devices " + + $"({numSelected}) returned by the C# selection delegate exceeds the maximum ({max})"; + Assert.Contains(expected, ex.Message); + } + } + + // throw exception in user provided delegate + private static List SelectionPolicyDelegateThrows(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + throw new ArgumentException("Test exception"); + } + + [Fact] + public void SetEpSelectionPolicyDelegateThrows() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegateThrows); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + try + { + using var session = new InferenceSession(model, sessionOptions); + Assert.Fail("Should have thrown an exception"); + } + catch (OnnxRuntimeException ex) + { + var expected = "[ErrorCode:Fail] EP selection delegate failed: " + + "The C# selection delegate threw an exception: Test exception"; + Assert.Contains(expected, ex.Message); + } + } } #endif \ No newline at end of file From ef3caaf34b4ff8677ec3cb010faa61b22abe3b54 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 6 May 2025 21:24:31 -0700 Subject: [PATCH 29/84] [Clean up] Fix ep_name vs ep_registration_name usage in autoEP Python unit tests (#24667) ### Description Cleans up the usage of `ep_name` and `ep_registration_name` in the autoEP Python unit tests. ### Motivation and Context Addresses comments from a previous PR: https://github.com/microsoft/onnxruntime/pull/24634 > nit: the registration name and EP names don't need to match. could we call this 'ep_name' to avoid potentially creating an assumption that they always do? --- .../python/onnxruntime_pybind_state.cc | 2 +- .../python/onnxruntime_test_python_autoep.py | 43 +++++++++---------- .../onnxruntime_test_python_compile_api.py | 6 +-- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d5f8c7c181960..aa2c0cc6a0f86 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1801,7 +1801,7 @@ void addGlobalMethods(py::module& m) { #if !defined(ORT_MINIMAL_BUILD) /** - * Calls the user's Python EP selection function and coverts the results to a format that can be used + * Calls the user's Python EP selection function and converts the results to a format that can be used * by ORT to select OrtEpDevice instances. The user's function is set by calling * SessionOptions.set_provider_selection_policy_delegate() on the Python side. The result of this wrapper * function is used in core/session/provider_policy_context.cc. diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 61dc0ff221318..417a6e27fb7b2 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -28,24 +28,23 @@ def test_cuda_ep_register_and_inference(self): Test registration of CUDA EP, adding its OrtDevice to the SessionOptions, and running inference. """ ep_lib_path = "onnxruntime_providers_cuda.dll" - ep_registration_name = "CUDAExecutionProvider" + ep_name = "CUDAExecutionProvider" if sys.platform != "win32": self.skipTest("Skipping test because device discovery is only supported on Windows") - if ep_registration_name not in available_providers: + if ep_name not in available_providers: self.skipTest("Skipping test because it needs to run on CUDA EP") - self.register_execution_provider_library(ep_registration_name, ep_lib_path) + self.register_execution_provider_library(ep_name, ep_lib_path) ep_devices = onnxrt.get_ep_devices() has_cpu_ep = False cuda_ep_device = None for ep_device in ep_devices: - ep_name = ep_device.ep_name - if ep_name == "CPUExecutionProvider": + if ep_device.ep_name == "CPUExecutionProvider": has_cpu_ep = True - if ep_name == ep_registration_name: + if ep_device.ep_name == ep_name: cuda_ep_device = ep_device self.assertTrue(has_cpu_ep) @@ -70,22 +69,22 @@ def test_cuda_ep_register_and_inference(self): np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) del sess # Delete session before unregistering library - self.unregister_execution_provider_library(ep_registration_name) + self.unregister_execution_provider_library(ep_name) def test_cuda_prefer_gpu_and_inference(self): """ Test selecting CUDA EP via the PREFER_GPU policy and running inference. """ ep_lib_path = "onnxruntime_providers_cuda.dll" - ep_registration_name = "CUDAExecutionProvider" + ep_name = "CUDAExecutionProvider" if sys.platform != "win32": self.skipTest("Skipping test because device discovery is only supported on Windows") - if ep_registration_name not in available_providers: + if ep_name not in available_providers: self.skipTest("Skipping test because it needs to run on CUDA EP") - self.register_execution_provider_library(ep_registration_name, ep_lib_path) + self.register_execution_provider_library(ep_name, ep_lib_path) # Set a policy to prefer GPU. Cuda should be selected. sess_options = onnxrt.SessionOptions() @@ -102,22 +101,22 @@ def test_cuda_prefer_gpu_and_inference(self): np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) del sess # Delete session before unregistering library - self.unregister_execution_provider_library(ep_registration_name) + self.unregister_execution_provider_library(ep_name) def test_cuda_ep_selection_delegate_and_inference(self): """ Test selecting CUDA EP via the custom EP selection delegate function and then run inference. """ ep_lib_path = "onnxruntime_providers_cuda.dll" - ep_registration_name = "CUDAExecutionProvider" + ep_name = "CUDAExecutionProvider" if sys.platform != "win32": self.skipTest("Skipping test because device discovery is only supported on Windows") - if ep_registration_name not in available_providers: + if ep_name not in available_providers: self.skipTest("Skipping test because it needs to run on CUDA EP") - self.register_execution_provider_library(ep_registration_name, ep_lib_path) + self.register_execution_provider_library(ep_name, ep_lib_path) # User's custom EP selection function. def my_delegate( @@ -130,7 +129,7 @@ def my_delegate( self.assertGreaterEqual(len(ep_devices), 2) self.assertGreaterEqual(max_selections, 2) - cuda_ep_device = next((d for d in ep_devices if d.ep_name == ep_registration_name), None) + cuda_ep_device = next((d for d in ep_devices if d.ep_name == ep_name), None) self.assertIsNotNone(cuda_ep_device) # Select the CUDA device and the ORT CPU EP device (should always be last) @@ -150,7 +149,7 @@ def my_delegate( np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) del sess # Delete session before unregistering library - self.unregister_execution_provider_library(ep_registration_name) + self.unregister_execution_provider_library(ep_name) def test_custom_ep_selection_delegate_that_raises_error(self): """ @@ -192,18 +191,16 @@ def test_example_plugin_ep_devices(self): except FileNotFoundError: self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") - ep_registration_name = "example_ep" - self.register_execution_provider_library(ep_registration_name, os.path.realpath(ep_lib_path)) + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) ep_devices = onnxrt.get_ep_devices() has_cpu_ep = False test_ep_device = None for ep_device in ep_devices: - ep_name = ep_device.ep_name - - if ep_name == "CPUExecutionProvider": + if ep_device.ep_name == "CPUExecutionProvider": has_cpu_ep = True - if ep_name == ep_registration_name: + if ep_device.ep_name == ep_name: test_ep_device = ep_device self.assertTrue(has_cpu_ep) @@ -236,7 +233,7 @@ def test_example_plugin_ep_devices(self): sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) self.assertIn("EP is not currently supported", str(context.exception)) - self.unregister_execution_provider_library(ep_registration_name) + self.unregister_execution_provider_library(ep_name) if __name__ == "__main__": diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 866388f3aa226..7a410d4bbeb6a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -34,8 +34,8 @@ def test_compile_with_files_prefer_npu_policy(self): self.skipTest("Skipping test because provider selection policies are only supported on Windows") ep_lib_path = "onnxruntime_providers_qnn.dll" - ep_registration_name = "QNNExecutionProvider" - self.register_execution_provider_library(ep_registration_name, ep_lib_path) + ep_name = "QNNExecutionProvider" + self.register_execution_provider_library(ep_name, ep_lib_path) input_model_path = get_name("nhwc_resize_scales_opset18.onnx") output_model_path = os.path.join(self._tmp_dir_path, "model.compiled0.onnx") @@ -51,7 +51,7 @@ def test_compile_with_files_prefer_npu_policy(self): ) model_compiler.compile_to_file(output_model_path) self.assertTrue(os.path.exists(output_model_path)) - self.unregister_execution_provider_library(ep_registration_name) + self.unregister_execution_provider_library(ep_name) def test_compile_with_ep_selection_delegate(self): """ From 76ae65a7cda306e6e6601bcb4d9e53e32c84c333 Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Tue, 6 May 2025 21:34:51 -0700 Subject: [PATCH 30/84] [QNN EP] Fix Resize Op support translation (#24657) - Use ResizeNearestNeighbor Op for Resize with interpolation_mode=Nearest and rank-4 inputs. - Add a Unit test to verify the modified translation. ### Description ResizeNearestNeighbor Op is faster for Resize with interpolation_mode=Nearest and rank-4 inputs. ### Motivation and Context This commit matches Resize Op behavior in QNN-EP with QNN Offline converter path. This fix also improves inference time. --- .../qnn/builder/opbuilder/resize_op_builder.cc | 17 +++++++---------- onnxruntime/test/providers/qnn/resize_test.cc | 9 +++++++++ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 347f0651069dc..85844721b1f2c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -185,18 +185,15 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); #endif - const bool use_resize_nn_op = nearest_mode == "floor"; + // Use ResizeNearestNeighbor for rank-4 inputs. + const bool use_resize_nn_op = input_rank == 4; // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode // is not supported. - ORT_RETURN_IF(use_resize_nn_op && transformation_mode == "pytorch_half_pixel", + ORT_RETURN_IF(!use_resize_nn_op && nearest_mode == "floor" && transformation_mode == "pytorch_half_pixel", "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", " and coordinate_transformation_mode == 'pytorch_half_pixel'."); - // QNN's ResizeNearestNeighbor requires rank 4 inputs. - ORT_RETURN_IF(use_resize_nn_op && input_rank != 4, - "QNN EP: Resize on the NPU with nearest_mode == 'floor' requires an input with rank 4."); - #if QNN_API_VERSION_MAJOR >= 2 && QNN_API_VERSION_MINOR >= 14 // QNN's Resize only supports "round_prefer_ceil" if transformation_mode is "align_corners". ORT_RETURN_IF(!use_resize_nn_op && transformation_mode != "align_corners", @@ -267,11 +264,11 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); std::string qnn_op_type = "Resize"; - if (is_npu_backend && input_rank == 4 && interp_mode == "nearest" && nearest_mode == "floor") { + if (is_npu_backend && input_rank == 4 && interp_mode == "nearest") { // Translate Resize with - // {input_rank: 4, mode: "nearest", nearest_mode: "floor", coordinate_transformation_mode: XXX} to - // QNN's ResizeNearestNeighbor operator on the HTP backend. This combination of parameters is not supported on HTP - // via QNN's Resize operator. Note that QNN's ResizeNearestNeighbor operator always uses "floor" rounding. + // {input_rank: 4, mode: "nearest", coordinate_transformation_mode: XXX} to + // QNN's ResizeNearestNeighbor operator on the HTP backend. QNN ResizeNearestNeighbor + // seems to be faster than QNN Resize. qnn_op_type = "ResizeNearestNeighbor"; // 'align_corners' diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index fbd729fa998d9..702d4e6eddb1b 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -399,6 +399,15 @@ TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor_Unsupport ExpectedEPNodeAssignment::None); // No longer supported as of QNN SDK 2.21 } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_Ceil" +// Maps to QNN's ResizeNearesetNeighbor operator. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferCeil) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_ceil", + ExpectedEPNodeAssignment::All); +} + // Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "align_corners", nearest_mode: "round_prefer_ceil" // Maps to QNN's Resize operator. // UPDATE: "round_prefer_ceil" is supported as of QNN SDK 2.21 if using "align_corners". (Unsupported in QNN SDK 2.19). From dc09448cf757c02a6c70b3acf422bab22abf1cc4 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 6 May 2025 21:46:23 -0700 Subject: [PATCH 31/84] Fix improper vector iterator handling and a couple of warnings (#24664) ### Description When erasing elements from a vector while iterating over it, the iterator must be updated to the new position after the erase operation as a return value, otherwise results in decrement of end operator which is undefined behavior. Fix a couple of warnings as well. ### Motivation and Context Causes test assert. --- onnxruntime/core/framework/session_state_utils.cc | 2 +- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 7 ++++--- .../tensorrt/tensorrt_execution_provider_custom_ops.h | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index cef902e506075..cacd772b61d76 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -369,7 +369,7 @@ common::Status SaveInitializedTensors( if (memory_profile_func) memory_profile_func(planner); - for (auto i : planned_initializers_memory_sizes_in_byte) { + for (const auto& i : planned_initializers_memory_sizes_in_byte) { LOGS(logger, INFO) << "[Memory] SessionStateInitializer statically allocates " << i.second << " bytes for " << i.first.ToString() << std::endl; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index ded135bf50ec8..72eb2579e9d42 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2587,11 +2587,12 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, supported_nodes_vector.clear(); } - // Remove subgraphs if its size is less than the predefined minimal size - for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end();) { const size_t subgraph_size = it->first.size(); if (subgraph_size < min_subgraph_size_) { - supported_nodes_vector.erase(it--); + it = supported_nodes_vector.erase(it); + } else { + ++it; } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index a72de6ed75399..8d4dc19690eac 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -78,7 +78,7 @@ struct TensorRTCustomOp : Ort::CustomOpBase Date: Wed, 7 May 2025 00:31:50 -0700 Subject: [PATCH 32/84] Use GUID_DEVCLASS_COMPUTEACCELERATOR for NPU discovery (#24660) ### Description Updated device discovery to use `GUID_DEVCLASS_COMPUTEACCELERATOR` instead of `GUID_DEVCLASS_SYSTEM` when querying SetupAPI for potential NPU devices. This provides a more specific and accurate class for identifying compute accelerators like NPUs. This change also saves us an average of 5 milliseconds by not looping through unnecessary system devices. ### Motivation and Context When looking for NPUs, the previous code used `GUID_DEVCLASS_SYSTEM` as the class to query for potential devices and didn't return the Qualcomm NPU. --- onnxruntime/core/platform/windows/device_discovery.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 5a5b5041a5912..1f8600d6ca4a6 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -75,11 +75,12 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde const GUID local_DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML = {0xb71b0d41, 0x1088, 0x422f, 0xa2, 0x7c, 0x2, 0x50, 0xb7, 0xd3, 0xa9, 0x88}; const GUID local_DXCORE_HARDWARE_TYPE_ATTRIBUTE_NPU = {0xd46140c4, 0xadd7, 0x451b, 0x9e, 0x56, 0x6, 0xfe, 0x8c, 0x3b, 0x58, 0xed}; + const GUID local_GUID_DEVCLASS_COMPUTEACCELERATOR = {0xf01a9d53, 0x3ff6, 0x48d2, 0x9f, 0x97, 0xc8, 0xa7, 0x00, 0x4b, 0xe1, 0x0c}; std::array guids = { GUID_DEVCLASS_DISPLAY, GUID_DEVCLASS_PROCESSOR, - GUID_DEVCLASS_SYSTEM, + local_GUID_DEVCLASS_COMPUTEACCELERATOR, }; for (auto guid : guids) { @@ -183,9 +184,9 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde entry->type = OrtHardwareDeviceType_GPU; } else if (guid == GUID_DEVCLASS_PROCESSOR) { entry->type = is_npu ? OrtHardwareDeviceType_NPU : OrtHardwareDeviceType_CPU; - } else if (guid == GUID_DEVCLASS_SYSTEM) { + } else if (guid == local_GUID_DEVCLASS_COMPUTEACCELERATOR) { if (!is_npu) { - // we're only iterating system devices to look for NPUs so drop anything else + // we're only iterating compute accelerator devices to look for NPUs so drop anything else device_info.erase(key); continue; } From 143d8be6df2a36cb56f3b3189fb6f2f1c4e83e0d Mon Sep 17 00:00:00 2001 From: Yuduo Wu <6426433+1duo@users.noreply.github.com> Date: Wed, 7 May 2025 10:00:21 -0700 Subject: [PATCH 33/84] [QNN-EP] Einsum QDQ tests followup: node selector (#24659) [QNN-EP] Einsum QDQ tests followup: node selector. Verified now Einsum QDQ tests are running with `QNN_DATATYPE_UFIXED_POINT_8`: ``` 2025-05-06 11:22:09.738247549 [V:onnxruntime:qdq_model_logger, qnn_model_wrapper.cc:305 ComposeQnnGraph] Qnn_OpConfig node name: node_token_15 package_name: qti.aisw QNN_op_type: MatMul num_of_inputs: 2 num_of_outputs: 1 num_of_params: 2 node_inputs: name=node id=2 version=1 type=QNN_TENSOR_TYPE_NATIVE dataFormat=0 dataType=QNN_DATATYPE_UFIXED_POINT_8 rank=2 dimensions=(2 3 ) memType=QNN_TENSORMEMTYPE_RAW quantizeParams: encodingDefinition=QNN_DEFINITION_DEFINED quantizationEncoding=QNN_QUANTIZATION_ENCODING_SCALE_OFFSET scale=0.000980392 offset=-102 name=node_token_6 id=4 version=1 type=QNN_TENSOR_TYPE_NATIVE dataFormat=0 dataType=QNN_DATATYPE_UFIXED_POINT_8 rank=2 dimensions=(3 4 ) memType=QNN_TENSORMEMTYPE_RAW quantizeParams: encodingDefinition=QNN_DEFINITION_DEFINED quantizationEncoding=QNN_QUANTIZATION_ENCODING_SCALE_OFFSET scale=0.00215686 offset=-46 node_outputs: name=node_token_16 id=5 version=1 type=QNN_TENSOR_TYPE_NATIVE dataFormat=0 dataType=QNN_DATATYPE_UFIXED_POINT_8 rank=2 dimensions=(2 4 ) memType=QNN_TENSORMEMTYPE_RAW quantizeParams: encodingDefinition=QNN_DEFINITION_DEFINED quantizationEncoding=QNN_QUANTIZATION_ENCODING_SCALE_OFFSET scale=0.000441176 offset=-40 node_params: type=QNN_PARAMTYPE_SCALAR name=transpose_in0 value=0 type=QNN_PARAMTYPE_SCALAR name=transpose_in1 value=0 ``` --- .../selectors_actions/qdq_selectors.cc | 31 +++++++++++++++++++ .../selectors_actions/qdq_selectors.h | 16 ++++++++++ .../selectors_actions/shared/utils.cc | 11 +++++++ 3 files changed, 58 insertions(+) diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 05627dd25857f..6515661a2ee6a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -401,6 +401,37 @@ void ConvSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); } +bool EinsumNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, /*num_dq_inputs=*/-1, + /*is_empty_q_nodes_allowed=*/true)) { + return false; + } + size_t num_dq_inputs = dq_nodes.size(); + for (size_t i = 0; i < num_dq_inputs; ++i) { + int32_t dt_input = dq_nodes[i]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (!allow_int8_ && dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + return false; + } + if (!allow_16bit_ && Is16BitIntType(dt_input)) { + return false; + } + if (!allow_4bit_ && Is4BitIntType(dt_input)) { + return false; + } + } + if (!q_nodes.empty()) { + int32_t dt_input0 = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input0 != dt_output) { + return false; + } + } + return true; +} + bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 36e04146040db..e4f4844fb88ad 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -182,6 +182,22 @@ class PadNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// one ore more DQ nodes for each input -> node -> Q +class EinsumNodeGroupSelector : public NodeGroupSelector { + public: + explicit EinsumNodeGroupSelector(bool allow_int8 = true, bool allow_16bit = true, bool allow_4bit = true) + : allow_int8_(allow_int8), allow_16bit_(allow_16bit), allow_4bit_(allow_4bit) {} + + private: + bool Check(const GraphViewer& graph_viewer, + const Node& node, const Node* redundant_clip_node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; + bool allow_int8_; + bool allow_16bit_; + bool allow_4bit_; +}; + // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not // The lack of a trailing Q isn't really a QDQ node group, so we default support for that to off. class MatMulNodeGroupSelector : public NodeGroupSelector { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index e531d19d4c643..d3957a34dcfca 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -113,6 +113,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetConvOpVersionsMap() { static const OpVersionsAndSelector::OpVersionsMap GetConvTransposeOpVersionsMap() { return {{"ConvTranspose", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetEinsumOpVersionsMap() { + return {{"Einsum", {}}}; +} static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() { return {{"MatMul", {}}}; } @@ -202,6 +205,13 @@ void RegisterConvTransposeSelector(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterEinsumSelector(Selectors& qdq_selectors) { + /* register selector for einsum op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetEinsumOpVersionsMap(), + std::move(selector)); +} + void RegisterMatMulSelector(Selectors& qdq_selectors) { /* register selector for matmul op */ std::unique_ptr selector = std::make_unique(); @@ -267,6 +277,7 @@ void SelectorManager::CreateSelectors() { RegisterSplitSelector(qdq_selectors_); RegisterConvSelector(qdq_selectors_); RegisterConvTransposeSelector(qdq_selectors_); + RegisterEinsumSelector(qdq_selectors_); RegisterMatMulSelector(qdq_selectors_); RegisterGemmSelector(qdq_selectors_); RegisterInstanceAndLayerNormalizationSelector(qdq_selectors_); From 523f486a5b56d4b1139a8cbcb35ced78e3794096 Mon Sep 17 00:00:00 2001 From: Ashrit Shetty Date: Wed, 7 May 2025 17:11:12 -0700 Subject: [PATCH 34/84] Align hardware ID parsing with DXCore representation (#24682) ### Description Modified the `get_id` lambda to parse 4-character hardware ID components (e.g., from VEN_xxxx or DEV_yyyy) by converting the ASCII string representation directly to a uint32_t using `WStringToUint32Id`. This replaces the previous hexadecimal string-to-integer conversion and aligns with how DXCore reports these vendor and device IDs, ensuring consistent ID interpretation. ### Motivation and Context Before this change, we were assuming the hardware ID components were hexadecimal strings and converting them to integers. This assumption is incorrect and was leading to incorrect interpretations of the vendor and device IDs. Down the line when we compare the vendor and device id with the ids we receive from DXCORE, there are no matches, and we fail to copy the information collected in `GetDeviceInfoSetupApi`. ### Testing Tested with all the sample apps and stepped through the code to verify the Vendor and Device IDs match the IDs from DXCORE on both Qualcomm and AMD. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/platform/windows/device_discovery.cc | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 1f8600d6ca4a6..5c4c41a5799c8 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -68,6 +68,20 @@ uint64_t GetLuidKey(LUID luid) { return (uint64_t(luid.HighPart) << 32) | luid.LowPart; } +// Converts a wide string (up to 4 characters) representing a hardware ID component (e.g., "ABCD" from "VEN_ABCD") +// into a uint32_t. The conversion is done in a little-endian manner, meaning the first character +// of the string becomes the least significant byte of the integer, and the fourth character +// becomes the most significant byte. +uint32_t WStringToUint32Id(const std::wstring& vendor_name) { + uint32_t vendor_id = 0; + for (size_t i = 0; i < 4 && i < vendor_name.size(); ++i) { + // For little-endian, place each character at the appropriate byte position + // First character goes into lowest byte, last character into highest byte + vendor_id |= static_cast(vendor_name[i] & 0xFF) << (i * 8); + } + return vendor_id; +} + // returns info for display and processor entries. key is (vendor_id << 32 | device_id) // npus: (vendor_id << 32 | device_id) for devices we think are NPUs from DXCORE std::unordered_map GetDeviceInfoSetupApi(const std::unordered_set& npus) { @@ -104,28 +118,33 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde //// Get hardware ID (contains VEN_xxxx&DEV_xxxx) if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_HARDWAREID, ®DataType, (PBYTE)buffer, sizeof(buffer), &size)) { + uint32_t vendor_id = 0; + uint32_t device_id = 0; + // PCI\VEN_xxxx&DEV_yyyy&... // ACPI\VEN_xxxx&DEV_yyyy&... if we're lucky. // ACPI values seem to be very inconsistent, so we check fairly carefully and always require a device id. const auto get_id = [](const std::wstring& hardware_id, const std::wstring& prefix) -> uint32_t { if (auto idx = hardware_id.find(prefix); idx != std::wstring::npos) { auto id = hardware_id.substr(idx + prefix.size(), 4); - if (std::all_of(id.begin(), id.end(), iswxdigit)) { - return std::stoul(id, nullptr, 16); + if (id.size() == 4) { + // DXCore reports vendor and device IDs as 32-bit integer representations of the ASCII string. + return WStringToUint32Id(id); } } return 0; }; - uint32_t vendor_id = get_id(buffer, L"VEN_"); - uint32_t device_id = get_id(buffer, L"DEV_"); - // Processor ID should come from CPUID mapping. - if (vendor_id == 0 && guid == GUID_DEVCLASS_PROCESSOR) { + if (guid == GUID_DEVCLASS_PROCESSOR) { vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId(); + } else { + vendor_id = get_id(buffer, L"VEN_"); } + device_id = get_id(buffer, L"DEV_"); + // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { continue; From b7e7b6ff9a3df4a002fffb1305e24ff7a38bd858 Mon Sep 17 00:00:00 2001 From: "genmingz@AMD" Date: Thu, 8 May 2025 08:22:32 +0800 Subject: [PATCH 35/84] [VitisAI] fix graph_save dump unsorted Model (#24678) ### Description When calling graph_save, a model with unsorted nodes will be saved, which will cause errors when using the model later. ### Motivation and Context The resolve function is called in the graph_save process, which causes the nodes to be sorted. Co-authored-by: genmingz --- onnxruntime/core/providers/vitisai/imp/graph.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index e9db58143577b..88bdbaed40c73 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -178,6 +178,8 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri *model_proto->mutable_graph() = *graph_proto_subgraph; auto& logger = logging::LoggingManager::DefaultLogger(); auto model = Model::Create(std::move(*model_proto), ToPathString(filename), nullptr, logger); + auto status = model->MainGraph().Resolve(); + vai_assert(status.IsOK(), "graph resolve error:" + status.ErrorMessage()); if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model->ToProto(); } else { From 0e0002b032627aa9a30fba1cc6b2dc2d15723a3f Mon Sep 17 00:00:00 2001 From: mingyue <131847423+mingyueliuh@users.noreply.github.com> Date: Thu, 8 May 2025 08:50:48 +0800 Subject: [PATCH 36/84] [Fix] compare OrtDevice error (#24677) ### Description Fix compare OrtDevice when Debug mode Related #24371 ### Motivation and Context add compare device alignment in OrtDevice compare function --- include/onnxruntime/core/framework/ortdevice.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 472575d1998f5..2a377238e0e27 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -103,7 +103,7 @@ struct OrtDevice { }; inline bool operator==(const OrtDevice& left, const OrtDevice& other) { - return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type(); + return left.Id() == other.Id() && left.MemType() == other.MemType() && left.Type() == other.Type() && left.GetAlignment() == other.GetAlignment(); } inline bool operator!=(const OrtDevice& left, const OrtDevice& other) { From 0aaccafd41eca1580ec409d4ccd32cd1288c7e05 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Thu, 8 May 2025 02:37:36 -0400 Subject: [PATCH 37/84] [MIGraphX EP] [ROCm EP] Update CI to use ROCm 6.4 (#23535) ### Description Update Onnxruntime CI to use latest ROCm release ROCm 6.3.2 ### Motivation and Context Ensure CI is testing up to date with recent Onnxruntime. AMD validates changes based off the latest ROCm release and adds additional features on top to validate changes prior to being pushed to our internal testing branches and then up streamed to Microsoft/Onnxruntime:main . Right now Onnxruntime is testing things a few releases back for their CI --------- Co-authored-by: Ted Themistokleous --- .../github/azure-pipelines/linux-migraphx-ci-pipeline.yml | 2 +- .../ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml | 2 +- .../github/linux/docker/migraphx-ci-pipeline-env.Dockerfile | 2 +- .../github/linux/docker/rocm-ci-pipeline-env.Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index de5df97d37d3d..c6ebb80f98e12 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -37,7 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.2.3 + value: 6.4 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index af5a8d1decb6e..7388ed6d5a1e9 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -37,7 +37,7 @@ variables: - name: render value: 109 - name: RocmVersion - value: 6.3.2 + value: 6.4 jobs: - job: Linux_Build diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile index 667ffe03d8922..7b02a5e658d31 100644 --- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.2.3 +ARG ROCM_VERSION=6.4 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' diff --git a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile index 1cd5f289dd1c9..83a4e04435b95 100644 --- a/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/linux/docker/rocm-ci-pipeline-env.Dockerfile @@ -1,7 +1,7 @@ # Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete FROM ubuntu:22.04 -ARG ROCM_VERSION=6.3.2 +ARG ROCM_VERSION=6.4 ARG AMDGPU_VERSION=${ROCM_VERSION} ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600' From 05793bd2ce1a7f348b32660a3553a3439aecf0d0 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 8 May 2025 10:51:43 -0700 Subject: [PATCH 38/84] Qnn nuget package update for arm64x (#24690) ### Description Update the folder name from win-arm64x to win-arm64 since it is invalid RID: https://learn.microsoft.com/en-us/dotnet/core/rid-catalog#windows-rids --- .../targets/netstandard/props_qnn.xml | 20 +++++++++---------- .../azure-pipelines/templates/qnn-ep-win.yml | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml index fa0e957418fab..83ffb22ccf6b2 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml +++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml @@ -12,7 +12,7 @@ - $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + $(MSBuildThisFileDirectory)../../runtimes/win-arm64/native/onnxruntime.lib;%(AdditionalDependencies) @@ -24,7 +24,7 @@ - $(MSBuildThisFileDirectory)../../runtimes/win-arm64x/native/onnxruntime.lib;%(AdditionalDependencies) + $(MSBuildThisFileDirectory)../../runtimes/win-arm64/native/onnxruntime.lib;%(AdditionalDependencies) @@ -36,7 +36,7 @@ x86 - arm64x + arm64 arm $(Platform) @@ -47,31 +47,31 @@ - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll')"> onnxruntime.dll PreserveNewest false - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll')"> onnxruntime_providers_shared.dll PreserveNewest false - onnxruntime.dll PreserveNewest false - + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll')"> onnxruntime_providers_shared.dll PreserveNewest false diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 6bfc00b5b46eb..d739724f8744a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -98,7 +98,7 @@ stages: solution: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj' platform: 'Any CPU' configuration: ${{ parameters.build_config }} - msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64x' + msbuildArguments: '-t:CreatePackage -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId) -p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} -p:TargetArchitecture=arm64' workingDirectory: '$(Build.SourcesDirectory)\csharp' - task: CopyFiles@2 From 3dc91e6c31e008983dbc36ef63f0fbb253741dc7 Mon Sep 17 00:00:00 2001 From: Ishwar Raut Date: Thu, 8 May 2025 17:29:32 -0700 Subject: [PATCH 39/84] Removed the dependencies from cuda ep dll (#24656) ### Remove the dependencies from CUDA EP DLL- 1. Copied the CUDA allocator from CUDA EP to NV RTX EP 2. Copied data transfer from CUDA EP to NV RTX EP 3. Implemented the CUDA error handling in NV RTX EP ### Motivation and Context @ankan-ban @gedoensmax @chilo-ms to review --------- Co-authored-by: iraut --- .../providers/nv_tensorrt_rtx/nv_allocator.cc | 100 ++++++++++++ .../providers/nv_tensorrt_rtx/nv_allocator.h | 65 ++++++++ .../providers/nv_tensorrt_rtx/nv_cuda_call.cc | 145 ++++++++++++++++++ .../nv_tensorrt_rtx/nv_data_transfer.cc | 92 +++++++++++ .../nv_tensorrt_rtx/nv_data_transfer.h | 25 +++ .../nv_tensorrt_rtx/nv_execution_provider.cc | 21 +-- .../provider_bridge_provider.cc | 2 +- tools/ci_build/build.py | 2 +- 8 files changed, 436 insertions(+), 16 deletions(-) create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc new file mode 100644 index 0000000000000..4e8179d86fd73 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "nv_allocator.h" +#include "nv_includes.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +namespace onnxruntime { + +void CUDAAllocator::CheckDevice(bool throw_when_fail) const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call cudaSetDevice instead of the check + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + ORT_ENFORCE(current_device == Info().id); + } else if (throw_when_fail) { + CUDA_CALL_THROW(cuda_err); + } +#else + ORT_UNUSED_PARAMETER(throw_when_fail); +#endif +} + +void CUDAAllocator::SetDevice(bool throw_when_fail) const { + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + int allocator_device_id = Info().id; + if (current_device != allocator_device_id) { + cuda_err = cudaSetDevice(allocator_device_id); + } + } + + if (cuda_err != cudaSuccess && throw_when_fail) { + CUDA_CALL_THROW(cuda_err); + } +} + +void* CUDAAllocator::Alloc(size_t size) { + SetDevice(true); + CheckDevice(true); + void* p = nullptr; + if (size > 0) { + // BFCArena was updated recently to handle the exception and adjust the request size + CUDA_CALL_THROW(cudaMalloc((void**)&p, size)); + } + return p; +} + +void CUDAAllocator::Free(void* p) { + SetDevice(false); + CheckDevice(false); // ignore CUDA failure when free + cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown +} + +void* CUDAExternalAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + p = alloc_(size); + + // review(codemzs): ORT_ENFORCE does not seem appropriate. + ORT_ENFORCE(p != nullptr); + } + + return p; +} + +void CUDAExternalAllocator::Free(void* p) { + free_(p); + std::lock_guard lock(lock_); + auto it = reserved_.find(p); + if (it != reserved_.end()) { + reserved_.erase(it); + if (empty_cache_) empty_cache_(); + } +} + +void* CUDAExternalAllocator::Reserve(size_t size) { + void* p = Alloc(size); + if (!p) return nullptr; + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + return p; +} + +void* CUDAPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + CUDA_CALL_THROW(cudaMallocHost((void**)&p, size)); + } + return p; +} + +void CUDAPinnedAllocator::Free(void* p) { + CUDA_CALL_THROW(cudaFreeHost(p)); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h new file mode 100644 index 0000000000000..a3f05bded5de9 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_allocator.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include + +namespace onnxruntime { + +class CUDAAllocator : public IAllocator { + public: + CUDAAllocator(OrtDevice::DeviceId device_id, const char* name) + : IAllocator( + OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), + device_id, OrtMemTypeDefault)) {} + void* Alloc(size_t size) override; + void Free(void* p) override; + + private: + void CheckDevice(bool throw_when_fail) const; + void SetDevice(bool throw_when_fail) const; +}; + +class CUDAExternalAllocator : public CUDAAllocator { + typedef void* (*ExternalAlloc)(size_t size); + typedef void (*ExternalFree)(void* p); + typedef void (*ExternalEmptyCache)(); + + public: + CUDAExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + : CUDAAllocator(device_id, name) { + alloc_ = reinterpret_cast(alloc); + free_ = reinterpret_cast(free); + empty_cache_ = reinterpret_cast(empty_cache); + } + + void* Alloc(size_t size) override; + void Free(void* p) override; + void* Reserve(size_t size) override; + + private: + mutable std::mutex lock_; + ExternalAlloc alloc_; + ExternalFree free_; + ExternalEmptyCache empty_cache_; + InlinedHashSet reserved_; +}; + +// TODO: add a default constructor +class CUDAPinnedAllocator : public IAllocator { + public: + CUDAPinnedAllocator(const char* name) + : IAllocator( + OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, 0 /*CPU device always with id 0*/), + 0, OrtMemTypeCPUOutput)) {} + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc new file mode 100644 index 0000000000000..8e9ea1257cdd2 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_cuda_call.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include + +#ifdef _WIN32 +#else // POSIX +#include +#include +#endif + +namespace onnxruntime { + +using namespace common; + +template +const char* CudaErrString(ERRTYPE) { + ORT_NOT_IMPLEMENTED(); +} + +#define CASE_ENUM_TO_STR(x) \ + case x: \ + return #x + +template <> +const char* CudaErrString(cudaError_t x) { + cudaDeviceSynchronize(); + return cudaGetErrorString(x); +} + +#ifndef USE_CUDA_MINIMAL +template <> +const char* CudaErrString(cublasStatus_t e) { + cudaDeviceSynchronize(); + switch (e) { + CASE_ENUM_TO_STR(CUBLAS_STATUS_SUCCESS); + CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_ALLOC_FAILED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_INVALID_VALUE); + CASE_ENUM_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH); + CASE_ENUM_TO_STR(CUBLAS_STATUS_MAPPING_ERROR); + CASE_ENUM_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR); + CASE_ENUM_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED); + CASE_ENUM_TO_STR(CUBLAS_STATUS_LICENSE_ERROR); + default: + return "(look for CUBLAS_STATUS_xxx in cublas_api.h)"; + } +} + +template <> +const char* CudaErrString(curandStatus) { + cudaDeviceSynchronize(); + return "(see curand.h & look for curandStatus or CURAND_STATUS_xxx)"; +} + +template <> +const char* CudaErrString(cudnnStatus_t e) { + cudaDeviceSynchronize(); + return cudnnGetErrorString(e); +} + +template <> +const char* CudaErrString(cufftResult e) { + cudaDeviceSynchronize(); + switch (e) { + CASE_ENUM_TO_STR(CUFFT_SUCCESS); + CASE_ENUM_TO_STR(CUFFT_ALLOC_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_VALUE); + CASE_ENUM_TO_STR(CUFFT_INTERNAL_ERROR); + CASE_ENUM_TO_STR(CUFFT_SETUP_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_SIZE); + default: + return "Unknown cufft error status"; + } +} +#endif + +#ifdef ORT_USE_NCCL +template <> +const char* CudaErrString(ncclResult_t e) { + cudaDeviceSynchronize(); + return ncclGetErrorString(e); +} +#endif + +template +int GetErrorCode(ERRTYPE err) { + return static_cast(err); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, + const char* file, const int line) { + if (retCode != successCode) { + try { +#ifdef _WIN32 + std::string hostname_str = GetEnvironmentVar("COMPUTERNAME"); + if (hostname_str.empty()) { + hostname_str = "?"; + } + const char* hostname = hostname_str.c_str(); +#else + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX) != 0) + strcpy(hostname, "?"); +#endif + int currentCudaDevice = -1; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", + libName, GetErrorCode(retCode), CudaErrString(retCode), currentCudaDevice, + hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + LOGS_DEFAULT(ERROR) << str; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, + // so we'd never get to see the error + if constexpr (THRW) { + ORT_THROW(e.what()); + } else { + LOGS_DEFAULT(ERROR) << e.what(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return Status::OK(); + } +} + +template Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); +template void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc new file mode 100644 index 0000000000000..4779ddd1a9556 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.cc @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" + +#include "nv_data_transfer.h" + +#include "core/providers/cuda/shared_inc/cuda_call.h" +#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) +namespace onnxruntime { +bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::CUDA_PINNED || + dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::CUDA_PINNED; +} + +common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + // for the sync version of memcpy, launch to cuda default stream + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } else { + // copy from other CPU memory to GPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + if (src_device.MemType() != OrtDevice::MemType::CUDA_PINNED) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } + } else if (src_device.Type() == OrtDevice::GPU) { + // copying from GPU to CPU memory, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } else { + // copying between cpu memory + ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::CPU) { + // copy from pinned or non-pinned CPU memory to GPU + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, static_cast(stream.GetHandle()))); + } else if (src_device.Type() == OrtDevice::GPU) { + // copying between GPU, this is non-blocking + if (dst_data != src_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); + } + } + } else if (src_device.Type() == OrtDevice::GPU) { + if (dst_device.Type() == OrtDevice::CPU) { + // copy from GPU to pinned or non-pinned CPU memory. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, static_cast(stream.GetHandle()))); + } + } else { + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(stream.GetHandle()))); + } + + ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h new file mode 100644 index 0000000000000..272ea367ac7e4 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_data_transfer.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "nv_includes.h" +#include "core/framework/data_transfer.h" + +namespace onnxruntime { + +class GPUDataTransfer : public IDataTransfer { + public: + GPUDataTransfer() = default; + ~GPUDataTransfer() = default; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + // Dumpen MSVC warning about not fully overriding + using IDataTransfer::CopyTensor; + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + common::Status CopyTensorAsync(const Tensor& src, Tensor& dst, Stream& stream) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 25c130a849793..469d139ed03bb 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // Licensed under the MIT License. #include #include @@ -12,10 +13,11 @@ #include "nv_execution_provider.h" #include "nv_execution_provider_utils.h" #include "nv_execution_provider_custom_ops.h" +#include "nv_allocator.h" +#include "nv_data_transfer.h" #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" -#include "core/providers/cuda/gpu_data_transfer.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include @@ -113,16 +115,6 @@ void Impl_Cast( } } // namespace cuda -template <> -Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) { - return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); -} - -template <> -void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) { - return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); -} - #if NV_TENSORRT_MAJOR >= 10 void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { @@ -1311,13 +1303,14 @@ void NvExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, + [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, narrow(device_id_)); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { ORT_UNUSED_PARAMETER(device_id); - return CreateCUDAPinnedAllocator(onnxruntime::CUDA_PINNED); + return std::make_unique(CUDA_PINNED); + ; }, 0); @@ -1325,7 +1318,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { } std::unique_ptr NvExecutionProvider::GetDataTransfer() const { - return onnxruntime::CreateGPUDataTransfer(); + return std::make_unique(); } Status NvExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index eee6a05f12729..afabc1fa9b1c9 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -347,7 +347,7 @@ common::Status IExecutionProvider::Compile(const std::vector& return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs); } -#if defined(USE_TENSORRT) || defined(USE_NV) +#if defined(USE_TENSORRT) std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) { return g_host->CreateCUDAAllocator(device_id, name); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index f4b7a78fe4c99..06e48a36942e2 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2198,7 +2198,7 @@ def main(): cmake_extra_defines = normalize_arg_list(args.cmake_extra_defines) - if args.use_tensorrt or args.use_nv_tensorrt_rtx: + if args.use_tensorrt: args.use_cuda = True if args.build_wheel or args.gen_doc or args.enable_training: From d684d90cb1fee0722600511dcf36f95b0dd126fb Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Fri, 9 May 2025 08:02:28 -0700 Subject: [PATCH 40/84] [QNN EP] Fix Weight Bias Quantization implementation (#24693) - ORT is populating all node attributes in GetAttributes() API call with keeping default values for unspecified attributes in NodeProto. - Check the possibility of default values and don't skip weight_bias_quantization on Conv operator if the weight has DeQuantize node with attribute block_size=0 is read. ### Description GetAttributes() API call is populating all attributes in node definition with assigning default values for unspecified attributes in the model. Weight Bias Quantization is being skipped when block_size attribute present in the DequantizeLinear node producing the weight for a Conv operator. Gracefully handle the default value of block_size i.e., 0 and apply the Weight Bias Quantization as the default value '0' has no significance. ### Motivation and Context Applying Weight Bias Quantization on Conv operator enables ORT QDQ transformer to fold the DQ-->Conv-->Q pattern into Conv operator. This improves inference time for some QDQ ONNX models. --- .../optimizer/qdq_transformer/weight_bias_quantization.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 83c5d7bc8d92a..58e90ea3c71c2 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -89,7 +89,9 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph } const auto& dq_attrs = dq_1->GetAttributes(); - if (dq_attrs.find("block_size") != dq_attrs.end()) { + auto attr_it = dq_attrs.find("block_size"); + // Default value of block_size=0 has no significance. Don't skip weight_bias_quantization. + if (attr_it != dq_attrs.end() && attr_it->second.i() != 0) { continue; } From 1920ba16c6c9749f9028785ee7918dd8fbdb1fa9 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 9 May 2025 08:40:29 -0700 Subject: [PATCH 41/84] Implement Public API GetTensorSizeInBytes (#24686) ### Description Implement Public API GetTensorSizeInBytes ### Motivation and Context Addresses https://github.com/microsoft/onnxruntime/issues/24680 --- .../core/session/onnxruntime_c_api.h | 15 +++++++++ .../core/session/onnxruntime_cxx_api.h | 12 +++++++ .../core/session/onnxruntime_cxx_inline.h | 7 ++++ onnxruntime/core/session/onnxruntime_c_api.cc | 28 ++++++++++++++++ onnxruntime/core/session/ort_apis.h | 3 ++ onnxruntime/test/shared_lib/test_inference.cc | 33 +++++++++++++++++++ 6 files changed, 98 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a4a32fbea630a..25b6d72394e0c 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5251,6 +5251,21 @@ struct OrtApi { * \since Version 1.22. */ const OrtEpApi*(ORT_API_CALL* GetEpApi)(); + + /** \brief Compute total size in bytes of the tensor data contained in an OrtValue. + * + * Returns the total number of bytes used to store the tensor data. For numeric tensors, + * this is sizeof(element_type) * total_element_count. OrtValues that are not tensors or + * that are tensors that contain strings will cause an error to be returned. + * + * \param[in] ort_value OrtValue instance containing a tensor + * \param[out] size The total size of the tensor data in bytes + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 39c20e237b02c..8876c40fe9e6c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1738,6 +1738,18 @@ struct ConstValueImpl : Base { /// byte length for the specified string element size_t GetStringTensorElementLength(size_t element_index) const; + /// + /// Returns the total size of the tensor data in bytes. + /// + /// The total size of the tensor data in bytes + /// Throws an exception if the OrtValue does not contain a tensor or + /// if it contains a tensor that contains strings + /// + /// For numeric tensors, this is sizeof(element_type) * total_element_count. + /// + /// + size_t GetTensorSizeInBytes() const; ///< Wraps OrtApi::GetTensorSizeInBytes + #if !defined(DISABLE_SPARSE_TENSORS) /// /// The API returns the sparse data format this OrtValue holds in a sparse tensor. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 94ad2118fa4d6..0d0b3198a8736 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1823,6 +1823,13 @@ inline size_t ConstValueImpl::GetStringTensorElementLength(size_t element_ind return out; } +template +inline size_t ConstValueImpl::GetTensorSizeInBytes() const { + size_t out; + ThrowOnError(GetApi().GetTensorSizeInBytes(this->p_, &out)); + return out; +} + template template inline const R* ConstValueImpl::GetTensorData() const { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index d03b98a9c1eb5..868fab767fa7b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1206,6 +1206,33 @@ ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElementLength, _In_ const OrtValue* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetTensorSizeInBytes, _In_ const OrtValue* value, _Out_ size_t* size) { + API_IMPL_BEGIN + + if (value == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Input `value` argument must not be null"); + } + + if (size == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output `size` argument must not be null"); + } + + if (!value->IsAllocated() || !value->IsTensor()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue is expected to contain a tensor"); + } + + const auto& tensor = value->Get(); + + // Check if this is a string tensor + if (tensor.IsDataTypeString()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "String tensors are not supported by this API"); + } + + *size = tensor.SizeInBytes(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len) { API_IMPL_BEGIN @@ -2996,6 +3023,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::GetTensorSizeInBytes, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 47d1a543b5a31..81af6694f6273 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -598,4 +598,7 @@ ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_ ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtEpApi*, GetEpApi); + +ORT_API_STATUS_IMPL(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); + } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0a4ea71933724..6460e3cb3aec4 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -425,6 +425,39 @@ TEST_P(CApiTestWithProvider, simple) { nullptr, nullptr); } +template +void TestGetTensorSizeInBytes(Ort::ConstMemoryInfo cpu_meminfo) { + constexpr const size_t expected_size_in_bytes = sizeof(T) * element_count_to_create; + constexpr const std::array dims = {1, static_cast(element_count_to_create)}; + std::array data; + std::fill(data.begin(), data.end(), T{1}); + + auto value = Ort::Value::CreateTensor(cpu_meminfo, data.data(), + data.size(), dims.data(), dims.size()); + + auto type_info = value.GetTypeInfo(); + ASSERT_EQ(type_info.GetONNXType(), ONNX_TYPE_TENSOR); + auto tensor_type_info = type_info.GetTensorTypeAndShapeInfo(); + const auto element_count = tensor_type_info.GetElementCount(); + ASSERT_EQ(expected_size_in_bytes / sizeof(T), element_count); + ASSERT_EQ(expected_size_in_bytes, value.GetTensorSizeInBytes()); +} + +TEST(CApiTest, TestGetTensorSizeInBytes) { + Ort::MemoryInfo cpu_meminfo("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); + TestGetTensorSizeInBytes(cpu_meminfo.GetConst()); +} + TEST(CApiTest, dim_param) { Ort::SessionOptions session_options; Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); From 67f95573694b86cf34ac2174022714f05a8f8e68 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 9 May 2025 08:40:53 -0700 Subject: [PATCH 42/84] Fix improper iterator usage (#24698) ### Description This pull request includes a minor but important fix to the iteration logic in the `NvExecutionProvider::GetCapability` method to ensure correctness and avoid potential undefined behavior when removing elements from a container during iteration. ### Iteration logic fix: * In `onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc`, the loop for removing subgraphs smaller than the minimal size was updated to use `it = supported_nodes_vector.erase(it)` instead of `erase(it--)`. This ensures the iterator remains valid and avoids decrementing it unnecessarily, improving code safety and readability. ### Motivation and Context Prevent possible memory corruption. Similar code was addressed in tensorrt EP. --- .../core/providers/nv_tensorrt_rtx/nv_execution_provider.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 469d139ed03bb..b2950cd2c5da3 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2014,10 +2014,12 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph, } // Remove subgraphs if its size is less than the predefined minimal size - for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end();) { const size_t subgraph_size = it->first.size(); if (subgraph_size < min_subgraph_size_) { - supported_nodes_vector.erase(it--); + it = supported_nodes_vector.erase(it); + } else { + ++it; } } From 0d9f4e70df88da00c84522a450e85b4cc126130f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 9 May 2025 09:09:38 -0700 Subject: [PATCH 43/84] [web CI] fixes shader key validation failure (#24701) ### Description loose the regex for the log prefix. this fixes the shader key validation failure and makes it less sensitive to future log format change. ### Motivation and Context with chrome update, the chrome_debug logging format changed a little bit. --- .../webgpu-validate-shader-key/parse-chromium-debug-log.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js b/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js index 2342381f03934..d34ff95192089 100644 --- a/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js +++ b/.github/actions/webgpu-validate-shader-key/parse-chromium-debug-log.js @@ -16,7 +16,7 @@ async function processChromiumDebugLog() { for await (const line of rl) { const result = - /^\[.+INFO:CONSOLE\(\d+\)]\ "(?.+)",\ source:\ [^"]+?\(\d+\)$/.exec( + /^\[.+INFO:CONSOLE.+?]\ "(?.+)",\ source:\ [^"]+?\(\d+\)$/.exec( line ); if (!result) { From 8a97463aeac9f2be8a3bffdcac6b8a18c7b19bb1 Mon Sep 17 00:00:00 2001 From: Qiujiao Date: Sat, 10 May 2025 00:10:01 +0800 Subject: [PATCH 44/84] Add trace event control for ORT Web performance profiling (#23393) ### Description Add trace event control to better profile ORT web performance ### Motivation and Context ORT Web's current tracing implementation lacks interfaces for performance profiling using about://tracing. This PR introduces these interfaces, enabling performance bottleneck identification in ORT Web and adding several trace events for WebNN. --- js/common/lib/inference-session-impl.ts | 6 ++++- js/common/lib/trace.ts | 22 +++++++++++++++++++ js/web/lib/wasm/jsep/init.ts | 2 ++ js/web/lib/wasm/wasm-core-impl.ts | 8 ++++++- js/web/lib/wasm/wasm-types.ts | 1 + .../core/providers/webnn/builders/model.cc | 8 +++++++ .../core/providers/webnn/data_transfer.cc | 11 ++++++++++ onnxruntime/wasm/pre-jsep.js | 1 + 8 files changed, 57 insertions(+), 2 deletions(-) diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 797dba8b94089..877c595bffd15 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -6,7 +6,7 @@ import { InferenceSessionHandler } from './backend.js'; import { InferenceSession as InferenceSessionInterface } from './inference-session.js'; import { OnnxValue } from './onnx-value.js'; import { Tensor } from './tensor.js'; -import { TRACE_FUNC_BEGIN, TRACE_FUNC_END } from './trace.js'; +import { TRACE_FUNC_BEGIN, TRACE_FUNC_END, TRACE_EVENT_BEGIN, TRACE_EVENT_END } from './trace.js'; type SessionOptions = InferenceSessionInterface.SessionOptions; type RunOptions = InferenceSessionInterface.RunOptions; @@ -22,6 +22,7 @@ export class InferenceSession implements InferenceSessionInterface { run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise; async run(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise { TRACE_FUNC_BEGIN(); + TRACE_EVENT_BEGIN('InferenceSession.run'); const fetches: { [name: string]: OnnxValue | null } = {}; let options: RunOptions = {}; // check inputs @@ -120,6 +121,7 @@ export class InferenceSession implements InferenceSessionInterface { } } } + TRACE_EVENT_END('InferenceSession.run'); TRACE_FUNC_END(); return returnValue; } @@ -144,6 +146,7 @@ export class InferenceSession implements InferenceSessionInterface { arg3?: SessionOptions, ): Promise { TRACE_FUNC_BEGIN(); + TRACE_EVENT_BEGIN('InferenceSession.create'); // either load from a file or buffer let filePathOrUint8Array: string | Uint8Array; let options: SessionOptions = {}; @@ -207,6 +210,7 @@ export class InferenceSession implements InferenceSessionInterface { // resolve backend, update session options with validated EPs, and create session handler const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options); const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs); + TRACE_EVENT_END('InferenceSession.create'); TRACE_FUNC_END(); return new InferenceSession(handler); } diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 25d178f15a29d..0f20dd39935ac 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -51,3 +51,25 @@ export const TRACE_FUNC_END = (extraMsg?: string) => { } TRACE_FUNC('END', extraMsg); }; + +/** + * @ignore + */ +export const TRACE_EVENT_BEGIN = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + // eslint-disable-next-line no-console + console.time(`ORT::${extraMsg}`); +}; + +/** + * @ignore + */ +export const TRACE_EVENT_END = (extraMsg?: string) => { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { + return; + } + // eslint-disable-next-line no-console + console.timeEnd(`ORT::${extraMsg}`); +}; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 8ab6b054bf8a7..463e26d0208e5 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -291,6 +291,8 @@ export const init = async ( }, // jsepDownloadTensor async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), + // jsepEnableTraceEvent + !!env.trace, ]); } }; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index f42a224ed2e85..cfdc0053b3485 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -6,7 +6,7 @@ // https://github.com/webmachinelearning/webnn/issues/677 /// -import { Env, InferenceSession, Tensor } from 'onnxruntime-common'; +import { Env, InferenceSession, Tensor, TRACE_EVENT_BEGIN, TRACE_EVENT_END } from 'onnxruntime-common'; import { SerializableInternalBuffer, @@ -711,6 +711,7 @@ export const run = async ( try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); + TRACE_EVENT_BEGIN('wasm prepareInputOutputTensor'); // create input tensors for (let i = 0; i < inputCount; i++) { await prepareInputOutputTensor( @@ -736,6 +737,7 @@ export const run = async ( enableGraphCapture, ); } + TRACE_EVENT_END('wasm prepareInputOutputTensor'); for (let i = 0; i < inputCount; i++) { wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*'); @@ -755,6 +757,7 @@ export const run = async ( ); } + TRACE_EVENT_BEGIN('wasm bindInputsOutputs'); // process inputs for (let i = 0; i < inputCount; i++) { const index = inputIndices[i]; @@ -788,6 +791,7 @@ export const run = async ( } } } + TRACE_EVENT_END('wasm bindInputsOutputs'); activeSessions.set(sessionId, [ sessionHandle, inputNamesUTF8Encoded, @@ -830,6 +834,7 @@ export const run = async ( const output: TensorMetadata[] = []; const outputPromises: Array> = []; + TRACE_EVENT_BEGIN('wasm ProcessOutputTensor'); for (let i = 0; i < outputCount; i++) { const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); if (tensor === outputTensorHandles[i]) { @@ -1028,6 +1033,7 @@ export const run = async ( for (const [index, data] of await Promise.all(outputPromises)) { output[index][2] = data; } + TRACE_EVENT_END('wasm ProcessOutputTensor'); return output; } finally { wasm.webnnOnRunEnd?.(sessionHandle); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index f2d051927b1d5..29a4028ae46cc 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -71,6 +71,7 @@ export declare namespace JSEP { ensureTensor: EnsureTensorFunction, uploadTensor: UploadTensorFunction, downloadTensor: DownloadTensorFunction, + enableTraceEvent: boolean, ], ): void; } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 40fdfc609e6a1..ef829e82823d0 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -158,6 +158,11 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap& outputs) { auto webnnEnsureTensor = emscripten::val::module_property("webnnEnsureTensor"); auto promises = emscripten::val::array(); + bool trace = emscripten::val::module_property("webnnEnableTraceEvent").as(); + emscripten::val console = emscripten::val::global("console"); + if (trace) { + console.call("time", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + } for (const auto& [_, tensor] : inputs) { emscripten::val shape = emscripten::val::array(); for (const auto& dim : tensor.tensor_info.shape) { @@ -176,6 +181,9 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", ml_tensor); } + if (trace) { + console.call("timeEnd", emscripten::val("ORT::Dispatch::jsepEnsureTensor")); + } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); for (const auto& [name, _] : inputs) { wnn_inputs_.set(name, ml_tensors.call("shift")); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index aa85277b72453..17369e6fbc75d 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -20,6 +20,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the data. return Status::OK(); } + bool trace = emscripten::val::module_property("webnnEnableTraceEvent").as(); + emscripten::val console = emscripten::val::global("console"); + if (trace) { + console.call("time", emscripten::val("ORT::DataTransfer::CopyTensor")); + } size_t bytes = src.SizeInBytes(); if (bytes > 0) { @@ -30,10 +35,16 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnUploadTensor")); + } } else { auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor"); auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); webnnDownloadTensor(reinterpret_cast(src_data), subarray).await(); + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnDownloadTensor")); + } } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index df09920bddebd..8232a286d4480 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -104,6 +104,7 @@ Module["jsepInit"] = (name, params) => { Module["webnnEnsureTensor"], Module.webnnUploadTensor, Module["webnnDownloadTensor"], + Module["webnnEnableTraceEvent"], ] = params.slice(1); // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. From 17cffcb936febb5d9b166574b1803e77cc83182d Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 9 May 2025 16:47:01 -0700 Subject: [PATCH 45/84] Update AnyChart library to version 8.13.0 in perf_view HTML (#24684) ### Description The perf view uses outdated any chart version, which results in perf_view not showing any summary. Previously happened: https://github.com/microsoft/onnxruntime/issues/16873 --- tools/perf_view/ort_perf_view.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/perf_view/ort_perf_view.html b/tools/perf_view/ort_perf_view.html index 509fe5593f6a1..9cd65f06d9337 100644 --- a/tools/perf_view/ort_perf_view.html +++ b/tools/perf_view/ort_perf_view.html @@ -3,9 +3,9 @@ Onnxruntime Perf View - + - +