From 54534057d24eea574f492f78262e09761ccfec06 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 28 Mar 2025 15:01:38 +0100 Subject: [PATCH 01/24] first draft to migrate to newer version of transformers --- .../models/llama/convert_to_onnx.py | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 89fd613ecbbc2..f2a74a8b88388 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -5,6 +5,12 @@ # -------------------------------------------------------------------------- from __future__ import annotations +import pprint +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_test_helper import replace_string_by_dynamic +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes + import argparse import logging import os @@ -16,7 +22,7 @@ import onnx import torch -from benchmark_helper import Precision, prepare_environment, setup_logger +from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs @@ -141,9 +147,37 @@ def run_dynamo_export( ) temp_dir = tempfile.TemporaryDirectory() temp_path = os.path.join(temp_dir.name, "temp.onnx") - torch.onnx.dynamo_export( - llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True) - ).save(temp_path) + + input_names = ["input_ids", "attention_mask", "position_ids"] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_model_dynamic_axes(input_names, output_names) + + model_args = (input_ids, attn_mask, pos_ids, past_kv) + model_args, model_kwargs, dynamic_shapes = convert_dynamic_axes_into_dynamic_shapes( + llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"} + ) + + with bypass_export_some_errors(patch_transformers=True): + torch.export.export( + llama, + (), + kwargs=model_kwargs, + dynamic_shapes=replace_string_by_dynamic(dynamic_shapes), + ) + torch.onnx.export( + llama, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + ) # Check decoder_with_past_model.onnx and save all external data to one file onnx.checker.check_model(temp_path) @@ -330,6 +364,7 @@ def run_torchscript_merged_export( temp_dir = f"./temp_{rank}" _prepare_dir(temp_dir) temp_path = os.path.join(temp_dir, "temp.onnx") + torch.onnx.export( llama, args=decoder_merged_inputs, @@ -338,9 +373,11 @@ def run_torchscript_merged_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, + dynamic_shapes=dynamic_shapes, opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, + dynamo=args.dynamo, ) # Check decoder_merged_model.onnx and save all external data to one file From 31e82a98dba95762cc4cf98b92c380854cb04a94 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 28 Mar 2025 16:09:11 +0100 Subject: [PATCH 02/24] add patches --- .../models/llama/convert_to_onnx.py | 21 +- .../models/torch_export_patches/__init__.py | 110 ++++ .../onnx_export_errors.py | 471 ++++++++++++++++ .../onnx_export_serialization.py | 131 +++++ .../torch_export_patches/patch_inputs.py | 168 ++++++ .../torch_export_patches/patches/__init__.py | 0 .../patches/patch_torch.py | 331 ++++++++++++ .../patches/patch_transformers.py | 510 ++++++++++++++++++ 8 files changed, 1729 insertions(+), 13 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/patches/__init__.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index f2a74a8b88388..ff67696ae0f6d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -5,12 +5,6 @@ # -------------------------------------------------------------------------- from __future__ import annotations -import pprint -from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_test_helper import replace_string_by_dynamic -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors -from onnx_diagnostic.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes - import argparse import logging import os @@ -23,7 +17,7 @@ import onnx import torch from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger -from convert_generation import replace_mha_with_gqa +from onnxruntime.transformers.convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -36,6 +30,13 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer +# to patch transformers before exporting for transformers >= 4.45 +from onnxruntime.transformers.models.torch_export_patches import bypass_export_some_errors +from onnxruntime.transformers.models.torch_export_patches.patch_inputs import ( + convert_dynamic_axes_into_dynamic_shapes, +) + + torch_export_onnx_opset_version = 14 logger = logging.getLogger("") init_dist() @@ -163,12 +164,6 @@ def run_dynamo_export( ) with bypass_export_some_errors(patch_transformers=True): - torch.export.export( - llama, - (), - kwargs=model_kwargs, - dynamic_shapes=replace_string_by_dynamic(dynamic_shapes), - ) torch.onnx.export( llama, (), diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py new file mode 100644 index 0000000000000..a8f32918ef72d --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -0,0 +1,110 @@ +from typing import Any, List, Tuple +import packaging.version as pv +import torch +import transformers +from onnxruntime.transformers.models.torch_export_patches.onnx_export_errors import ( + bypass_export_some_errors, + register_additional_serialization_functions, +) + + +def is_torchdynamo_exporting() -> bool: + "Tells if torch is exporting a model." + import torch + + if not hasattr(torch.compiler, "is_exporting"): + # torch.compiler.is_exporting requires torch>=2.7 + return False + + try: + return torch.compiler.is_exporting() + except Exception: + try: + import torch._dynamo as dynamo + + return dynamo.is_exporting() # type: ignore + except Exception: + return False + + +def string_type(anything, **args): + # too long + return str(anything) + + +if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): + + def make_dynamic_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + ''' + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers >= 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + + Example: + + :: + + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(n_layers) + ] + ) + print(string_type(past_key_values, with_shape=True)) + ''' + return transformers.cache_utils.DynamicCache(key_value_pairs) + +else: + + def make_dynamic_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + ''' + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers < 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + + Example: + + :: + + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + + past_key_values = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(n_layers) + ] + ) + print(string_type(past_key_values, with_shape=True)) + ''' + cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + for i, (key, value) in enumerate(key_value_pairs): + cache.update(key, value, i) + return cache + + +def make_encoder_decoder_cache( + self_attention_cache: transformers.cache_utils.DynamicCache, + cross_attention_cache: transformers.cache_utils.DynamicCache, +) -> transformers.cache_utils.EncoderDecoderCache: + "Creates an EncoderDecoderCache." + return transformers.cache_utils.EncoderDecoderCache( + self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache + ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py new file mode 100644 index 0000000000000..b9463303e4bf7 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -0,0 +1,471 @@ +import contextlib +import pprint +from typing import Any, Callable, Dict +from .onnx_export_serialization import ( + flatten_with_keys_dynamic_cache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_mamba_cache, + flatten_with_keys_mamba_cache, + unflatten_mamba_cache, +) +from onnxruntime.transformers.models.torch_export_patches.patches import patch_transformers as patch_transformers_list + + +def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: + """ + Applies all patches defined in classes prefixed by ``patched_`` + ``cls._PATCHED_CLASS_`` defines the class to patch, + ``cls._PATCHES_`` defines the method to patch. + The returns information needs to be sent to :func:`unpatch_module` + to revert the changes. + """ + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + + res = {} + for cls in to_patch: + original = cls._PATCHED_CLASS_ + methods = cls._PATCHES_ + if verbose: + print(f"[patch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + + keep = {n: getattr(original, n, None) for n in methods} + for n in methods: + setattr(original, n, getattr(cls, n)) + res[cls] = keep + + return res + + +def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0): + """Reverts modification made by :func:`patch_module`.""" + to_patch = [] + for k in dir(mod): + if k.startswith("patched_"): + v = getattr(mod, k) + if hasattr(v, "_PATCHED_CLASS_") and hasattr(v, "_PATCHES_"): + to_patch.append(v) + set_patch = set(to_patch) + + for cls, methods in info.items(): + assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})" + if verbose: + print(f"[unpatch_module] {mod.__name__} - {cls.__name__}: {', '.join(methods)}") + original = cls._PATCHED_CLASS_ + for n, v in methods.items(): + if v is None: + # The method did not exist. We remove it. + delattr(original, n) + else: + setattr(original, n, v) + + +def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: + # Cache serialization: to be moved into appropriate packages + import torch + + try: + from transformers.cache_utils import DynamicCache + except ImportError: + DynamicCache = None + + try: + from transformers.cache_utils import MambaCache + except ImportError: + MambaCache = None + + # MambaCache + unregistered_mamba_cache = True + if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {MambaCache} already registered") + # It is already registered because bypass_export_some_errors was called + # within a section already calling bypass_export_some_errors or transformers + # has updated its code to do it. + # No need to register and unregister then. + unregistered_mamba_cache = False + else: + if verbose: + print("[_register_cache_serialization] register MambaCache") + torch.utils._pytree.register_pytree_node( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_mamba_cache, + ) + + # DynamicCache + unregistered_dynamic_cache = True + if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: + if verbose > 1: + print(f"[_register_cache_serialization] {DynamicCache} already registered") + unregistered_dynamic_cache = False + else: + if verbose: + print("[_register_cache_serialization] register DynamicCache") + torch.utils._pytree.register_pytree_node( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_dynamic_cache, + ) + torch.fx._pytree.register_pytree_flatten_spec( + DynamicCache, lambda x, _: [x.key_cache, x.value_cache] + ) + + # check + from ..cache_helpers import make_dynamic_cache + + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) + values, spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(values, spec) + # torch.fx._pytree.tree_flatten(cache) + assert len(cache2.key_cache) == 1 + + return dict(DynamicCache=unregistered_dynamic_cache, MambaCache=unregistered_mamba_cache) + + +def _unregister(cls: type, verbose: int = 0): + import optree + import torch + + # torch.fx._pytree._deregister_pytree_flatten_spec(cls) + if cls in torch.fx._pytree.SUPPORTED_NODES: + del torch.fx._pytree.SUPPORTED_NODES[cls] + if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: + del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] + if hasattr(torch.utils._pytree, "_deregister_pytree_node"): + # torch >= 2.7 + torch.utils._pytree._deregister_pytree_node(cls) + optree.unregister_pytree_node(cls, namespace="torch") + if cls in torch.utils._pytree.SUPPORTED_NODES: + import packaging.version as pv + + if pv.Version(torch.__version__) < pv.Version("2.7.0"): + del torch.utils._pytree.SUPPORTED_NODES[cls] + assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( + f"{cls} was not successful unregistered " + f"from torch.utils._pytree.SUPPORTED_NODES=" + f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" + ) + if verbose: + print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") + + +def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): + + if undo.get("MambaCache", False): + from transformers.cache_utils import MambaCache + + _unregister(MambaCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister MambaCache") + + if undo.get("DynamicCache", False): + from transformers.cache_utils import DynamicCache + + _unregister(DynamicCache, verbose) + elif verbose > 1: + print("[_unregister_cache_serialization] skip unregister DynamicCache") + + +@contextlib.contextmanager +def register_additional_serialization_functions( + patch_transformers: bool = False, verbose: int = 0 +) -> Callable: + """The necessary modifications to run the fx Graph.""" + fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + + +@contextlib.contextmanager +def bypass_export_some_errors( + patch_sympy: bool = True, + patch_torch: bool = True, + patch_transformers: bool = False, + catch_constraints: bool = True, + stop_if_static: bool = False, + verbose: int = 0, + patch: bool = True, +) -> Callable: + """ + Tries to bypass some situations :func:`torch.export.export` does not support. + + :param patch_sympy: fix missing method ``name`` for IntegerConstant + :param patch_torch: patches :epkg:`torch` with supported implementation + :param patch_transformers: patches :epkg:`transformers` with supported implementation + :param catch_constraints: catch constraints related to dynamic shapes, + as a result, some dynamic dimension may turn into static ones, + the environment variable ``SKIP_SOLVE_CONSTRAINTS=0`` + can be put to stop at that stage. + :param stop_if_static: see example :ref:`l-plot-export-locale-issue`, + to stop the export as soon as an issue is detected with dynamic shapes + and show a stack trace indicating the exact location of the issue + :param patch: if False, disable all patches except the registration of + serialization function + :param verbose: to show which patches is applied + + The list of available patches. + + * ``torch.jit.isinstance`` + * ``torch._dynamo.mark_static_address`` + * ``torch._subclasses.fake_impls.infer_size`` + * fix missing method ``name`` for ``sympy.S.IntegerConstant`` + * ``AttentionMaskConverter._make_causal_mask`` + * Serialization of ``MambaCache`` (in :epkg:`transformers`) + * Serialization of ``DynamicCache`` (in :epkg:`transformers`) + * reduce errors due to shape inference + * fixes some transformers classes + + Serialization issues happen when a module takes one input or output + has a type :func:`torch.export.export` cannot serialize. + + Examples: + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + onx = to_onnx(..., inputs, ...) + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + onx = torch.onnx.export(..., inputs, ...) + + It can be used as well to fix the torch export: + + :: + + with bypass_export_some_errors(patch_transformers=True) as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When running the model through the exported program, only the + serialization functions need to be restored: + + :: + + with register_additional_serialization_functions() as modificator: + inputs = modificator(inputs) + ep = torch.export.export(..., inputs, ...) + + When exporting a model with a cache, the following error message + may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``. + It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`. + """ + if not patch: + fct_callable = lambda x: x # noqa: E731 + done = _register_cache_serialization(verbose=verbose) + try: + yield fct_callable + finally: + _unregister_cache_serialization(done, verbose=verbose) + else: + import torch + import torch._export.non_strict_utils # produce_guards_and_solve_constraints + import torch.jit + + if verbose: + print( + "[bypass_export_some_errors] replace torch.jit.isinstance, " + "torch._dynamo.mark_static_address" + ) + + ######## + # caches + ######## + + cache_done = _register_cache_serialization(verbose=verbose) + + ############# + # patch sympy + ############# + + if patch_sympy: + import sympy + + f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None) + + if verbose: + print("[bypass_export_some_errors] patch sympy") + + sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}" + + ############### + # patch pytorch + ############### + + if patch_torch: + from .patches.patch_torch import ( + patched_infer_size, + patched__broadcast_shapes, + _catch_produce_guards_and_solve_constraints, + patch__check_input_constraints_for_graph, + ) + + if verbose: + print("[bypass_export_some_errors] patch pytorch") + + # torch.jit.isinstance + f_jit_isinstance = torch.jit.isinstance + torch.jit.isinstance = isinstance + + # torch._dynamo.mark_static_address + f_mark_static_address = torch._dynamo.mark_static_address + torch._dynamo.mark_static_address = lambda *_, **y_: None + + # torch._subclasses.fake_impls.infer_size + f_infer_size = torch._subclasses.fake_impls.infer_size + torch._subclasses.fake_impls.infer_size = patched_infer_size + + # torch._refs._broadcast_shapes + f__broadcast_shapes = torch._refs._broadcast_shapes + torch._refs._broadcast_shapes = patched__broadcast_shapes + torch._meta_registrations._broadcast_shapes = patched__broadcast_shapes + + # torch._export.non_strict_utils.produce_guards_and_solve_constraints + if catch_constraints: + if verbose: + print("[bypass_export_some_errors] modifies shape constraints") + f_produce_guards_and_solve_constraints = ( + torch._export.non_strict_utils.produce_guards_and_solve_constraints + ) + f__check_input_constraints_for_graph = ( + torch._export.utils._check_input_constraints_for_graph + ) + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints( + f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs + ) + ) + torch._export.utils._check_input_constraints_for_graph = ( + lambda *args, **kwargs: patch__check_input_constraints_for_graph( + f__check_input_constraints_for_graph, *args, verbose=verbose, **kwargs + ) + ) + + if stop_if_static: + if verbose: + print( + "[bypass_export_some_errors] assert when a dynamic dimension turns static" + ) + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from .patches.patch_torch import patched_ShapeEnv + + f_shape_env__set_replacement = ShapeEnv._set_replacement + ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement + + #################### + # patch transformers + #################### + + if patch_transformers: + revert_patches_info = patch_module(patch_transformers_list, verbose=verbose) + + ######## + # export + ######## + + fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) + + if verbose: + print("[bypass_export_some_errors] done patching") + + try: + yield fct_callable + finally: + ####### + # sympy + ####### + + if verbose: + print("[bypass_export_some_errors] remove patches") + + if patch_sympy: + + # tracked by https://github.com/pytorch/pytorch/issues/143494 + if f_sympy_name: + sympy.core.numbers.IntegerConstant.name = f_sympy_name + else: + delattr(sympy.core.numbers.IntegerConstant, "name") + + if verbose: + print("[bypass_export_some_errors] restored sympy functions") + + ####### + # torch + ####### + + if patch_torch: + # this should disappear when torch.jit is removed + torch.jit.isinstance = f_jit_isinstance + torch._dynamo.mark_static_address = f_mark_static_address + # tracked by https://github.com/pytorch/pytorch/issues/143495 + torch._subclasses.fake_impls.infer_size = f_infer_size + torch._refs._broadcast_shapes = f__broadcast_shapes + torch._meta_registrations._broadcast_shapes = f__broadcast_shapes + + if verbose: + print("[bypass_export_some_errors] restored pytorch functions") + + if stop_if_static: + if verbose: + print("[bypass_export_some_errors] restored ShapeEnv._set_replacement") + + ShapeEnv._set_replacement = f_shape_env__set_replacement + + if catch_constraints: + # to catch or skip dynamic_shapes issues + torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( + f_produce_guards_and_solve_constraints + ) + torch._export.utils._check_input_constraints_for_graph = ( + f__check_input_constraints_for_graph + ) + if verbose: + print("[bypass_export_some_errors] restored shape constraints") + + ############## + # transformers + ############## + + if patch_transformers: + unpatch_module(patch_transformers_list, revert_patches_info, verbose=verbose) + + ######## + # caches + ######## + + _unregister_cache_serialization(cache_done, verbose=verbose) + + +def replacement_before_exporting(args: Any) -> Any: + """ + Does replacements on the given inputs if needed. + """ + if args is None: + return None + if isinstance(args, (int, float)): + return args + if isinstance(args, dict): + return {k: replacement_before_exporting(v) for k, v in args.items()} + if isinstance(args, tuple): + return tuple(replacement_before_exporting(v) for v in args) + if isinstance(args, list): + return [replacement_before_exporting(v) for v in args] + + return args diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py new file mode 100644 index 0000000000000..66b447fd3566e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, List, Tuple +import torch +import transformers + +############ +# MambaCache +############ + + +# self.conv_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.conv_kernel_size, +# device=device, +# dtype=dtype, +# ) +# self.ssm_states: torch.Tensor = torch.zeros( +# config.num_hidden_layers, +# self.max_batch_size, +# self.intermediate_size, +# self.ssm_state_size, +# device=device, +# dtype=dtype, +# ) +def flatten_mamba_cache( + mamba_cache: transformers.cache_utils.MambaCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + flat = [ + (k, getattr(mamba_cache, k)) + for k in [ + # "max_batch_size", # new in transformers==4.47 + # "intermediate_size", + # "ssm_state_size", + # "conv_kernel_size", + "conv_states", + "ssm_states", + ] + if hasattr(mamba_cache, k) + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def unflatten_mamba_cache( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.MambaCache: + """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" + conv_states, ssm_states = values + + class _config: + def __init__(self): + if isinstance(conv_states, list): + self.intermediate_size = conv_states[0].shape[1] + self.state_size = ssm_states[0].shape[2] + self.conv_kernel = conv_states[0].shape[2] + self.num_hidden_layers = len(conv_states) + else: + self.intermediate_size = conv_states.shape[2] + self.state_size = ssm_states.shape[3] + self.conv_kernel = conv_states.shape[3] + self.num_hidden_layers = conv_states.shape[0] + + from transformers.cache_utils import MambaCache + + cache = MambaCache( + _config(), + max_batch_size=1, + dtype=values[-1][0].dtype, + device="cpu" if values[-1][0].get_device() < 0 else "cuda", + ) + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + +def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" + import torch + + values, context = flatten_mamba_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +############## +# DynamicCache +############## + + +def flatten_dynamic_cache( + dynamic_cache: transformers.cache_utils.DynamicCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + flat = [ + (k, getattr(dynamic_cache, k)) + for k in ["key_cache", "value_cache"] + if hasattr(dynamic_cache, k) + ] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + import torch + + values, context = flatten_dynamic_cache(d) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_dynamic_cache( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> transformers.cache_utils.DynamicCache: + """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" + from transformers.cache_utils import DynamicCache + + cache = DynamicCache() + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py new file mode 100644 index 0000000000000..6318cc746275a --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -0,0 +1,168 @@ +import inspect +from typing import Any, Dict, Optional, Tuple +import torch +import transformers +from onnxruntime.transformers.models.torch_export_patches import make_dynamic_cache, string_type + + +def _process_cache(k: str, v): + assert k != "position_ids" or isinstance( + k, torch.Tensor + ), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}" + if ( + isinstance(v, list) + and all(isinstance(i, tuple) for i in v) + and set(len(t) for t in v) == {2} + ): + # A dynamicCache + cache = make_dynamic_cache(v) + return cache + if isinstance(v, torch.Tensor): + return v + raise NotImplementedError( + f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}" + ) + + +def _make_shape(subset: Dict, cls: type, value: Any) -> Any: + if cls is transformers.cache_utils.DynamicCache: + assert subset, "DynamicCache cannot be empty" + values = set(map(str, subset.values())) + assert len(values) == 1, ( + f"Inconsistencies in subset={subset}, found={values}, " + f"it cannot be a {cls}, value={string_type(value)}" + ) + cache_length = len(value.key_cache) + for v in subset.values(): + axes = v + break + new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]] + return new_shape + raise NotImplementedError( + f"_make_shape not implemented for cls={cls}, " + f"subset={subset}, value={string_type(value)}" + ) + + +def convert_dynamic_axes_into_dynamic_shapes( + model: torch.nn.Module, + args: Optional[Tuple[Any, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None, + dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, + prefix_mapping: Optional[Dict[str, str]] = None, + verbose: int = 0, +) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]: + """ + Converts the input from an export to something :func:`torch.export.export` can handle. + + :param model: model to convert (used to extract the signature) + :param args: positional arguments + :param kwargs: named arguments + :param dynamic_axes: dynamic axes + :param prefix_mapping: prefix mapping + :param verbose: verbosity + :return: (args, kwargs, dynamic shapes) + """ + new_kwargs = {} + if args: + assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}" + plus = 0 if isinstance(model, torch.nn.Module) else 1 + print( + f"[convert_dynamic_axes_into_dynamic_shapes] " + f"mapping args to kwargs for model=" + f"{model if plus else model.__class__.__name__}" + ) + pars = inspect.signature(model.forward).parameters + assert len(pars) >= len( + args + ), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}" + + for i, p in enumerate(pars): + if i < plus: + continue + if i - plus >= len(args): + break + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] " + f"to {p!r} ({string_type(args[i-plus])})" + ) + new_kwargs[p] = args[i - plus] + + if kwargs: + for k, v in kwargs.items(): + assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args." + new_kwargs[k] = v + + # process + updated_kwargs = {} + changes = {} + for k, v in new_kwargs.items(): + if isinstance(v, torch.Tensor): + updated_kwargs[k] = v + continue + if isinstance(v, list): + # cache? + updated_kwargs[k] = _process_cache(k, v) + if type(updated_kwargs[k]) is not type(v): + # A cache was introduced. + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] parameter " + f"{k!r} was changed into {type(updated_kwargs[k])}" + ) + changes[k] = type(updated_kwargs[k]) + continue + raise NotImplementedError( + f"Unexpected type {type(v)} for parameter {k!r} " + f"({string_type(v, with_shape=True)})" + ) + + # process dynamic axes + if changes: + dynamic_shapes = {} + done = set() + for k, v in dynamic_axes.items(): + if k not in changes and k in updated_kwargs and isinstance(v, dict): + dynamic_shapes[k] = v + continue + if "." in k: + # something like present.0.key + prefix = k.split(".")[0] + if prefix in done: + continue + args_prefix = ( + prefix_mapping[prefix] + if prefix_mapping and prefix in prefix_mapping + else prefix + ) + if args_prefix in updated_kwargs and args_prefix in changes: + # A cache. + cls = changes[args_prefix] + dynamic_shapes[args_prefix] = _make_shape( + { + _: __ + for _, __ in dynamic_axes.items() + if _.startswith(f"{prefix}.") + }, + cls, + updated_kwargs[args_prefix], + ) + done.add(prefix) + continue + if k not in updated_kwargs: + # dynamic axes not in the given inputs, should be raise an exception? + if verbose: + print( + f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes " + f"{k!r}-{v!r}, not found in {set(updated_kwargs)}" + ) + continue + raise NotImplementedError( + f"Unable to process dynamic axes {k!r}, axes={v}, " + f"value={string_type(updated_kwargs[k], with_shape=True)}, " + f"dynamic axes={dynamic_axes}, " + f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}" + ) + + return (), updated_kwargs, dynamic_shapes diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py new file mode 100644 index 0000000000000..082d8f3a1f886 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -0,0 +1,331 @@ +import inspect +import os +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _catch_produce_guards_and_solve_constraints( + previous_function: Callable, + fake_mode: "FakeTensorMode", # noqa: F821 + gm: "torch.fx.GraphModule", # noqa: F821 + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + equalities_inputs: "EqualityConstraint", # noqa: F821 + original_signature: inspect.Signature, + _is_torch_jit_trace: bool = False, + verbose: int = 0, +): + try: + return previous_function( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_catch_produce_guards_and_solve_constraints] ERROR" + f"produce_guards_and_solve_constraints failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"fake_mode={fake_mode}\n" + f"dynamic_shapes={dynamic_shapes}\n" + f"equalities_inputs={equalities_inputs}\n" + f"original_signature={original_signature}\n" + f"_is_torch_jit_trace={_is_torch_jit_trace}\n" + f"exc={e}\ngm={gm}" + ) + + +def patch__check_input_constraints_for_graph( + previous_function: Callable, + input_placeholders: list[torch.fx.Node], + flat_args_with_path, + range_constraints, + verbose: int = 0, +) -> None: + try: + return previous_function(input_placeholders, flat_args_with_path, range_constraints) + except Exception as e: + if not int(os.environ.get("SKIP_SOLVE_CONSTRAINTS", "1")): + raise + if verbose: + print( + f"[_check_input_constraints_for_graph] ERROR" + f"_check_input_constraints_for_graph failed, " + f"use SKIP_SOLVE_CONSTRAINTS=0 to avoid skipping\n" + f"input_placeholders={input_placeholders}\n" + f"range_constraints={range_constraints}\n" + f"exc={e}" + ) + + +def patched_infer_size(a, b): + """Patches ``torch._subclasses.fake_impls.infer_size``.""" + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + try: + b1 = guard_size_oblivious(sizeA == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b1 = False + try: + b2 = guard_size_oblivious(sizeB == 1) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b2 = False + try: + b3 = guard_size_oblivious(sizeA == sizeB) + except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: + b3 = False + if b1 or b2 or b3: + expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA + else: + # In this case, the current implementation of torch fails (17/12/2024). + # Try model SmolLM. + expandedSizes[i] = torch.sym_max(sizeA, sizeB) + return tuple(expandedSizes) + + +def patched__broadcast_shapes(*_shapes): + """Patches ``torch._refs._broadcast_shapes``.""" + from functools import reduce + from torch._prims_common import IntLike + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape = [ # List[Union[int, torch.SymInt]] + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for _arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) + + return common_shape + + +class patched_ShapeEnv: + + def _set_replacement( + self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821 + ) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + if tgt == self.replacements.get(a, None): + return + + if a in tgt.free_symbols: + return + + import sympy + from torch._logging import structured + from torch.utils._traceback import CapturedTraceback + from torch._logging import trace_structured + from torch._guards import TracingContext + from torch.utils._sympy.functions import FloorToInt, CeilToInt + from torch.utils._sympy.solve import try_solve + from torch.fx.experimental.symbolic_shapes import ( + _is_supported_equivalence, + ValueRanges, + ) + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if self.allow_complex_guards_as_runtime_asserts and not _is_supported_equivalence(tgt): + # continuing leads to placeholder shapes + # having complex expressions that we can't resolve + return + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug( + "set_replacement: solve for %s in %s == %s gives %s", + b, + a, + tgt, + r, + ) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges( + CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper) + ) + self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset( + src_bound + ), f"{tgt_bound=} not a subset of {src_bound=}" + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s]", + a, + tgt, + msg, + tgt_bound, + src_bound, + ) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug( + "skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", + a, + tgt, + msg, + tgt_bound_so, + src_bound_so, + ) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + "user_stack": (structured.from_traceback(user_tb) if user_tb else None), + }, + ) + + # if config.print_specializations: + # self.log.warning( + # "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt + # ) + # self.log.debug("SPECIALIZATION", stack_info=True) + assert msg != "range_refined_to_singleton", ( + f"A dynamic dimension becomes static! " + f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" + ) + # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + # NB: the replacement may get refined, but the user will find the + # FIRST one most useful (TODO: Maybe we could consider tracking all of + # them) + if a not in self.replacements_slocs: + self.replacements_slocs[a] = self._get_sloc() + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py new file mode 100644 index 0000000000000..76bd62561d55f --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -0,0 +1,510 @@ +import inspect +import sys +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple +import torch +import transformers +import transformers.modeling_attn_mask_utils +from transformers.cache_utils import StaticCache, Cache, DynamicCache + + +def _patch_make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """Patched method.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [ + torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), + mask, + ], + dim=-1, + ) + + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window - 1 + + context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) + # In this case, the current implementation of torch fails (17/12/2024). + # Try model Phi-3.5-Mini-Instruct. + mask = mask.masked_fill(context_mask, torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +if sys.version_info[:2] <= (3, 11): + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + _PATCHES_ = ["_make_causal_mask"] + _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """Patched method.""" + return _patch_make_causal_mask( + input_ids_shape, dtype, device, past_key_values_length, sliding_window + ) + +else: + + @dataclass + class patched_AttentionMaskConverter: + """ + Patches + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. + """ + + _PATCHES_ = ["_make_causal_mask"] + _PATCHED_CLASS_ = transformers.modeling_attn_mask_utils.AttentionMaskConverter + + @staticmethod + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """Patched method.""" + return _patch_make_causal_mask( + input_ids_shape, dtype, device, past_key_values_length, sliding_window + ) + + +class patched_DynamicCache: + """ + Applies modifications implemented in PR + `transformers/#36652 `_. + """ + + _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"] + _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. + A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) + <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or self.key_cache[layer_idx].numel() == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` + and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if key_states is not None: + if len(self.key_cache) <= layer_idx: + # There may be skipped layers, fill them with empty lists + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif not self.key_cache[ + layer_idx + ].numel(): # prefers not t.numel() to len(t) == 0 to export the model + # fills previously skipped layers; checking for tensor causes errors + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def crop(self, max_length: int): + """Crop the past key values up to a new `max_length` + in terms of tokens. `max_length` can also be + negative to remove `max_length` tokens. + This is used in assisted decoding and contrastive search. + """ + # In case it is negative + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + @classmethod + def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache: + """This is the opposite of the above `batch_split()` method. + This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + key_cache = [ + current.key_cache[idx] for current in splits if current.key_cache[idx].numel() + ] + value_cache = [ + current.value_cache[idx] + for current in splits + if current.value_cache[idx].numel() + ] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + +class patched_GenerationMixin: + """ + Applies modifications implemented in PR + `transformers/#36652 `_. + """ + + _PATCHES_ = [ + "_cache_dependant_input_preparation", + "_cache_dependant_input_preparation_exporting", + "prepare_inputs_for_generation", + ] + _PATCHED_CLASS_ = transformers.generation.utils.GenerationMixin + + def _cache_dependant_input_preparation( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + Generic cache-dependent input preparation + The code is put in a separate function to allow granular unit testing + as it needs a different implementation to be exportable. + + If we have cache: let's slice `input_ids` through `cache_position`, + to keep only the unprocessed tokens + - Exception 1: when passing input_embeds, + input_ids may be missing entries + - Exception 2: some generation methods do special slicing of input_ids, + so we don't need to do it here + - Exception 3: with synced GPUs cache_position may go out of bounds, + but we only want dummy token in that case. + - Exception 4: If input_embeds are passed then slice it through + `cache_position`, to keep only the unprocessed tokens and + generate the first token for each sequence. + Later use the generated Input ids for continuation. + + The current implementation does not rely on ``self`` and could be + a class method. It is left as a standard method to be easily rewritten. + """ + return self._cache_dependant_input_preparation_exporting( + input_ids, inputs_embeds, cache_position + ) + """ + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids + """ + + def _cache_dependant_input_preparation_exporting( + self, + input_ids: torch.LongTensor, + inputs_embeds: Optional[torch.FloatTensor], + cache_position: Optional[torch.LongTensor], + ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + This method implements method ``_cache_dependant_input_preparation`` + with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. + The code is put in a separate function to allow granular unit testing. + """ + if inputs_embeds is None: + input_ids = input_ids[:, cache_position] + else: + # This is the code we need to implemented with torch.cond. + # if input_ids.shape[1] == 0: + # inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + # else: + # if cache_position[-1] >= input_ids.shape[1]: + # input_ids = input_ids[:, -cache_position.shape[0] :] + # else: + # if input_ids.shape[1] != cache_position.shape[0]: + # input_ids = input_ids[:, cache_position] + def branch_1(inputs_embeds, cache_position): + return inputs_embeds[:, -cache_position.shape[0] :] + + def branch_2(input_ids, cache_position): + return input_ids[:, -cache_position.shape[0] :] + + def branch_3(input_ids, cache_position): + return input_ids[:, cache_position] + + inputs_embeds, input_ids = torch.cond( + input_ids.shape[1] == 0, + ( + lambda input_ids, inputs_embeds, cache_position: ( + branch_1(inputs_embeds, cache_position), + input_ids, + ) + ), + ( + lambda input_ids, inputs_embeds, cache_position: ( + inputs_embeds, + torch.cond( + cache_position[-1] >= input_ids.shape[1], + branch_2, + lambda input_ids, cache_position: ( + torch.cond( + input_ids.shape[1] != cache_position.shape[0], + branch_3, + (lambda input_ids, cache_position: input_ids), + [input_ids, cache_position], + ) + ), + [input_ids, cache_position], + ), + ) + ), + [input_ids, inputs_embeds, cache_position], + ) + return inputs_embeds, input_ids + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + """ + Prepare the model inputs for generation. + In includes operations like computing the 4D attention mask or + slicing inputs given the existing cache. + + See the forward pass in the model documentation + for expected arguments (different models might have different + requirements for e.g. `past_key_values`). + This function should work as is for most LLMs. + """ + + # 1. Handle BC: + model_inputs = {} + # - some models don't have `Cache` support + # (which implies they don't expect `cache_position` in `forward`) + if self._supports_cache_class: + model_inputs["cache_position"] = cache_position + # - `cache_position` was not a mandatory input in + # `prepare_inputs_for_generation` for those models, and this + # function may be called outside of `generate`. + # Handle most use cases by creating `cache_position` on the fly + # (this alternative is not as robust as calling + # `generate` and letting it create `cache_position`) + elif cache_position is None: + past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + cache_position = torch.arange( + past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device + ) + + # 2. Generic cache-dependent input preparation + if past_key_values is not None: + model_inputs["past_key_values"] = past_key_values + inputs_embeds, input_ids = self._cache_dependant_input_preparation( + input_ids, inputs_embeds, cache_position + ) + + # 3. Prepare base model inputs + input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + # if `inputs_embeds` are passed, we only want + # to use them in the 1st generation step for every prompt. + if not self.config.is_encoder_decoder: + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs[input_ids_key] = None + model_inputs["inputs_embeds"] = inputs_embeds + else: + # `clone` calls in this function ensure a consistent stride. See #32227 + model_inputs[input_ids_key] = input_ids.clone( + memory_format=torch.contiguous_format + ) + model_inputs["inputs_embeds"] = None + else: + model_inputs[input_ids_key] = input_ids.clone( + memory_format=torch.contiguous_format + ) + + # 4. Create missing `position_ids` on the fly + encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None + attention_mask = ( + kwargs.pop("decoder_attention_mask", None) + if self.config.is_encoder_decoder + else attention_mask + ) + attention_mask_key = ( + "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" + ) + position_ids_key = ( + "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" + ) + if ( + attention_mask is not None + and kwargs.get(position_ids_key) is None + and position_ids_key in set(inspect.signature(self.forward).parameters.keys()) + ): + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + kwargs[position_ids_key] = ( + position_ids # placed in kwargs for further processing (see below) + ) + + # 5. Slice model inputs if it's an input + # that should have the same length as `input_ids` + for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]: + model_input = kwargs.get(model_input_name) + if model_input is not None: + if past_key_values is not None: + current_input_length = ( + model_inputs["inputs_embeds"].shape[1] + if model_inputs.get("inputs_embeds") is not None + else model_inputs[input_ids_key].shape[1] + ) + model_input = model_input[:, -current_input_length:] + model_input = model_input.clone(memory_format=torch.contiguous_format) + model_inputs[model_input_name] = model_input + + # 6. Create 4D attention mask is we are using a + # `StaticCache` (important for performant compiled forward pass) + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs[input_ids_key].shape + device = model_inputs[input_ids_key].device + + # Create the causal mask with fixed shape in advance, + # to reduce recompilations. If the function to create + # the 4D causal mask exists, + # it should be present in the base model (XXXModel class). + base_model = getattr(self, self.base_model_prefix, None) + if base_model is None: + causal_mask_creation_function = getattr( + self, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + else: + causal_mask_creation_function = getattr( + base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None + ) + if causal_mask_creation_function is None: + pass + # logger.warning_once( + # f"{self.__class__.__name__} has no " + # "`_prepare_4d_causal_attention_mask_with_cache_position` method " + # "defined in its base modeling class. " + # "Compiled forward passes will be sub-optimal. If you're " + # "writing code, see Llama for an example implementation. " + # "If you're a user, please report this " + # "issue on GitHub." + # ) + else: + attention_mask = causal_mask_creation_function( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + if attention_mask is not None: + model_inputs[attention_mask_key] = attention_mask + + if encoder_attention_mask is not None: + model_inputs["attention_mask"] = encoder_attention_mask + + # 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`). + for key, value in kwargs.items(): + if key not in model_inputs: + model_inputs[key] = value + + # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples) + model_inputs.pop("labels", None) + return model_inputs From cdec2d0dcbe863fce1c5cc1416864a3b77eae522 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 15:22:12 +0200 Subject: [PATCH 03/24] fix import --- .../models/llama/convert_to_onnx.py | 13 ++++++------- .../torch_export_patches/onnx_export_errors.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index ff67696ae0f6d..bf4ccb157b90d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -16,8 +16,12 @@ import onnx import torch -from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger -from onnxruntime.transformers.convert_generation import replace_mha_with_gqa +# to patch transformers before exporting for transformers >= 4.45 +from models.torch_export_patches import bypass_export_some_errors +from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes + +from benchmark_helper import Precision, prepare_environment, setup_logger +from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -30,11 +34,6 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer -# to patch transformers before exporting for transformers >= 4.45 -from onnxruntime.transformers.models.torch_export_patches import bypass_export_some_errors -from onnxruntime.transformers.models.torch_export_patches.patch_inputs import ( - convert_dynamic_axes_into_dynamic_shapes, -) torch_export_onnx_opset_version = 14 diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index b9463303e4bf7..19a7d55d67411 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -68,6 +68,8 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0 def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # Cache serialization: to be moved into appropriate packages import torch + import transformers + import packaging.version as pv try: from transformers.cache_utils import DynamicCache @@ -100,7 +102,20 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: flatten_with_keys_fn=flatten_with_keys_mamba_cache, ) - # DynamicCache + # DynamicCache serialization is different in transformers and does not + # play way with torch.export.export. + # This is caused by this line: + # torch.fx._pytree.register_pytree_flatten_spec( + # DynamicCache, _flatten_dynamic_cache_for_fx) + # so we remove it anyway + if ( + DynamicCache in torch.fx._pytree.SUPPORTED_NODES + and pv.Version(transformers.__version__) >= pv.Version("2.7") + ): + if verbose: + print("[_register_cache_serialization] DynamicCache is unregistered first.") + _unregister(DynamicCache) + unregistered_dynamic_cache = True if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: if verbose > 1: From 827d3bd714253ee23672c153c791e04aa10ea047 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 16:34:25 +0200 Subject: [PATCH 04/24] fix build and import --- cmake/onnxruntime_python.cmake | 14 ++++ .../torch_export_patches/cache_helper.py | 75 +++++++++++++++++++ .../onnx_export_errors.py | 2 +- 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index ca65c02a40c3b..2743c1e522c9f 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -471,6 +471,12 @@ file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py" ) +file(GLOB onnxruntime_python_transformers_models_torch_export_patches_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/*.py" +) +file(GLOB onnxruntime_python_transformers_models_torch_export_patches_patches_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/torch_export_patches/patches/*.py" +) file(GLOB onnxruntime_python_transformers_models_bart_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/bart/*.py" ) @@ -566,6 +572,8 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/sam2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/torch_export_patches + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/torch_export_patches/patches COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators @@ -682,6 +690,12 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_t5_src} $/onnxruntime/transformers/models/t5/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_torch_export_patches_src} + $/onnxruntime/transformers/models/torch_export_patches/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_torch_export_patches_patches_src} + $/onnxruntime/transformers/models/torch_export_patches/patches/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_whisper_src} $/onnxruntime/transformers/models/whisper/ diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py new file mode 100644 index 0000000000000..2a484206b1714 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py @@ -0,0 +1,75 @@ +from typing import List, Tuple +import packaging.version as pv +import torch +import transformers +import transformers.cache_utils + + +def is_cache_dynamic_registered(fast: bool = False) -> bool: + """ + Tells class :class:`transformers.cache_utils.DynamicCache` can be + serialized and deserialized. Only then, :func:`torch.export.export` + can export a model. + + :param fast: if True, do not check the serialization is ok as well + :return: result + """ + if fast: + return transformers.cache_utils.DynamicCache in torch.utils._pytree.SUPPORTED_NODES + bsize, nheads, slen, dim = 2, 4, 3, 7 + cache = make_dynamic_cache( + [ + ( + torch.randn(bsize, nheads, slen, dim), + torch.randn(bsize, nheads, slen, dim), + ) + for i in range(2) + ] + ) + values, spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(values, spec) + return len(cache2.key_cache) == len(cache.value_cache) + + +if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): + + def make_dynamic_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers >= 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + """ + return transformers.cache_utils.DynamicCache(key_value_pairs) + +else: + + def make_dynamic_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> transformers.cache_utils.DynamicCache: + """ + Creates an instance of :class:`transformers.cache_utils.DynamicCache`. + This version is valid for ``transformers < 4.50``. + + :param key_value_pairs: list of pairs of (key, values) + :return: :class:`transformers.cache_utils.DynamicCache` + """ + cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) + for i, (key, value) in enumerate(key_value_pairs): + cache.update(key, value, i) + return cache + + +def make_encoder_decoder_cache( + self_attention_cache: transformers.cache_utils.DynamicCache, + cross_attention_cache: transformers.cache_utils.DynamicCache, +) -> transformers.cache_utils.EncoderDecoderCache: + """ + Creates an EncoderDecoderCache. + """ + return transformers.cache_utils.EncoderDecoderCache( + self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache + ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 19a7d55d67411..164630c0d9c21 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -136,7 +136,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: ) # check - from ..cache_helpers import make_dynamic_cache + from .cache_helper import make_dynamic_cache cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) values, spec = torch.utils._pytree.tree_flatten(cache) From 18b649e986dbbcda9dc9312ba182b86b9620374e Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 16:36:43 +0200 Subject: [PATCH 05/24] build --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index bf4ccb157b90d..567989012c62a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -34,8 +34,6 @@ from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer - - torch_export_onnx_opset_version = 14 logger = logging.getLogger("") init_dist() From 0e77ed4d3f3e8cdaebf9899fac12453c3407be3a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 18:17:30 +0200 Subject: [PATCH 06/24] fix lint --- .../torch_export_patches/cache_helper.py | 5 +- .../onnx_export_errors.py | 61 +++++------- .../onnx_export_serialization.py | 28 +++--- .../torch_export_patches/patch_inputs.py | 60 ++++-------- .../patches/patch_torch.py | 52 +++++----- .../patches/patch_transformers.py | 98 ++++++------------- 6 files changed, 114 insertions(+), 190 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py index 2a484206b1714..0cbe1e58a9e02 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/cache_helper.py @@ -1,4 +1,3 @@ -from typing import List, Tuple import packaging.version as pv import torch import transformers @@ -34,7 +33,7 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool: if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -48,7 +47,7 @@ def make_dynamic_cache( else: def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 164630c0d9c21..8f4fdb155d003 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -1,18 +1,21 @@ import contextlib import pprint -from typing import Any, Callable, Dict +from collections.abs import Callable +from typing import Any + from .onnx_export_serialization import ( - flatten_with_keys_dynamic_cache, flatten_dynamic_cache, - unflatten_dynamic_cache, flatten_mamba_cache, + flatten_with_keys_dynamic_cache, flatten_with_keys_mamba_cache, + unflatten_dynamic_cache, unflatten_mamba_cache, ) -from onnxruntime.transformers.models.torch_export_patches.patches import patch_transformers as patch_transformers_list +from .patches import patch_transformers as patch_transformers_list -def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: + +def patch_module(mod, verbose: int = 0) -> dict[type, dict[type, Callable]]: """ Applies all patches defined in classes prefixed by ``patched_`` ``cls._PATCHED_CLASS_`` defines the class to patch, @@ -42,7 +45,7 @@ def patch_module(mod, verbose: int = 0) -> Dict[type, Dict[type, Callable]]: return res -def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0): +def unpatch_module(mod, info: dict[type, dict[type, Callable]], verbose: int = 0): """Reverts modification made by :func:`patch_module`.""" to_patch = [] for k in dir(mod): @@ -65,11 +68,11 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0 setattr(original, n, v) -def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: +def _register_cache_serialization(verbose: int = 0) -> dict[str, bool]: # Cache serialization: to be moved into appropriate packages + import packaging.version as pv import torch import transformers - import packaging.version as pv try: from transformers.cache_utils import DynamicCache @@ -108,10 +111,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # torch.fx._pytree.register_pytree_flatten_spec( # DynamicCache, _flatten_dynamic_cache_for_fx) # so we remove it anyway - if ( - DynamicCache in torch.fx._pytree.SUPPORTED_NODES - and pv.Version(transformers.__version__) >= pv.Version("2.7") - ): + if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(transformers.__version__) >= pv.Version("2.7"): if verbose: print("[_register_cache_serialization] DynamicCache is unregistered first.") _unregister(DynamicCache) @@ -131,9 +131,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", flatten_with_keys_fn=flatten_with_keys_dynamic_cache, ) - torch.fx._pytree.register_pytree_flatten_spec( - DynamicCache, lambda x, _: [x.key_cache, x.value_cache] - ) + torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, lambda x, _: [x.key_cache, x.value_cache]) # check from .cache_helper import make_dynamic_cache @@ -174,7 +172,7 @@ def _unregister(cls: type, verbose: int = 0): print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") -def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): +def _unregister_cache_serialization(undo: dict[str, bool], verbose: int = 0): if undo.get("MambaCache", False): from transformers.cache_utils import MambaCache @@ -192,9 +190,7 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): @contextlib.contextmanager -def register_additional_serialization_functions( - patch_transformers: bool = False, verbose: int = 0 -) -> Callable: +def register_additional_serialization_functions(patch_transformers: bool = False, verbose: int = 0) -> Callable: """The necessary modifications to run the fx Graph.""" fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) done = _register_cache_serialization(verbose=verbose) @@ -294,10 +290,7 @@ def bypass_export_some_errors( import torch.jit if verbose: - print( - "[bypass_export_some_errors] replace torch.jit.isinstance, " - "torch._dynamo.mark_static_address" - ) + print("[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address") ######## # caches @@ -317,7 +310,7 @@ def bypass_export_some_errors( if verbose: print("[bypass_export_some_errors] patch sympy") - sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}" + sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{self!s}" ############### # patch pytorch @@ -325,10 +318,10 @@ def bypass_export_some_errors( if patch_torch: from .patches.patch_torch import ( - patched_infer_size, - patched__broadcast_shapes, _catch_produce_guards_and_solve_constraints, patch__check_input_constraints_for_graph, + patched__broadcast_shapes, + patched_infer_size, ) if verbose: @@ -355,12 +348,8 @@ def bypass_export_some_errors( if catch_constraints: if verbose: print("[bypass_export_some_errors] modifies shape constraints") - f_produce_guards_and_solve_constraints = ( - torch._export.non_strict_utils.produce_guards_and_solve_constraints - ) - f__check_input_constraints_for_graph = ( - torch._export.utils._check_input_constraints_for_graph - ) + f_produce_guards_and_solve_constraints = torch._export.non_strict_utils.produce_guards_and_solve_constraints + f__check_input_constraints_for_graph = torch._export.utils._check_input_constraints_for_graph torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( lambda *args, **kwargs: _catch_produce_guards_and_solve_constraints( f_produce_guards_and_solve_constraints, *args, verbose=verbose, **kwargs @@ -374,9 +363,7 @@ def bypass_export_some_errors( if stop_if_static: if verbose: - print( - "[bypass_export_some_errors] assert when a dynamic dimension turns static" - ) + print("[bypass_export_some_errors] assert when a dynamic dimension turns static") from torch.fx.experimental.symbolic_shapes import ShapeEnv from .patches.patch_torch import patched_ShapeEnv @@ -448,9 +435,7 @@ def bypass_export_some_errors( torch._export.non_strict_utils.produce_guards_and_solve_constraints = ( f_produce_guards_and_solve_constraints ) - torch._export.utils._check_input_constraints_for_graph = ( - f__check_input_constraints_for_graph - ) + torch._export.utils._check_input_constraints_for_graph = f__check_input_constraints_for_graph if verbose: print("[bypass_export_some_errors] restored shape constraints") diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py index 66b447fd3566e..b21425e4713dc 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any import torch import transformers @@ -25,7 +25,7 @@ # ) def flatten_mamba_cache( mamba_cache: transformers.cache_utils.MambaCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: +) -> tuple[list[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" flat = [ (k, getattr(mamba_cache, k)) @@ -43,7 +43,7 @@ def flatten_mamba_cache( def unflatten_mamba_cache( - values: List[Any], + values: list[Any], context: torch.utils._pytree.Context, output_type=None, ) -> transformers.cache_utils.MambaCache: @@ -71,21 +71,23 @@ def __init__(self): dtype=values[-1][0].dtype, device="cpu" if values[-1][0].get_device() < 0 else "cuda", ) - values = dict(zip(context, values)) + values = dict(zip(context, values, strict=False)) for k, v in values.items(): setattr(cache, k, v) return cache -def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], +def flatten_with_keys_mamba_cache( + d: dict[Any, Any], +) -> tuple[ + list[tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" import torch values, context = flatten_mamba_cache(d) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context ############## @@ -95,7 +97,7 @@ def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ def flatten_dynamic_cache( dynamic_cache: transformers.cache_utils.DynamicCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: +) -> tuple[list[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" flat = [ (k, getattr(dynamic_cache, k)) @@ -105,19 +107,21 @@ def flatten_dynamic_cache( return [f[1] for f in flat], [f[0] for f in flat] -def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], +def flatten_with_keys_dynamic_cache( + d: dict[Any, Any], +) -> tuple[ + list[tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" import torch values, context = flatten_dynamic_cache(d) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values, strict=False)], context def unflatten_dynamic_cache( - values: List[Any], + values: list[Any], context: torch.utils._pytree.Context, output_type=None, ) -> transformers.cache_utils.DynamicCache: diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index 6318cc746275a..563653d422b34 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -1,30 +1,24 @@ import inspect -from typing import Any, Dict, Optional, Tuple +from typing import Any import torch import transformers from onnxruntime.transformers.models.torch_export_patches import make_dynamic_cache, string_type def _process_cache(k: str, v): - assert k != "position_ids" or isinstance( - k, torch.Tensor - ), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}" - if ( - isinstance(v, list) - and all(isinstance(i, tuple) for i in v) - and set(len(t) for t in v) == {2} - ): + assert k != "position_ids" or isinstance(k, torch.Tensor), ( + f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}" + ) + if isinstance(v, list) and all(isinstance(i, tuple) for i in v) and set(len(t) for t in v) == {2}: # A dynamicCache cache = make_dynamic_cache(v) return cache if isinstance(v, torch.Tensor): return v - raise NotImplementedError( - f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}" - ) + raise NotImplementedError(f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}") -def _make_shape(subset: Dict, cls: type, value: Any) -> Any: +def _make_shape(subset: dict, cls: type, value: Any) -> Any: if cls is transformers.cache_utils.DynamicCache: assert subset, "DynamicCache cannot be empty" values = set(map(str, subset.values())) @@ -38,20 +32,17 @@ def _make_shape(subset: Dict, cls: type, value: Any) -> Any: break new_shape = [[axes for i in range(cache_length)], [axes for i in range(cache_length)]] return new_shape - raise NotImplementedError( - f"_make_shape not implemented for cls={cls}, " - f"subset={subset}, value={string_type(value)}" - ) + raise NotImplementedError(f"_make_shape not implemented for cls={cls}, subset={subset}, value={string_type(value)}") def convert_dynamic_axes_into_dynamic_shapes( model: torch.nn.Module, - args: Optional[Tuple[Any, ...]] = None, - kwargs: Optional[Dict[str, Any]] = None, - dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, - prefix_mapping: Optional[Dict[str, str]] = None, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + dynamic_axes: dict[str, dict[int, str]] | None = None, + prefix_mapping: dict[str, str] | None = None, verbose: int = 0, -) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]: +) -> tuple[tuple[Any, ...], dict[str, Any], dict[str, Any]]: """ Converts the input from an export to something :func:`torch.export.export` can handle. @@ -73,9 +64,7 @@ def convert_dynamic_axes_into_dynamic_shapes( f"{model if plus else model.__class__.__name__}" ) pars = inspect.signature(model.forward).parameters - assert len(pars) >= len( - args - ), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}" + assert len(pars) >= len(args), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}" for i, p in enumerate(pars): if i < plus: @@ -84,8 +73,8 @@ def convert_dynamic_axes_into_dynamic_shapes( break if verbose: print( - f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] " - f"to {p!r} ({string_type(args[i-plus])})" + f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i - plus}] " + f"to {p!r} ({string_type(args[i - plus])})" ) new_kwargs[p] = args[i - plus] @@ -113,10 +102,7 @@ def convert_dynamic_axes_into_dynamic_shapes( ) changes[k] = type(updated_kwargs[k]) continue - raise NotImplementedError( - f"Unexpected type {type(v)} for parameter {k!r} " - f"({string_type(v, with_shape=True)})" - ) + raise NotImplementedError(f"Unexpected type {type(v)} for parameter {k!r} ({string_type(v, with_shape=True)})") # process dynamic axes if changes: @@ -131,20 +117,12 @@ def convert_dynamic_axes_into_dynamic_shapes( prefix = k.split(".")[0] if prefix in done: continue - args_prefix = ( - prefix_mapping[prefix] - if prefix_mapping and prefix in prefix_mapping - else prefix - ) + args_prefix = prefix_mapping[prefix] if prefix_mapping and prefix in prefix_mapping else prefix if args_prefix in updated_kwargs and args_prefix in changes: # A cache. cls = changes[args_prefix] dynamic_shapes[args_prefix] = _make_shape( - { - _: __ - for _, __ in dynamic_axes.items() - if _.startswith(f"{prefix}.") - }, + {_: __ for _, __ in dynamic_axes.items() if _.startswith(f"{prefix}.")}, cls, updated_kwargs[args_prefix], ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py index 082d8f3a1f886..5eb3fee71b450 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -1,16 +1,17 @@ import inspect import os -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from collections.abc import Callable, Sequence +from typing import Any import torch from torch._subclasses.fake_tensor import FakeTensorMode def _catch_produce_guards_and_solve_constraints( previous_function: Callable, - fake_mode: "FakeTensorMode", # noqa: F821 - gm: "torch.fx.GraphModule", # noqa: F821 + fake_mode: "FakeTensorMode", + gm: "torch.fx.GraphModule", dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], - equalities_inputs: "EqualityConstraint", # noqa: F821 + equalities_inputs: "EqualityConstraint", original_signature: inspect.Signature, _is_torch_jit_trace: bool = False, verbose: int = 0, @@ -114,12 +115,11 @@ def patched_infer_size(a, b): def patched__broadcast_shapes(*_shapes): """Patches ``torch._refs._broadcast_shapes``.""" from functools import reduce + from torch._prims_common import IntLike from torch.fx.experimental.symbolic_shapes import guard_size_oblivious - shapes = tuple( - (x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes) - ) + shapes = tuple((x,) if isinstance(x, IntLike) else x for x in filter(lambda x: x is not None, _shapes)) # Short-circuits on no input if len(shapes) == 0: @@ -138,9 +138,7 @@ def patched__broadcast_shapes(*_shapes): for idx in range(-1, -1 - len(shape), -1): if guard_size_oblivious(common_shape[idx] == 1): if shape[idx] < 0: - raise ValueError( - "Attempting to broadcast a dimension with negative length!" - ) + raise ValueError("Attempting to broadcast a dimension with negative length!") common_shape[idx] = shape[idx] elif guard_size_oblivious(shape[idx] != 1): common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx]) @@ -151,7 +149,10 @@ def patched__broadcast_shapes(*_shapes): class patched_ShapeEnv: def _set_replacement( - self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821 + self, + a: "sympy.Symbol", # noqa: F821 + tgt: "sympy.Expr", # noqa: F821 + msg: str, ) -> None: """ Adds or updates a replacement for a symbol. @@ -164,16 +165,15 @@ def _set_replacement( return import sympy - from torch._logging import structured - from torch.utils._traceback import CapturedTraceback - from torch._logging import trace_structured from torch._guards import TracingContext - from torch.utils._sympy.functions import FloorToInt, CeilToInt - from torch.utils._sympy.solve import try_solve + from torch._logging import structured, trace_structured from torch.fx.experimental.symbolic_shapes import ( - _is_supported_equivalence, ValueRanges, + _is_supported_equivalence, ) + from torch.utils._sympy.functions import CeilToInt, FloorToInt + from torch.utils._sympy.solve import try_solve + from torch.utils._traceback import CapturedTraceback # Precondition: a == tgt assert isinstance(a, sympy.Symbol) @@ -219,14 +219,10 @@ def _set_replacement( # and then requantize the bound to integers when we are # done. rat_b_bound = self.bound_sympy(r[1]) - b_bound = ValueRanges( - CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper) - ) + b_bound = ValueRanges(CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper)) self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) tgt_bound = self.bound_sympy(tgt) - assert tgt_bound.issubset( - src_bound - ), f"{tgt_bound=} not a subset of {src_bound=}" + assert tgt_bound.issubset(src_bound), f"{tgt_bound=} not a subset of {src_bound=}" # TODO: Should we propagate size-like-ness? # @@ -279,8 +275,7 @@ def _set_replacement( src_bound_so = self.bound_sympy(a, size_oblivious=True) if not tgt_bound_so.issubset(src_bound_so): self.log.debug( - "skipped set_replacement %s = %s (%s) " - "[%s not subset of %s (size-oblivious conditions)]", + "skipped set_replacement %s = %s (%s) [%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, @@ -301,9 +296,7 @@ def _set_replacement( "sources": [s.name() for s in self.var_to_sources.get(a, [])], "value": repr(tgt), "reason": msg, - "stack": structured.from_traceback( - CapturedTraceback.extract(skip=1).summary() - ), + "stack": structured.from_traceback(CapturedTraceback.extract(skip=1).summary()), "user_stack": (structured.from_traceback(user_tb) if user_tb else None), }, ) @@ -314,8 +307,7 @@ def _set_replacement( # ) # self.log.debug("SPECIALIZATION", stack_info=True) assert msg != "range_refined_to_singleton", ( - f"A dynamic dimension becomes static! " - f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" + f"A dynamic dimension becomes static! a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" ) # log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) self.replacements[a] = tgt diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 76bd62561d55f..2d76608c3e406 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -1,11 +1,11 @@ import inspect import sys from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import torch import transformers import transformers.modeling_attn_mask_utils -from transformers.cache_utils import StaticCache, Cache, DynamicCache +from transformers.cache_utils import Cache, DynamicCache, StaticCache def _patch_make_causal_mask( @@ -13,7 +13,7 @@ def _patch_make_causal_mask( dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, - sliding_window: Optional[int] = None, + sliding_window: int | None = None, ): """Patched method.""" bsz, tgt_len = input_ids_shape @@ -61,12 +61,10 @@ def _make_causal_mask( dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, - sliding_window: Optional[int] = None, + sliding_window: int | None = None, ): """Patched method.""" - return _patch_make_causal_mask( - input_ids_shape, dtype, device, past_key_values_length, sliding_window - ) + return _patch_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length, sliding_window) else: @@ -87,12 +85,10 @@ def _make_causal_mask( dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0, - sliding_window: Optional[int] = None, + sliding_window: int | None = None, ): """Patched method.""" - return _patch_make_causal_mask( - input_ids_shape, dtype, device, past_key_values_length, sliding_window - ) + return _patch_make_causal_mask(input_ids_shape, dtype, device, past_key_values_length, sliding_window) class patched_DynamicCache: @@ -104,7 +100,7 @@ class patched_DynamicCache: _PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"] _PATCHED_CLASS_ = transformers.cache_utils.DynamicCache - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + def get_seq_length(self, layer_idx: int | None = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` @@ -122,21 +118,17 @@ def reorder_cache(self, beam_idx: torch.LongTensor): for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) if self.value_cache[layer_idx].numel(): device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, + cache_kwargs: Dict[str, Any] | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` @@ -169,19 +161,13 @@ def update( self.value_cache.append(torch.tensor([])) self.key_cache.append(key_states) self.value_cache.append(value_states) - elif not self.key_cache[ - layer_idx - ].numel(): # prefers not t.numel() to len(t) == 0 to export the model + elif not self.key_cache[layer_idx].numel(): # prefers not t.numel() to len(t) == 0 to export the model # fills previously skipped layers; checking for tensor causes errors self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=-2 - ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -211,14 +197,8 @@ def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache: `generation.utils`""" cache = cls() for idx in range(len(splits[0])): - key_cache = [ - current.key_cache[idx] for current in splits if current.key_cache[idx].numel() - ] - value_cache = [ - current.value_cache[idx] - for current in splits - if current.value_cache[idx].numel() - ] + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] if key_cache != []: layer_keys = torch.cat(key_cache, dim=0) layer_values = torch.cat(value_cache, dim=0) @@ -242,9 +222,9 @@ class patched_GenerationMixin: def _cache_dependant_input_preparation( self, input_ids: torch.LongTensor, - inputs_embeds: Optional[torch.FloatTensor], - cache_position: Optional[torch.LongTensor], - ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + inputs_embeds: torch.FloatTensor | None, + cache_position: torch.LongTensor | None, + ) -> tuple[torch.FloatTensor, torch.LongTensor]: """ Generic cache-dependent input preparation The code is put in a separate function to allow granular unit testing @@ -286,8 +266,8 @@ def _cache_dependant_input_preparation( def _cache_dependant_input_preparation_exporting( self, input_ids: torch.LongTensor, - inputs_embeds: Optional[torch.FloatTensor], - cache_position: Optional[torch.LongTensor], + inputs_embeds: torch.FloatTensor | None, + cache_position: torch.LongTensor | None, ) -> Tuple[torch.FloatTensor, torch.LongTensor]: """ This method implements method ``_cache_dependant_input_preparation`` @@ -348,10 +328,10 @@ def branch_3(input_ids, cache_position): def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: Optional[Cache] = None, - attention_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: Optional[torch.LongTensor] = None, + past_key_values: Cache | None = None, + attention_mask: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + cache_position: torch.LongTensor | None = None, **kwargs, ): """ @@ -386,9 +366,7 @@ def prepare_inputs_for_generation( # 2. Generic cache-dependent input preparation if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - inputs_embeds, input_ids = self._cache_dependant_input_preparation( - input_ids, inputs_embeds, cache_position - ) + inputs_embeds, input_ids = self._cache_dependant_input_preparation(input_ids, inputs_embeds, cache_position) # 3. Prepare base model inputs input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" @@ -400,28 +378,18 @@ def prepare_inputs_for_generation( model_inputs["inputs_embeds"] = inputs_embeds else: # `clone` calls in this function ensure a consistent stride. See #32227 - model_inputs[input_ids_key] = input_ids.clone( - memory_format=torch.contiguous_format - ) + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) model_inputs["inputs_embeds"] = None else: - model_inputs[input_ids_key] = input_ids.clone( - memory_format=torch.contiguous_format - ) + model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format) # 4. Create missing `position_ids` on the fly encoder_attention_mask = attention_mask if self.config.is_encoder_decoder else None attention_mask = ( - kwargs.pop("decoder_attention_mask", None) - if self.config.is_encoder_decoder - else attention_mask - ) - attention_mask_key = ( - "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" - ) - position_ids_key = ( - "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" + kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask ) + attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask" + position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids" if ( attention_mask is not None and kwargs.get(position_ids_key) is None @@ -429,9 +397,7 @@ def prepare_inputs_for_generation( ): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - kwargs[position_ids_key] = ( - position_ids # placed in kwargs for further processing (see below) - ) + kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below) # 5. Slice model inputs if it's an input # that should have the same length as `input_ids` From 4633a3e942b83a5aca03cd53985133a050dcd4f4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 18:39:58 +0200 Subject: [PATCH 07/24] lint --- .../models/llama/convert_to_onnx.py | 7 ++++--- .../models/torch_export_patches/__init__.py | 14 +++++++------- .../torch_export_patches/onnx_export_errors.py | 3 +-- .../onnx_export_serialization.py | 8 ++------ .../models/torch_export_patches/patch_inputs.py | 5 ++--- .../torch_export_patches/patches/patch_torch.py | 4 ++-- .../patches/patch_transformers.py | 17 ++++++----------- 7 files changed, 24 insertions(+), 34 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 567989012c62a..93cbbfe7e1a0e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -16,9 +16,6 @@ import onnx import torch -# to patch transformers before exporting for transformers >= 4.45 -from models.torch_export_patches import bypass_export_some_errors -from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes from benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa @@ -26,6 +23,10 @@ from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check from llama_torch import setup_torch_model + +# to patch transformers before exporting for transformers >= 4.45 +from models.torch_export_patches import bypass_export_some_errors +from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py index a8f32918ef72d..5490cb956d592 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any import packaging.version as pv import torch import transformers @@ -35,9 +35,9 @@ def string_type(anything, **args): if pv.Version(transformers.__version__) > pv.Version("4.49.99999"): def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: - ''' + """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers >= 4.50``. @@ -61,15 +61,15 @@ def make_dynamic_cache( ] ) print(string_type(past_key_values, with_shape=True)) - ''' + """ return transformers.cache_utils.DynamicCache(key_value_pairs) else: def make_dynamic_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], + key_value_pairs: list[tuple[torch.Tensor, torch.Tensor]], ) -> transformers.cache_utils.DynamicCache: - ''' + """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers < 4.50``. @@ -93,7 +93,7 @@ def make_dynamic_cache( ] ) print(string_type(past_key_values, with_shape=True)) - ''' + """ cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) for i, (key, value) in enumerate(key_value_pairs): cache.update(key, value, i) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 8f4fdb155d003..e3b3f53c80335 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -115,7 +115,7 @@ def _register_cache_serialization(verbose: int = 0) -> dict[str, bool]: if verbose: print("[_register_cache_serialization] DynamicCache is unregistered first.") _unregister(DynamicCache) - + unregistered_dynamic_cache = True if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: if verbose > 1: @@ -173,7 +173,6 @@ def _unregister(cls: type, verbose: int = 0): def _unregister_cache_serialization(undo: dict[str, bool], verbose: int = 0): - if undo.get("MambaCache", False): from transformers.cache_utils import MambaCache diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py index b21425e4713dc..b191601885e48 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -99,11 +99,7 @@ def flatten_dynamic_cache( dynamic_cache: transformers.cache_utils.DynamicCache, ) -> tuple[list[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - flat = [ - (k, getattr(dynamic_cache, k)) - for k in ["key_cache", "value_cache"] - if hasattr(dynamic_cache, k) - ] + flat = [(k, getattr(dynamic_cache, k)) for k in ["key_cache", "value_cache"] if hasattr(dynamic_cache, k)] return [f[1] for f in flat], [f[0] for f in flat] @@ -129,7 +125,7 @@ def unflatten_dynamic_cache( from transformers.cache_utils import DynamicCache cache = DynamicCache() - values = dict(zip(context, values)) + values = dict(zip(context, values, strict=False)) for k, v in values.items(): setattr(cache, k, v) return cache diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index 563653d422b34..b9cdf11d06bb4 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -15,7 +15,7 @@ def _process_cache(k: str, v): return cache if isinstance(v, torch.Tensor): return v - raise NotImplementedError(f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}") + raise NotImplementedError(f"Unable to process parameter {k!r} with v={string_type(v, with_shape=True)}") def _make_shape(subset: dict, cls: type, value: Any) -> Any: @@ -23,8 +23,7 @@ def _make_shape(subset: dict, cls: type, value: Any) -> Any: assert subset, "DynamicCache cannot be empty" values = set(map(str, subset.values())) assert len(values) == 1, ( - f"Inconsistencies in subset={subset}, found={values}, " - f"it cannot be a {cls}, value={string_type(value)}" + f"Inconsistencies in subset={subset}, found={values}, it cannot be a {cls}, value={string_type(value)}" ) cache_length = len(value.key_cache) for v in subset.values(): diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py index 5eb3fee71b450..c6117d0f94220 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -10,7 +10,7 @@ def _catch_produce_guards_and_solve_constraints( previous_function: Callable, fake_mode: "FakeTensorMode", gm: "torch.fx.GraphModule", - dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, equalities_inputs: "EqualityConstraint", original_signature: inspect.Signature, _is_torch_jit_trace: bool = False, @@ -131,7 +131,7 @@ def patched__broadcast_shapes(*_shapes): assert isinstance(shape, Sequence) # Computes common shape - common_shape = [ # List[Union[int, torch.SymInt]] + common_shape = [ # list[Union[int, torch.SymInt]] 1, ] * reduce(max, (len(shape) for shape in shapes)) for _arg_idx, shape in enumerate(shapes): diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 2d76608c3e406..25cca9ccbdf71 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -106,8 +106,7 @@ def get_seq_length(self, layer_idx: int | None = 0) -> int: # TODO: deprecate this function in favor of `cache_position` is_empty_layer = ( len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) - <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or self.key_cache[layer_idx].numel() == 0 # the layer has no cache ) layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 @@ -128,7 +127,7 @@ def update( key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, - cache_kwargs: Dict[str, Any] | None = None, + cache_kwargs: dict[str, Any] | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` @@ -191,7 +190,7 @@ def crop(self, max_length: int): self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] @classmethod - def from_batch_splits(cls, splits: List[DynamicCache]) -> DynamicCache: + def from_batch_splits(cls, splits: list[DynamicCache]) -> DynamicCache: """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" @@ -246,9 +245,7 @@ def _cache_dependant_input_preparation( The current implementation does not rely on ``self`` and could be a class method. It is left as a standard method to be easily rewritten. """ - return self._cache_dependant_input_preparation_exporting( - input_ids, inputs_embeds, cache_position - ) + return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) """ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] @@ -268,7 +265,7 @@ def _cache_dependant_input_preparation_exporting( input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor | None, cache_position: torch.LongTensor | None, - ) -> Tuple[torch.FloatTensor, torch.LongTensor]: + ) -> tuple[torch.FloatTensor, torch.LongTensor]: """ This method implements method ``_cache_dependant_input_preparation`` with :func:`torch.cond` to make it exportable with :func:`torch.export.export`. @@ -359,9 +356,7 @@ def prepare_inputs_for_generation( # `generate` and letting it create `cache_position`) elif cache_position is None: past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - cache_position = torch.arange( - past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device - ) + cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) # 2. Generic cache-dependent input preparation if past_key_values is not None: From b12287a6e13d2702d0c71fd6a719e6fc694dd0bc Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 19:01:27 +0200 Subject: [PATCH 08/24] lint --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 1 - .../models/torch_export_patches/patches/patch_torch.py | 2 +- .../models/torch_export_patches/patches/patch_transformers.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 93cbbfe7e1a0e..c67e15de88d01 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -366,7 +366,6 @@ def run_torchscript_merged_export( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - dynamic_shapes=dynamic_shapes, opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py index c6117d0f94220..2b40a0ee97c20 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -11,7 +11,7 @@ def _catch_produce_guards_and_solve_constraints( fake_mode: "FakeTensorMode", gm: "torch.fx.GraphModule", dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, - equalities_inputs: "EqualityConstraint", + equalities_inputs: "EqualityConstraint", # noqa: F821 original_signature: inspect.Signature, _is_torch_jit_trace: bool = False, verbose: int = 0, diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 25cca9ccbdf71..09034ee0375e5 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -128,7 +128,7 @@ def update( value_states: torch.Tensor, layer_idx: int, cache_kwargs: dict[str, Any] | None = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. From 1b926cbe51b2b082346e7753a4f3e7c89ab287b9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 19:14:38 +0200 Subject: [PATCH 09/24] rename --- .../models/torch_export_patches/onnx_export_errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index e3b3f53c80335..6f7169aabc003 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -1,6 +1,6 @@ import contextlib import pprint -from collections.abs import Callable +from collections.abc import Callable from typing import Any from .onnx_export_serialization import ( From 6646e61ff4f83ca431da53e64eefa94c03103845 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 19:48:34 +0200 Subject: [PATCH 10/24] lint --- .../transformers/models/torch_export_patches/__init__.py | 2 ++ .../models/torch_export_patches/onnx_export_errors.py | 3 +-- .../models/torch_export_patches/onnx_export_serialization.py | 1 + .../transformers/models/torch_export_patches/patch_inputs.py | 4 +++- .../models/torch_export_patches/patches/patch_torch.py | 2 +- .../models/torch_export_patches/patches/patch_transformers.py | 1 + pyproject.toml | 1 + 7 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py index 5490cb956d592..84db86231f8f6 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -1,7 +1,9 @@ from typing import Any + import packaging.version as pv import torch import transformers + from onnxruntime.transformers.models.torch_export_patches.onnx_export_errors import ( bypass_export_some_errors, register_additional_serialization_functions, diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 6f7169aabc003..05293722bd803 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -11,7 +11,6 @@ unflatten_dynamic_cache, unflatten_mamba_cache, ) - from .patches import patch_transformers as patch_transformers_list @@ -365,6 +364,7 @@ def bypass_export_some_errors( print("[bypass_export_some_errors] assert when a dynamic dimension turns static") from torch.fx.experimental.symbolic_shapes import ShapeEnv + from .patches.patch_torch import patched_ShapeEnv f_shape_env__set_replacement = ShapeEnv._set_replacement @@ -397,7 +397,6 @@ def bypass_export_some_errors( print("[bypass_export_some_errors] remove patches") if patch_sympy: - # tracked by https://github.com/pytorch/pytorch/issues/143494 if f_sympy_name: sympy.core.numbers.IntegerConstant.name = f_sympy_name diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py index b191601885e48..d109dd3059480 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_serialization.py @@ -1,4 +1,5 @@ from typing import Any + import torch import transformers diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index b9cdf11d06bb4..bda08ba7da7df 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -1,7 +1,9 @@ import inspect from typing import Any + import torch import transformers + from onnxruntime.transformers.models.torch_export_patches import make_dynamic_cache, string_type @@ -9,7 +11,7 @@ def _process_cache(k: str, v): assert k != "position_ids" or isinstance(k, torch.Tensor), ( f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}" ) - if isinstance(v, list) and all(isinstance(i, tuple) for i in v) and set(len(t) for t in v) == {2}: + if isinstance(v, list) and all(isinstance(i, tuple) for i in v) and {len(t) for t in v} == {2}: # A dynamicCache cache = make_dynamic_cache(v) return cache diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py index 2b40a0ee97c20..c30a58f4290f1 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_torch.py @@ -2,6 +2,7 @@ import os from collections.abc import Callable, Sequence from typing import Any + import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -147,7 +148,6 @@ def patched__broadcast_shapes(*_shapes): class patched_ShapeEnv: - def _set_replacement( self, a: "sympy.Symbol", # noqa: F821 diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 09034ee0375e5..3ff52af654b8d 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -2,6 +2,7 @@ import sys from dataclasses import dataclass from typing import Any + import torch import transformers import transformers.modeling_attn_mask_utils diff --git a/pyproject.toml b/pyproject.toml index f95fb0ff955a4..2b96187028d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ ignore = [ "tools/nuget/generate_nuspec_for_native_nuget.py" = ["ISC003"] # Too many errors to fix "onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_triton.py" = ["N806"] # use of Q, K and V in triton script "onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v2/sparse_attention_v2_triton.py" = ["N806"] # use of Q, K and V in triton script +"onnxruntime/python/tools/transformers/models/torch_export_patches/*" = ["F401", "PLW0211", "N801", "N806", "RUF012"] # patches are based on pytorch code "onnxruntime/test/python/quantization/test_op_gemm.py" = ["N806"] # use of A for a matrix "onnxruntime/test/python/quantization/op_test_utils.py" = ["N806", "PERF203", "RUF012"] # use of A for a matrix "orttraining/orttraining/python/training/ort_triton/kernel/_flash_attn.py" = ["N806", "PLW2901", "ISC001", "E731"] # Long triton code from other repo. From a14b8b3ef8b0be0dfd3fd2b2b7638292a444b079 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 31 Mar 2025 19:58:26 +0200 Subject: [PATCH 11/24] lint --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index c67e15de88d01..693bb645b9596 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -16,7 +16,6 @@ import onnx import torch - from benchmark_helper import Precision, prepare_environment, setup_logger from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist From 9f3a816c9baae33c6d1e3b15053b24554620db66 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Apr 2025 08:49:34 +0200 Subject: [PATCH 12/24] remove args.dynamo --- .../tools/transformers/models/llama/convert_to_onnx.py | 8 +------- .../transformers/models/torch_export_patches/__init__.py | 2 +- .../models/torch_export_patches/patch_inputs.py | 3 ++- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 693bb645b9596..c8924b7fef613 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -368,7 +368,7 @@ def run_torchscript_merged_export( opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, - dynamo=args.dynamo, + dynamo=False, ) # Check decoder_merged_model.onnx and save all external data to one file @@ -890,9 +890,6 @@ def main(): decoder_merged_model_fp32_opt_path, ] - if args.use_dynamo_export: - continue - # Run the optimizer script. logger.info("Optimizing models...") for orig_path, opt_path in zip(old_paths, new_paths, strict=False): @@ -998,9 +995,6 @@ def main(): remove_existing_model(fp_path) barrier() - if args.use_dynamo_export: - return - logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py index 84db86231f8f6..caa6638e7f749 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/__init__.py @@ -4,7 +4,7 @@ import torch import transformers -from onnxruntime.transformers.models.torch_export_patches.onnx_export_errors import ( +from .onnx_export_errors import ( bypass_export_some_errors, register_additional_serialization_functions, ) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index bda08ba7da7df..371b8701ce367 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -4,7 +4,8 @@ import torch import transformers -from onnxruntime.transformers.models.torch_export_patches import make_dynamic_cache, string_type +from . import string_type +from .cache_helper import make_dynamic_cache def _process_cache(k: str, v): From 0c88e42aa526855fa82b8ba070f1eca1b8509f71 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Apr 2025 09:21:03 +0200 Subject: [PATCH 13/24] fix issues --- .../python/tools/transformers/models/llama/llama_parity.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index eab55154b50b1..4181413411491 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -11,7 +11,9 @@ import time import numpy as np +import packaging.version as pv import torch +import transformers from benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( @@ -23,6 +25,7 @@ verify_ort_inputs, ) from llama_torch import setup_torch_model +from models.torch_export_patches.cache_helper import make_dynamic_cache from transformers import AutoConfig import onnxruntime as ort @@ -92,6 +95,10 @@ def verify_parity( inputs = get_inputs(args, config) + if "past_key_values" in inputs and pv.Version(transformers.__version__) >= pv.Version("4.45"): + # Using DynamicCache + inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"]) + # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() From 8b605359dad9bcfe31a1e50c391b492ad4a50b57 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Apr 2025 19:11:50 +0200 Subject: [PATCH 14/24] copy inputs --- .../transformers/models/llama/llama_parity.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4181413411491..49652c18684fe 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -74,6 +74,30 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): return inputs +def torch_deepcopy(value: Any) -> Any: + if isinstance(value, (int, float, str)): + return value + if isinstance(value, tuple): + return tuple(torch_deepcopy(v) for v in value) + if isinstance(value, list): + return [torch_deepcopy(v) for v in value] + if isinstance(value, set): + return {torch_deepcopy(v) for v in value} + if isinstance(value, dict): + return {k: torch_deepcopy(v) for k, v in value.items()} + if isinstance(value, np.ndarray): + return value.copy() + if hasattr(value, "clone"): + return value.clone() + if isinstance(value, transformers.cache_utils.DynamicCache): + return make_dynamic_cache( + torch_deepcopy(list(zip(value.key_cache, value.value_cache))) + ) + # We should have a code using serialization, deserialization assuming a model + # cannot be exported without them. + raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") + + def verify_parity( args: argparse.Namespace, location: str, @@ -103,7 +127,10 @@ def verify_parity( if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() - pt_outputs = py_model(**inputs).logits.detach().cpu().numpy() + # If there is a cache in the inputs, we need to make a copy as the model modify them inplace. + # DynamicCache inherits from torch.nn.Module in some version of transformers. + # We need to make the copy manually. + pt_outputs = py_model(**torch_deepcopy(inputs)).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() From 741285b27582d5124a65aa94eb3c9011106c5107 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 1 Apr 2025 19:43:48 +0200 Subject: [PATCH 15/24] fix shape --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index c8924b7fef613..e2edab47e981d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -169,6 +169,7 @@ def run_dynamo_export( dynamic_shapes=dynamic_shapes, dynamo=True, verbose=args.verbose, + optimize=True, ) # Check decoder_with_past_model.onnx and save all external data to one file From f8490a5c8570edd82c26bb98228bb5ee1a3346ff Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 2 Apr 2025 10:20:18 +0200 Subject: [PATCH 16/24] fix validation --- .../models/llama/convert_to_onnx.py | 2 +- .../transformers/models/llama/llama_inputs.py | 10 +++++-- .../transformers/models/llama/llama_parity.py | 6 ++-- .../onnx_export_errors.py | 3 ++ .../patches/patch_transformers.py | 28 ++++++++++--------- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index e2edab47e981d..269bf5c785594 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -126,7 +126,7 @@ def run_dynamo_export( config.capture_scalar_outputs = True # Dummy values for export - batch_size, sequence_length, past_sequence_length = 2, 8, 0 + batch_size, sequence_length, past_sequence_length = 2, 8, 3 device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu") temp_name = args.model_name.lower().replace("-", "").replace("_", "") diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 025d57f0b2d5d..4163b0a4ee9e5 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -7,6 +7,7 @@ import numpy as np import torch +import transformers from transformers import AutoConfig, AutoTokenizer from onnxruntime import InferenceSession, OrtValue @@ -240,11 +241,14 @@ def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, u def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]): past_kv = {} for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() + if isinstance(past_key_values, transformers.cache_utils.DynamicCache): + past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy() + else: + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv - # Format PyTorch inputs to ONNX Runtime inputs def convert_inputs_for_ort( pt_inputs: dict, diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 49652c18684fe..27e9db54bd0a4 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -74,7 +74,7 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): return inputs -def torch_deepcopy(value: Any) -> Any: +def torch_deepcopy(value): if isinstance(value, (int, float, str)): return value if isinstance(value, tuple): @@ -90,9 +90,7 @@ def torch_deepcopy(value: Any) -> Any: if hasattr(value, "clone"): return value.clone() if isinstance(value, transformers.cache_utils.DynamicCache): - return make_dynamic_cache( - torch_deepcopy(list(zip(value.key_cache, value.value_cache))) - ) + return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False)))) # We should have a code using serialization, deserialization assuming a model # cannot be exported without them. raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}") diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 05293722bd803..9eb239c59b8b9 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -313,6 +313,9 @@ def bypass_export_some_errors( ############### # patch pytorch ############### + # the linter gets confused if not initialized + f_jit_isinstance = f_mark_static_address = f_infer_size = ShapeEnv = None + f_shape_env__set_replacement = revert_patches_info = None if patch_torch: from .patches.patch_torch import ( diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py index 3ff52af654b8d..828a883b7ab12 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patches/patch_transformers.py @@ -245,21 +245,23 @@ def _cache_dependant_input_preparation( The current implementation does not rely on ``self`` and could be a class method. It is left as a standard method to be easily rewritten. + Original code: + + .. code-block:: python + + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None or ( # Exception 1 + cache_position[-1] >= input_ids.shape[1] + ): # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + return inputs_embeds, input_ids """ return self._cache_dependant_input_preparation_exporting(input_ids, inputs_embeds, cache_position) - """ - if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 - inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] - elif inputs_embeds is not None or ( # Exception 1 - cache_position[-1] >= input_ids.shape[1] - ): # Exception 3 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif ( - input_ids.shape[1] != cache_position.shape[0] - ): # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - return inputs_embeds, input_ids - """ def _cache_dependant_input_preparation_exporting( self, From ca43041b6bd03a2ad1ba1029a6162666de572822 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 2 Apr 2025 12:44:13 +0200 Subject: [PATCH 17/24] add use_dynamo_export --- .../github/azure-pipelines/bigmodels-ci-pipeline.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index f4658f3a22c33..122c7651907b0 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -332,7 +332,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --small_gp;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --small_gp --use_dynamo_export;\ ls -l llama2-7b-fp16; \ du -sh llama2-7b-fp16; \ popd ; \ @@ -353,7 +353,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda --use_dynamo_export;\ ls -l llama2-7b-fp32-gpu; \ du -sh llama2-7b-fp32-gpu; \ popd ; \ @@ -374,7 +374,7 @@ stages: python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ - python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --use_gqa;\ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input /meta-llama2 --output llama2-7b-int4-gpu --precision int4 --execution_provider cuda --use_gqa --use_dynamo_export;\ ls -l llama2-7b-int4-gpu; \ du -sh llama2-7b-int4-gpu; \ popd ; \ From 19d4dfb8c5e704d0159630969f01278bb722a586 Mon Sep 17 00:00:00 2001 From: xadupre Date: Wed, 2 Apr 2025 17:25:41 +0200 Subject: [PATCH 18/24] lint --- .../python/tools/transformers/models/llama/llama_inputs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 4163b0a4ee9e5..5c9ccb118bc61 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -249,6 +249,7 @@ def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tenso past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv + # Format PyTorch inputs to ONNX Runtime inputs def convert_inputs_for_ort( pt_inputs: dict, From 8d3b0ba036ecdce3e52ec597768fc696aaf28bca Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 13:12:20 +0200 Subject: [PATCH 19/24] fix requirements --- .../python/tools/transformers/models/llama/requirements.txt | 2 +- .../models/torch_export_patches/onnx_export_errors.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index c965cc5dab58a..81ab36af48978 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,5 +1,5 @@ optimum>=1.14.1 -transformers>=4.33.2,<= 4.38.0 +transformers==4.48.0 torch>=2.2.0 onnx==1.17.0 datasets>=2.8.0 diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py index 9eb239c59b8b9..5dd3b38a8232a 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/onnx_export_errors.py @@ -315,7 +315,7 @@ def bypass_export_some_errors( ############### # the linter gets confused if not initialized f_jit_isinstance = f_mark_static_address = f_infer_size = ShapeEnv = None - f_shape_env__set_replacement = revert_patches_info = None + f__broadcast_shapes = f_shape_env__set_replacement = revert_patches_info = None if patch_torch: from .patches.patch_torch import ( From 835b76e7831b32bebe2f010afe2d78e66b6e995c Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 14:31:41 +0200 Subject: [PATCH 20/24] fix requitmeents --- .../python/tools/transformers/models/llama/requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 81ab36af48978..40f3ab1c92f16 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,3 +1,5 @@ +onnxscript>=0.2.3 +optree optimum>=1.14.1 transformers==4.48.0 torch>=2.2.0 From a0a8c21ffcb3465f3b9c0a03dea46a7489851cba Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 16:13:19 +0200 Subject: [PATCH 21/24] fix dynamic shapes --- .../models/llama/convert_to_onnx.py | 14 +++++++++++++- .../torch_export_patches/patch_inputs.py | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 269bf5c785594..bae7a69c5c75a 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -25,7 +25,7 @@ # to patch transformers before exporting for transformers >= 4.45 from models.torch_export_patches import bypass_export_some_errors -from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes +from models.torch_export_patches.patch_inputs import convert_dynamic_axes_into_dynamic_shapes, replace_dynamic_shapes from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version @@ -160,6 +160,18 @@ def run_dynamo_export( llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"} ) + if version.Version(torch.__version__) < version.Version("2.7"): + # strings are not allowed with torch 2.6, so we replace them by DYNAMIC + # {'input_ids': {0: 'batch_size', 1: 'sequence_length'}, + # 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, + # 'position_ids': {0: 'batch_size', 1: 'sequence_length'}, + # 'past_key_values': [[{0: 'batch_size', 2: 'sequence_length'}], [{0: 'batch_size', 2: 'sequence_length'}]]} + dynamic_shapes = replace_dynamic_shapes( + dynamic_shapes, + dict(batch_size=torch.export.Dim("batch_size")), + default_value=torch.export.Dim.DYNAMIC, + ) + with bypass_export_some_errors(patch_transformers=True): torch.onnx.export( llama, diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index 371b8701ce367..518aa6aea47d7 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -146,3 +146,21 @@ def convert_dynamic_axes_into_dynamic_shapes( ) return (), updated_kwargs, dynamic_shapes + + +def replace_dynamic_shapes(ds, mapping, default_value): + if isinstance(ds, dict) and all(isinstance(k, int) for k in ds): + new_ds = {} + for k, v in ds.items(): + if isinstance(v, str): + new_ds[k] = mapping.get(v, default_value) + else: + new_ds[k] = v + return new_ds + if isinstance(ds, tuple): + return tuple(replace_dynamic_shapes(d, mapping, default_value) for d in ds) + if isinstance(ds, list): + return [replace_dynamic_shapes(d, mapping, default_value) for d in ds] + if isinstance(ds, dict): + return {k: replace_dynamic_shapes(v, mapping, default_value) for k,v in ds.items()} + return ds From f61c27bfd5b6a8a356eee2fd5bb119336c7fe259 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 17:22:39 +0200 Subject: [PATCH 22/24] 2.6 --- .../models/llama/convert_to_onnx.py | 69 +++++++++++++++---- .../torch_export_patches/patch_inputs.py | 2 +- 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index bae7a69c5c75a..4cba7fa06213d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -162,28 +162,69 @@ def run_dynamo_export( if version.Version(torch.__version__) < version.Version("2.7"): # strings are not allowed with torch 2.6, so we replace them by DYNAMIC - # {'input_ids': {0: 'batch_size', 1: 'sequence_length'}, - # 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, - # 'position_ids': {0: 'batch_size', 1: 'sequence_length'}, - # 'past_key_values': [[{0: 'batch_size', 2: 'sequence_length'}], [{0: 'batch_size', 2: 'sequence_length'}]]} dynamic_shapes = replace_dynamic_shapes( dynamic_shapes, dict(batch_size=torch.export.Dim("batch_size")), default_value=torch.export.Dim.DYNAMIC, ) - with bypass_export_some_errors(patch_transformers=True): - torch.onnx.export( - llama, - (), - temp_path, - kwargs=model_kwargs, - dynamic_shapes=dynamic_shapes, - dynamo=True, - verbose=args.verbose, - optimize=True, + if version.Version(torch.__version__) < version.Version("2.7"): + # This section is only needed for torch==2.6. The workaround implemented here + # to fix bugs is not necessary with torch>=2.7. + # - strings are not allowed with torch 2.6, so we replace them by DYNAMIC + # - TypePromotion was fixed in torch==2.7 + from onnxscript import opset18 as op + + dynamic_shapes = replace_dynamic_shapes( + dynamic_shapes, + dict(batch_size=torch.export.Dim("batch_size")), + default_value=torch.export.Dim.DYNAMIC, ) + # TypePromotion cannot fix a type issue after the conversion. + # We insert an additional CastLike when the exporter + def custom_aten_ge(self, other): + if isinstance(other, (int, float)): + return op.GreaterOrEqual(self, op.CastLike(other, self)) + return op.GreaterOrEqual(self, other) + + with bypass_export_some_errors(patch_transformers=True): + # ONNX pass TypePromotion crashes for torch 2.6. + # It can be bypassed by exporting first into an exported program. + # We then need to apply run_decompositions() before onnx conversion starts. + ep = torch.export.export( + llama, + (), + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + ep = ep.run_decompositions() + torch.onnx.export( + ep, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + optimize=True, + custom_translation_table={torch.ops.aten.ge.Scalar: custom_aten_ge}, + ) + else: + + with bypass_export_some_errors(patch_transformers=True): + torch.onnx.export( + llama, + (), + temp_path, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + verbose=args.verbose, + optimize=True, + ) + # Check decoder_with_past_model.onnx and save all external data to one file onnx.checker.check_model(temp_path) onnx.shape_inference.infer_shapes_path(temp_path) diff --git a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py index 518aa6aea47d7..ded05b8c37be5 100644 --- a/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py +++ b/onnxruntime/python/tools/transformers/models/torch_export_patches/patch_inputs.py @@ -162,5 +162,5 @@ def replace_dynamic_shapes(ds, mapping, default_value): if isinstance(ds, list): return [replace_dynamic_shapes(d, mapping, default_value) for d in ds] if isinstance(ds, dict): - return {k: replace_dynamic_shapes(v, mapping, default_value) for k,v in ds.items()} + return {k: replace_dynamic_shapes(v, mapping, default_value) for k, v in ds.items()} return ds From 902c6af4840a2737cb461aa855c3a77f246dfeb6 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 17:24:39 +0200 Subject: [PATCH 23/24] remove duplicated section --- .../tools/transformers/models/llama/convert_to_onnx.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 4cba7fa06213d..2f60af968d689 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -160,14 +160,6 @@ def run_dynamo_export( llama, args=model_args, dynamic_axes=dynamic_axes, prefix_mapping={"present": "past_key_values"} ) - if version.Version(torch.__version__) < version.Version("2.7"): - # strings are not allowed with torch 2.6, so we replace them by DYNAMIC - dynamic_shapes = replace_dynamic_shapes( - dynamic_shapes, - dict(batch_size=torch.export.Dim("batch_size")), - default_value=torch.export.Dim.DYNAMIC, - ) - if version.Version(torch.__version__) < version.Version("2.7"): # This section is only needed for torch==2.6. The workaround implemented here # to fix bugs is not necessary with torch>=2.7. From e3188ada895c2a8effc9f344f9f4f75ce2112c1b Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 3 Apr 2025 17:39:40 +0200 Subject: [PATCH 24/24] lint --- .../python/tools/transformers/models/llama/convert_to_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 2f60af968d689..5c778424709c8 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -204,7 +204,6 @@ def custom_aten_ge(self, other): custom_translation_table={torch.ops.aten.ge.Scalar: custom_aten_ge}, ) else: - with bypass_export_some_errors(patch_transformers=True): torch.onnx.export( llama,