diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index 2bc63ef6873..f42cd44a668 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -55,6 +55,8 @@ """ import argparse +import collections +import inspect import logging import math import os @@ -180,6 +182,24 @@ def export_transformer_to_onnx( inputs = tokenizer( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] + + # Rearrange inputs' keys to match those defined by model foward func, which + # seem to define how the order of inputs is determined in the exported model + forward_args_spec = inspect.getfullargspec(model.__class__.forward) + dropped = [f for f in inputs.keys() if f not in forward_args_spec.args] + inputs = collections.OrderedDict( + [ + (f, inputs[f][0].reshape(1, -1)) + for f in forward_args_spec.args + if f in inputs + ] + ) + if dropped: + _LOGGER.warning( + "The following inputs were not present in the model forward function " + f"and therefore dropped from ONNX export: {dropped}" + ) + inputs_shapes = { key: ( f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " @@ -187,6 +207,7 @@ def export_transformer_to_onnx( ) for key, val in inputs.items() } + _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") # run export