From 4bd73696132f966ad816f7068cde6cc5e52750d8 Mon Sep 17 00:00:00 2001 From: Damian Date: Tue, 28 Nov 2023 09:20:15 +0000 Subject: [PATCH 1/2] initial commit --- src/sparseml/transformers/export.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index a159219970a..b33676c1c1a 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -297,10 +297,20 @@ def export_transformer_to_onnx( ) if sequence_length is None: + if hasattr(config, "max_position_embeddings"): + sequence_length = config.max_position_embeddings + elif hasattr(config, "max_seq_len"): + sequence_length = config.max_seq_len + else: + raise ValueError( + "Could not infer a default sequence length " + "from the HF transformers config. Please specify " + "a sequence length with --sequence_length" + ) _LOGGER.info( - f"Using default sequence length of {config.max_position_embeddings}" + f"Using default sequence length of {sequence_length} " + "(inferred from HF transformers config) " ) - sequence_length = config.max_position_embeddings tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length From 3cc37c99b21ab66fb248740c8e444d5d926ef303 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Tue, 28 Nov 2023 10:30:03 +0100 Subject: [PATCH 2/2] Update src/sparseml/transformers/export.py --- src/sparseml/transformers/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index b33676c1c1a..721426b3c2e 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -305,7 +305,7 @@ def export_transformer_to_onnx( raise ValueError( "Could not infer a default sequence length " "from the HF transformers config. Please specify " - "a sequence length with --sequence_length" + "the sequence length with --sequence_length" ) _LOGGER.info( f"Using default sequence length of {sequence_length} "