From 0cab6d085dbeb06798f97eec3b825d661618ddfb Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Mon, 7 Feb 2022 14:14:43 -0500 Subject: [PATCH 1/2] Enforce order on input keys to export --- src/sparseml/transformers/export.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index 2bc63ef6873..e21fd24a7a3 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -55,6 +55,8 @@ """ import argparse +import collections +import inspect import logging import math import os @@ -180,6 +182,17 @@ def export_transformer_to_onnx( inputs = tokenizer( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] + + # Rearrange inputs' keys to match those defined by model foward func + forward_args_spec = inspect.getfullargspec(model.__class__.forward) + inputs = collections.OrderedDict( + [ + (f, inputs[f][0].reshape(1, -1)) + for f in forward_args_spec.args + if f in inputs + ] + ) + inputs_shapes = { key: ( f"{val.dtype if hasattr(val, 'dtype') else 'unknown'}: " @@ -187,6 +200,7 @@ def export_transformer_to_onnx( ) for key, val in inputs.items() } + _LOGGER.info(f"Created sample inputs for the ONNX export process: {inputs_shapes}") # run export From 3aeea54accf1f8fb81d02c3fa19407e5d3340586 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Tue, 8 Feb 2022 12:01:49 -0500 Subject: [PATCH 2/2] Warn if input dropped from onnx export --- src/sparseml/transformers/export.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index e21fd24a7a3..f42cd44a668 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -183,8 +183,10 @@ def export_transformer_to_onnx( "", return_tensors="pt", padding=PaddingStrategy.MAX_LENGTH.value ).data # Dict[Tensor] - # Rearrange inputs' keys to match those defined by model foward func + # Rearrange inputs' keys to match those defined by model foward func, which + # seem to define how the order of inputs is determined in the exported model forward_args_spec = inspect.getfullargspec(model.__class__.forward) + dropped = [f for f in inputs.keys() if f not in forward_args_spec.args] inputs = collections.OrderedDict( [ (f, inputs[f][0].reshape(1, -1)) @@ -192,6 +194,11 @@ def export_transformer_to_onnx( if f in inputs ] ) + if dropped: + _LOGGER.warning( + "The following inputs were not present in the model forward function " + f"and therefore dropped from ONNX export: {dropped}" + ) inputs_shapes = { key: (