diff --git a/src/sparseml/transformers/export.py b/src/sparseml/transformers/export.py index a159219970a..721426b3c2e 100644 --- a/src/sparseml/transformers/export.py +++ b/src/sparseml/transformers/export.py @@ -297,10 +297,20 @@ def export_transformer_to_onnx( ) if sequence_length is None: + if hasattr(config, "max_position_embeddings"): + sequence_length = config.max_position_embeddings + elif hasattr(config, "max_seq_len"): + sequence_length = config.max_seq_len + else: + raise ValueError( + "Could not infer a default sequence length " + "from the HF transformers config. Please specify " + "the sequence length with --sequence_length" + ) _LOGGER.info( - f"Using default sequence length of {config.max_position_embeddings}" + f"Using default sequence length of {sequence_length} " + "(inferred from HF transformers config) " ) - sequence_length = config.max_position_embeddings tokenizer = AutoTokenizer.from_pretrained( model_path, model_max_length=sequence_length