From 42d0fbc05bdf5ac42405955f35ac475727ddf2e1 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Tue, 10 Oct 2023 22:34:03 +0300 Subject: [PATCH 01/13] Dynamically set max_new_tokens based on output feature length, global max sequence length and model window size --- ludwig/schema/model_types/base.py | 6 +- ludwig/schema/model_types/utils.py | 76 ++++++++++++++++++++++-- tests/ludwig/schema/test_model_config.py | 59 ++++++++++++++++++ 3 files changed, 134 insertions(+), 7 deletions(-) diff --git a/ludwig/schema/model_types/base.py b/ludwig/schema/model_types/base.py index b3ebc154ffb..410aa5c454e 100644 --- a/ludwig/schema/model_types/base.py +++ b/ludwig/schema/model_types/base.py @@ -30,7 +30,7 @@ sanitize_and_filter_combiner_entities_, set_derived_feature_columns_, set_hyperopt_defaults_, - set_llm_tokenizers, + set_llm_parameters, set_preprocessing_parameters, set_tagger_decoder_parameters, set_validation_parameters, @@ -69,8 +69,8 @@ def __post_init__(self): set_tagger_decoder_parameters(self) sanitize_and_filter_combiner_entities_(self) - # Set preprocessing parameters for text features for LLM model type - set_llm_tokenizers(self) + # Reconcile LLM parameters + set_llm_parameters(self) # Reconcile conflicting preprocessing parameters set_preprocessing_parameters(self) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index a38a3d53409..84b3b8d89f4 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -17,6 +17,7 @@ INPUT_FEATURES, LOSS, MODEL_ECD, + MODEL_LLM, OUTPUT_FEATURES, PARAMETERS, PREPROCESSING, @@ -299,16 +300,24 @@ def set_tagger_decoder_parameters(config: "ModelConfig") -> None: output_feature.reduce_input = None -def set_llm_tokenizers(config: "ModelConfig") -> None: +def set_llm_parameters(config: "ModelConfig") -> None: + if config.model_type != MODEL_LLM: + return + + # Set preprocessing parameters for text features for LLM model type + _set_llm_tokenizers(config) + + # Set max_new_tokens in generation config to the max sequence length of the output features + _set_generation_max_new_tokens(config) + + +def _set_llm_tokenizers(config: "ModelConfig") -> None: """Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the correct shared vocabulary from the tokenizer. This also ensures padding is correctly set to left padding to prevent the LLM from trying to continue to sequence based on the right padding tokens, which might exist based on sequence length. """ - if config.model_type != "llm": - return - pretrained_model_name_or_path = config.base_model if not isinstance(pretrained_model_name_or_path, str) or pretrained_model_name_or_path is None: raise ValueError("Must set `base_model` when using the LLM model.") @@ -337,6 +346,65 @@ def set_llm_tokenizers(config: "ModelConfig") -> None: output_feature.decoder.fallback_label = output_feature.preprocessing.fallback_label +def _set_generation_max_new_tokens(config: "ModelConfig") -> None: + """Sets the max_new_tokens parameter in the generation config to the max sequence length of the output + features. + + This ensures that the generation config is set to the correct value for the LLM model type. + """ + from transformers import AutoConfig + + from ludwig.schema.llms.generation import LLMGenerationConfig + + default_max_sequence_length = LLMGenerationConfig().max_new_tokens + if config.generation.max_new_tokens != default_max_sequence_length: + # Max new tokens is explicitly set by user, so don't override + return + + if config.output_features[0].type != TEXT: + # This is trickier to set for other output features, so don't override for now. + # TODO: Add support for other category features + return + + max_possible_sequence_length = default_max_sequence_length + if config.output_features[0].preprocessing.max_sequence_length is not None: + # Note: We don't need to check for max between feature.preprocessing.max_sequence_length and + # defaults.text.preprocessing.max_sequence_length because the latter is only applied to input features. + max_possible_sequence_length = max( + default_max_sequence_length, config.output_features[0].preprocessing.max_sequence_length + ) + if config.preprocessing.global_max_sequence_length is not None: + # This is not perfect since it includes tokens from both input + output features, but this at least + # ensures that max possible of the sequence length is used. It is very likely that the model learns + # to generate sequences than this value. + max_possible_sequence_length = max( + max_possible_sequence_length, config.preprocessing.global_max_sequence_length + ) + + # It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case + # we should fall back to the window size of the pretrained model. By this point, because of schema validation + # checks, we know that the base_model exists so we can safely grab the base model's config. + if max_possible_sequence_length == default_max_sequence_length: + model_config = AutoConfig.from_pretrained(config.base_model) + # Determines the maximum length of the context (input + output tokens) + if hasattr(model_config, "max_sequence_length"): + max_possible_sequence_length = model_config.max_sequence_length + elif hasattr(model_config, "max_position_embeddings"): + max_possible_sequence_length = model_config.max_position_embeddings + else: + # Fallback to 2048 for now. + # TODO: Determine this dynamically + max_possible_sequence_length = 2048 + + logger.info( + f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " + "sequence length assigned to the output feature or the global max sequence length. This will ensure that " + "the correct number of tokens are generated at inference time. To override this behavior, set " + "`generation.max_new_tokens` to a different value in your Ludwig config." + ) + config.generation.max_new_tokens = max_possible_sequence_length + + @DeveloperAPI def contains_grid_search_parameters(hyperopt_config: HyperoptConfigDict) -> bool: """Returns True if any hyperopt parameter in the config is using the grid_search space.""" diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 054a274b378..41060dab94b 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -950,3 +950,62 @@ def test_llm_quantization_backend_compatibility(): ModelConfig.from_dict(config) ray.shutdown() + + +def test_max_new_tokens_override_no_changes_to_max_new_tokens(): + """Tests that the default value for max_new_tokens is respected when explicitly set in the config.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + "generation": {"max_new_tokens": 64}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 64 + + +def test_max_new_tokens_override_large_max_sequence_length(): + """Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large + value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + +def test_max_new_tokens_override_large_global_max_sequence_length(): + """Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to a + larger value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + PREPROCESSING: {"global_max_sequence_length": 100}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + +def test_max_new_tokens_override_fallback_to_model_window_size(): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 2048 From d9c880519e30857cb966ba4a4b8bcd9130f4464f Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Tue, 10 Oct 2023 22:40:44 +0300 Subject: [PATCH 02/13] TODO comment --- ludwig/schema/model_types/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index 84b3b8d89f4..946eb12f327 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -384,6 +384,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: # It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case # we should fall back to the window size of the pretrained model. By this point, because of schema validation # checks, we know that the base_model exists so we can safely grab the base model's config. + # TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation. if max_possible_sequence_length == default_max_sequence_length: model_config = AutoConfig.from_pretrained(config.base_model) # Determines the maximum length of the context (input + output tokens) From 00a97df00ad4fffafe7329685bd494173f209ef9 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Tue, 10 Oct 2023 22:57:54 +0300 Subject: [PATCH 03/13] Fix comment' --- ludwig/schema/model_types/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index 946eb12f327..31ec3f569a9 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -363,7 +363,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: if config.output_features[0].type != TEXT: # This is trickier to set for other output features, so don't override for now. - # TODO: Add support for other category features + # TODO: Add better support for category output features return max_possible_sequence_length = default_max_sequence_length From 75988063bdc1222e3935f23cba1cb746001fe524 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 01:46:17 +0300 Subject: [PATCH 04/13] Address PR comments --- ludwig/models/llm.py | 9 ++---- ludwig/schema/model_types/utils.py | 11 ++----- ludwig/utils/llm_utils.py | 43 ++++++++++++++++++++++++++++ tests/ludwig/utils/test_llm_utils.py | 27 ++++++++++++++++- 4 files changed, 73 insertions(+), 17 deletions(-) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index bf43d025531..47b55f19fb2 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -21,6 +21,7 @@ from ludwig.utils.llm_utils import ( add_left_padding, generate_merged_ids, + get_context_len, pad_target_tensor_for_fine_tuning, realign_target_and_prediction_tensors_for_inference, remove_left_padding, @@ -126,13 +127,7 @@ def __init__( self.curr_device = next(self.model.parameters()).device logger.info("Done.") - # Determines the maximum length of the context (input + output tokens) - if hasattr(self.model_config, "max_sequence_length"): - self.context_len = self.model_config.max_sequence_length - elif hasattr(self.model_config, "max_position_embeddings"): - self.context_len = self.model_config.max_position_embeddings - else: - self.context_len = 2048 + self.context_len = get_context_len(self.model_config) # TODO(Arnav): This needs be more flexible to account for RoPE Scaling # When merging input IDs and target IDs for LLM fine-tuning, we want to make sure that the merged tensor is diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index 31ec3f569a9..dd73b53e41c 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -355,6 +355,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: from transformers import AutoConfig from ludwig.schema.llms.generation import LLMGenerationConfig + from ludwig.utils.llm_utils import get_context_len default_max_sequence_length = LLMGenerationConfig().max_new_tokens if config.generation.max_new_tokens != default_max_sequence_length: @@ -387,15 +388,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: # TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation. if max_possible_sequence_length == default_max_sequence_length: model_config = AutoConfig.from_pretrained(config.base_model) - # Determines the maximum length of the context (input + output tokens) - if hasattr(model_config, "max_sequence_length"): - max_possible_sequence_length = model_config.max_sequence_length - elif hasattr(model_config, "max_position_embeddings"): - max_possible_sequence_length = model_config.max_position_embeddings - else: - # Fallback to 2048 for now. - # TODO: Determine this dynamically - max_possible_sequence_length = 2048 + max_possible_sequence_length = get_context_len(model_config) logger.info( f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " diff --git a/ludwig/utils/llm_utils.py b/ludwig/utils/llm_utils.py index 12c0a1f509f..1dd9d1df85a 100644 --- a/ludwig/utils/llm_utils.py +++ b/ludwig/utils/llm_utils.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from bitsandbytes.nn.modules import Embedding from transformers import ( + AutoConfig, AutoModelForCausalLM, CodeLlamaTokenizer, CodeLlamaTokenizerFast, @@ -22,6 +23,9 @@ logger = logging.getLogger(__name__) +FALLBACK_CONTEXT_LEN = 2048 + + def set_pad_token(tokenizer: PreTrainedTokenizer): """Sets the pad token for the tokenizer if it is not already set. @@ -57,6 +61,45 @@ def set_pad_token(tokenizer: PreTrainedTokenizer): tokenizer.pad_token_id = tokenizer.eos_token_id +def get_context_len(model_config: AutoConfig): + """Determines the maximum length of the context (input + output tokens) based on the provided model + configuration. + + Args: + model_config (AutoConfig): The model configuration object containing information about the model's properties. + + Returns: + int: The maximum context length, which can be derived from the model configuration. If no relevant attribute + is found, the default value of 2048 is returned. + + This function examines the provided model configuration object to identify the attribute that specifies the maximum + context length. It checks for attributes in the following order of preference: + 1. 'max_sequence_length': If this attribute is present in the model configuration, its value is returned. + 2. 'max_position_embeddings': If 'max_sequence_length' is not found but 'max_position_embeddings' is present, its + value is returned. + 3. 'n_positions': If neither 'max_sequence_length' nor 'max_position_embeddings' are found, and 'n_positions' is + present, its value is returned. + 4. Default: If none of the relevant attributes are present, the function returns a default value of 2048. + + Note: + - The maximum context length is important for defining the size of input and output sequences in a model. + + Example Usage: + >>> config = AutoConfig.from_pretrained("bert-base-uncased") + >>> context_len = get_context_len(config) + >>> print(context_len) + 512 + """ + if hasattr(model_config, "max_sequence_length"): + return model_config.max_sequence_length + elif hasattr(model_config, "max_position_embeddings"): + return model_config.max_position_embeddings + elif hasattr(model_config, "n_positions"): + return model_config.n_positions + else: + return FALLBACK_CONTEXT_LEN + + def has_padding_token(input_tensor: torch.Tensor, tokenizer: PreTrainedTokenizer): """Checks if the input tensor contains any padding tokens. diff --git a/tests/ludwig/utils/test_llm_utils.py b/tests/ludwig/utils/test_llm_utils.py index e652297d285..d79264bf26a 100644 --- a/tests/ludwig/utils/test_llm_utils.py +++ b/tests/ludwig/utils/test_llm_utils.py @@ -1,13 +1,15 @@ import pytest import torch -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from ludwig.constants import LOGITS, PREDICTIONS, PROBABILITIES from ludwig.utils.llm_utils import ( add_left_padding, create_attention_mask, + FALLBACK_CONTEXT_LEN, find_last_matching_index, generate_merged_ids, + get_context_len, has_padding_token, pad_target_tensor_for_fine_tuning, realign_target_and_prediction_tensors_for_inference, @@ -57,6 +59,29 @@ def test_set_pad_token_already_exists(): assert tokenizer.pad_token_id == 1 +class TestSetContextLen: + def test_max_sequence_length(self): + # Test when 'max_sequence_length' is present in the model configuration + config = AutoConfig.from_pretrained("huggyllama/llama-7b") + assert get_context_len(config) == config.max_sequence_length + + def test_max_position_embeddings(self): + # Test when 'max_position_embeddings' is present in the model configuration + config = AutoConfig.from_pretrained("huggyllama/llama-7b") + del config.max_sequence_length + assert get_context_len(config) == config.max_position_embeddings + + def test_n_positions(self): + # Test when 'n_positions' is present in the model configuration + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM") + assert get_context_len(config) == config.n_positions + + def test_default_value(self): + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM") + del config.n_positions + assert get_context_len(config) == FALLBACK_CONTEXT_LEN + + def test_has_padding_token_with_padding_tokens(tokenizer): input_sentence = "This is an example sentence." input_ids = tokenizer([input_sentence]) From a57db6ffda5a48eb2a3063ef62880b95b08a9ec3 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 02:07:18 +0300 Subject: [PATCH 05/13] Fix failing tests --- ludwig/models/llm.py | 9 ++++++--- ludwig/schema/model_types/utils.py | 27 +++++++++++++++++++-------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 47b55f19fb2..ed7383c351d 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -206,10 +206,13 @@ def _set_generation_config(self, new_generation_config_dict: Dict[str, Any]): # CodeLlama to avoid getting an error. This workaround can be found here: # (https://github.com/huggingface/transformers/issues/25353#issuecomment-1669339754) self.generation.pad_token_id = self.tokenizer.pad_token_id - self.max_new_tokens = self.generation.max_new_tokens + self.max_new_tokens = self.generation.max_new_tokens or self.generation.max_length # max input length value copied from FastChat - # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L183 # noqa E501 - self.max_input_length = self.context_len - self.max_new_tokens - 8 + # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L180 # noqa E501 + if self.model_config.is_encoder_decoder: + self.max_input_length = self.context_len - self.max_new_tokens - 8 + else: + self.max_input_length = self.context_len @property def output_feature_decoder(self) -> OutputFeature: diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index dd73b53e41c..b7f338ffa07 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -389,14 +389,25 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: if max_possible_sequence_length == default_max_sequence_length: model_config = AutoConfig.from_pretrained(config.base_model) max_possible_sequence_length = get_context_len(model_config) - - logger.info( - f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " - "sequence length assigned to the output feature or the global max sequence length. This will ensure that " - "the correct number of tokens are generated at inference time. To override this behavior, set " - "`generation.max_new_tokens` to a different value in your Ludwig config." - ) - config.generation.max_new_tokens = max_possible_sequence_length + # Max length only works if max_new_tokens is not set. + # If max_new_tokens is set, then we need to set max_length to None to ensure that the correct number of tokens + # are generated (input + output tokens), otherwise we will exceed the bounds of generation resulting in errors. + config.generation.max_new_tokens = None + config.generation.max_length = max_possible_sequence_length + logger.info( + f"Setting generation max_length to {max_possible_sequence_length} to correspond with the max sequence " + "length assigned to the output feature or the global max sequence length. This will ensure that the " + "correct number of tokens are generated at inference time. To override this behavior, set " + "`generation.max_length` to a different value in your Ludwig config." + ) + else: + logger.info( + f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " + "sequence length assigned to the output feature or the global max sequence length. This will ensure that " + "the correct number of tokens are generated at inference time. To override this behavior, set " + "`generation.max_new_tokens` to a different value in your Ludwig config." + ) + config.generation.max_new_tokens = max_possible_sequence_length @DeveloperAPI From 1cec86f93adeba37f581061d36b9bcf64801ca66 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 02:09:50 +0300 Subject: [PATCH 06/13] Fix tests --- tests/ludwig/schema/test_model_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 41060dab94b..2813540572e 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -1008,4 +1008,5 @@ def test_max_new_tokens_override_fallback_to_model_window_size(): } config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 2048 + assert config_obj.generation.max_new_tokens is None + assert config_obj.generation.max_length == 2048 From 58406b7b940742007b09306be3682b578798f247 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 02:32:17 +0300 Subject: [PATCH 07/13] Revert to base --- ludwig/models/llm.py | 10 +++----- ludwig/schema/model_types/utils.py | 39 +++++++++++++++--------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index ed7383c351d..85ad170dc43 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -206,13 +206,11 @@ def _set_generation_config(self, new_generation_config_dict: Dict[str, Any]): # CodeLlama to avoid getting an error. This workaround can be found here: # (https://github.com/huggingface/transformers/issues/25353#issuecomment-1669339754) self.generation.pad_token_id = self.tokenizer.pad_token_id - self.max_new_tokens = self.generation.max_new_tokens or self.generation.max_length + self.max_new_tokens = self.generation.max_new_tokens # max input length value copied from FastChat - # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L180 # noqa E501 - if self.model_config.is_encoder_decoder: - self.max_input_length = self.context_len - self.max_new_tokens - 8 - else: - self.max_input_length = self.context_len + # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L183 # noqa E501 + self.max_input_length = self.context_len - self.max_new_tokens - 8 + assert self.max_input_length > 0 @property def output_feature_decoder(self) -> OutputFeature: diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index b7f338ffa07..b167e54a84d 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -389,25 +389,26 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: if max_possible_sequence_length == default_max_sequence_length: model_config = AutoConfig.from_pretrained(config.base_model) max_possible_sequence_length = get_context_len(model_config) - # Max length only works if max_new_tokens is not set. - # If max_new_tokens is set, then we need to set max_length to None to ensure that the correct number of tokens - # are generated (input + output tokens), otherwise we will exceed the bounds of generation resulting in errors. - config.generation.max_new_tokens = None - config.generation.max_length = max_possible_sequence_length - logger.info( - f"Setting generation max_length to {max_possible_sequence_length} to correspond with the max sequence " - "length assigned to the output feature or the global max sequence length. This will ensure that the " - "correct number of tokens are generated at inference time. To override this behavior, set " - "`generation.max_length` to a different value in your Ludwig config." - ) - else: - logger.info( - f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " - "sequence length assigned to the output feature or the global max sequence length. This will ensure that " - "the correct number of tokens are generated at inference time. To override this behavior, set " - "`generation.max_new_tokens` to a different value in your Ludwig config." - ) - config.generation.max_new_tokens = max_possible_sequence_length + # # Max length only works if max_new_tokens is not set. + # # If max_new_tokens is set, then we need to set max_length to None to ensure that the correct number of tokens + # # are generated (input + output tokens), otherwise we will exceed the bounds of generation + # # resulting in errors. + # config.generation.max_new_tokens = max_possible_sequence_length - 8 + # config.generation.max_length = max_possible_sequence_length + # logger.info( + # f"Setting generation max_length to {max_possible_sequence_length} to correspond with the max sequence " + # "length assigned to the output feature or the global max sequence length. This will ensure that the " + # "correct number of tokens are generated at inference time. To override this behavior, set " + # "`generation.max_length` to a different value in your Ludwig config." + # ) + # else: + logger.info( + f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " + "sequence length assigned to the output feature or the global max sequence length. This will ensure that " + "the correct number of tokens are generated at inference time. To override this behavior, set " + "`generation.max_new_tokens` to a different value in your Ludwig config." + ) + config.generation.max_new_tokens = max_possible_sequence_length @DeveloperAPI From f6adf0849db4cd76f2c9a203b85bba6df340f57d Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 02:36:35 +0300 Subject: [PATCH 08/13] Add temporary buffer --- ludwig/schema/model_types/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index b167e54a84d..7a9427e28b1 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -389,6 +389,9 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: if max_possible_sequence_length == default_max_sequence_length: model_config = AutoConfig.from_pretrained(config.base_model) max_possible_sequence_length = get_context_len(model_config) + # Artifically leave a buffer for now to prevent the following error: + # RuntimeError: index 512 is out of bounds for dimension 1 with size 512 + max_possible_sequence_length -= 32 # # Max length only works if max_new_tokens is not set. # # If max_new_tokens is set, then we need to set max_length to None to ensure that the correct number of tokens # # are generated (input + output tokens), otherwise we will exceed the bounds of generation From 27011e0b11d0185aa6b2cb43b8fdb1d13638814f Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 21:05:47 +0200 Subject: [PATCH 09/13] Fix issues by defaulting to model context len // 2 --- ludwig/models/llm.py | 3 +-- ludwig/schema/model_types/utils.py | 30 +++++++----------------- tests/ludwig/schema/test_model_config.py | 11 ++++----- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 24a3ec8b7c9..9582a6c33b1 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -208,9 +208,8 @@ def _set_generation_config(self, new_generation_config_dict: Dict[str, Any]): self.generation.pad_token_id = self.tokenizer.pad_token_id self.max_new_tokens = self.generation.max_new_tokens # max input length value copied from FastChat - # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L183 # noqa E501 + # https://github.com/lm-sys/FastChat/blob/0e958b852a14f4bef5f0e9d7a5e7373477329cf2/fastchat/serve/inference.py#L183 # noqa E501 self.max_input_length = self.context_len - self.max_new_tokens - 8 - assert self.max_input_length > 0 @property def output_feature_decoder(self) -> OutputFeature: diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index 7a9427e28b1..e8d2ad58da2 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -357,8 +357,8 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: from ludwig.schema.llms.generation import LLMGenerationConfig from ludwig.utils.llm_utils import get_context_len - default_max_sequence_length = LLMGenerationConfig().max_new_tokens - if config.generation.max_new_tokens != default_max_sequence_length: + _DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens + if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH: # Max new tokens is explicitly set by user, so don't override return @@ -367,12 +367,12 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: # TODO: Add better support for category output features return - max_possible_sequence_length = default_max_sequence_length + max_possible_sequence_length = _DEFAULT_MAX_SEQUENCE_LENGTH if config.output_features[0].preprocessing.max_sequence_length is not None: # Note: We don't need to check for max between feature.preprocessing.max_sequence_length and # defaults.text.preprocessing.max_sequence_length because the latter is only applied to input features. max_possible_sequence_length = max( - default_max_sequence_length, config.output_features[0].preprocessing.max_sequence_length + _DEFAULT_MAX_SEQUENCE_LENGTH, config.output_features[0].preprocessing.max_sequence_length ) if config.preprocessing.global_max_sequence_length is not None: # This is not perfect since it includes tokens from both input + output features, but this at least @@ -386,25 +386,13 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: # we should fall back to the window size of the pretrained model. By this point, because of schema validation # checks, we know that the base_model exists so we can safely grab the base model's config. # TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation. - if max_possible_sequence_length == default_max_sequence_length: + if max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH: model_config = AutoConfig.from_pretrained(config.base_model) max_possible_sequence_length = get_context_len(model_config) - # Artifically leave a buffer for now to prevent the following error: - # RuntimeError: index 512 is out of bounds for dimension 1 with size 512 - max_possible_sequence_length -= 32 - # # Max length only works if max_new_tokens is not set. - # # If max_new_tokens is set, then we need to set max_length to None to ensure that the correct number of tokens - # # are generated (input + output tokens), otherwise we will exceed the bounds of generation - # # resulting in errors. - # config.generation.max_new_tokens = max_possible_sequence_length - 8 - # config.generation.max_length = max_possible_sequence_length - # logger.info( - # f"Setting generation max_length to {max_possible_sequence_length} to correspond with the max sequence " - # "length assigned to the output feature or the global max sequence length. This will ensure that the " - # "correct number of tokens are generated at inference time. To override this behavior, set " - # "`generation.max_length` to a different value in your Ludwig config." - # ) - # else: + # Artifically leave a buffer of half the total model window size to trade off + # runtime while likely covering a majority of the max sequence length. + max_possible_sequence_length = max_possible_sequence_length // 2 + logger.info( f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max " "sequence length assigned to the output feature or the global max sequence length. This will ensure that " diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 2813540572e..ad271b60d00 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -958,7 +958,7 @@ def test_max_new_tokens_override_no_changes_to_max_new_tokens(): MODEL_TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for max_sequence_length is 32 + # Default value for generation.max_sequence_length is 32 OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], "generation": {"max_new_tokens": 64}, } @@ -974,7 +974,7 @@ def test_max_new_tokens_override_large_max_sequence_length(): MODEL_TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for max_sequence_length is 32 + # Default value for generation.max_sequence_length is 32 OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], } @@ -989,7 +989,7 @@ def test_max_new_tokens_override_large_global_max_sequence_length(): MODEL_TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for max_sequence_length is 32 + # Default value for generation.max_sequence_length is 32 OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], PREPROCESSING: {"global_max_sequence_length": 100}, } @@ -1003,10 +1003,9 @@ def test_max_new_tokens_override_fallback_to_model_window_size(): MODEL_TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for max_sequence_length is 32 + # Default value for generation.max_sequence_length is 32 OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], } config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens is None - assert config_obj.generation.max_length == 2048 + assert config_obj.generation.max_new_tokens == 1024 From a8a2cc8005b7b528cb90f10a2128f718ebe992f5 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 21:08:27 +0200 Subject: [PATCH 10/13] Add better comment for test --- tests/ludwig/schema/test_model_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index ad271b60d00..4165a8801e4 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -1008,4 +1008,6 @@ def test_max_new_tokens_override_fallback_to_model_window_size(): } config_obj = ModelConfig.from_dict(config) + # Base model context length is 2048 tokens by default + # Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens assert config_obj.generation.max_new_tokens == 1024 From 4d198d44a1d5722bc8956de6e5609c3a2371c63e Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 21:11:30 +0200 Subject: [PATCH 11/13] Group tests under a single test class --- tests/ludwig/schema/test_model_config.py | 116 +++++++++++------------ 1 file changed, 57 insertions(+), 59 deletions(-) diff --git a/tests/ludwig/schema/test_model_config.py b/tests/ludwig/schema/test_model_config.py index 4165a8801e4..21e2883b989 100644 --- a/tests/ludwig/schema/test_model_config.py +++ b/tests/ludwig/schema/test_model_config.py @@ -952,62 +952,60 @@ def test_llm_quantization_backend_compatibility(): ray.shutdown() -def test_max_new_tokens_override_no_changes_to_max_new_tokens(): - """Tests that the default value for max_new_tokens is respected when explicitly set in the config.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - "generation": {"max_new_tokens": 64}, - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 64 - - -def test_max_new_tokens_override_large_max_sequence_length(): - """Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large - value than the default max_new_tokens.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 100 - - -def test_max_new_tokens_override_large_global_max_sequence_length(): - """Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to a - larger value than the default max_new_tokens.""" - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - PREPROCESSING: {"global_max_sequence_length": 100}, - } - - config_obj = ModelConfig.from_dict(config) - assert config_obj.generation.max_new_tokens == 100 - - -def test_max_new_tokens_override_fallback_to_model_window_size(): - config = { - MODEL_TYPE: MODEL_LLM, - BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", - INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], - # Default value for generation.max_sequence_length is 32 - OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], - } - - config_obj = ModelConfig.from_dict(config) - # Base model context length is 2048 tokens by default - # Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens - assert config_obj.generation.max_new_tokens == 1024 +class TestMaxNewTokensOverride: + def test_max_new_tokens_override_no_changes_to_max_new_tokens(self): + """Tests that the default value for max_new_tokens is respected when explicitly set in the config.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + "generation": {"max_new_tokens": 64}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 64 + + def test_max_new_tokens_override_large_max_sequence_length(self): + """Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large + value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}], + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + def test_max_new_tokens_override_large_global_max_sequence_length(self): + """Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to + a larger value than the default max_new_tokens.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + PREPROCESSING: {"global_max_sequence_length": 100}, + } + + config_obj = ModelConfig.from_dict(config) + assert config_obj.generation.max_new_tokens == 100 + + def test_max_new_tokens_override_fallback_to_model_window_size(self): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}], + # Default value for generation.max_sequence_length is 32 + OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}], + } + + config_obj = ModelConfig.from_dict(config) + # Base model context length is 2048 tokens by default + # Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens + assert config_obj.generation.max_new_tokens == 1024 From 3b5ec39fc10608997506cf225054340374e480e7 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Wed, 11 Oct 2023 21:21:15 +0200 Subject: [PATCH 12/13] Refactor --- ludwig/schema/model_types/utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index e8d2ad58da2..6df01f9125d 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -374,19 +374,18 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: max_possible_sequence_length = max( _DEFAULT_MAX_SEQUENCE_LENGTH, config.output_features[0].preprocessing.max_sequence_length ) - if config.preprocessing.global_max_sequence_length is not None: + elif config.preprocessing.global_max_sequence_length is not None: # This is not perfect since it includes tokens from both input + output features, but this at least # ensures that max possible of the sequence length is used. It is very likely that the model learns # to generate sequences than this value. max_possible_sequence_length = max( max_possible_sequence_length, config.preprocessing.global_max_sequence_length ) - - # It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case - # we should fall back to the window size of the pretrained model. By this point, because of schema validation - # checks, we know that the base_model exists so we can safely grab the base model's config. - # TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation. - if max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH: + elif max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH: + # It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case + # we should fall back to the window size of the pretrained model. By this point, because of schema validation + # checks, we know that the base_model exists so we can safely grab the base model's config. + # TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation. model_config = AutoConfig.from_pretrained(config.base_model) max_possible_sequence_length = get_context_len(model_config) # Artifically leave a buffer of half the total model window size to trade off From d40b75f7546f26a3b5190765bcee663200dd5f57 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 12 Oct 2023 18:56:10 +0200 Subject: [PATCH 13/13] Address comments --- ludwig/schema/model_types/utils.py | 52 ++++++++++++++++-------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/ludwig/schema/model_types/utils.py b/ludwig/schema/model_types/utils.py index 6df01f9125d..b8550d06838 100644 --- a/ludwig/schema/model_types/utils.py +++ b/ludwig/schema/model_types/utils.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Set, TYPE_CHECKING from marshmallow import ValidationError +from transformers import AutoConfig from ludwig.api_annotations import DeveloperAPI from ludwig.constants import ( @@ -29,9 +30,11 @@ from ludwig.features.feature_utils import compute_feature_hash from ludwig.schema.features.utils import output_config_registry from ludwig.schema.hyperopt.scheduler import BaseHyperbandSchedulerConfig +from ludwig.schema.llms.generation import LLMGenerationConfig from ludwig.schema.trainer import ECDTrainerConfig from ludwig.types import HyperoptConfigDict, ModelConfigDict from ludwig.utils.data_utils import get_sanitized_feature_name +from ludwig.utils.llm_utils import get_context_len if TYPE_CHECKING: from ludwig.schema.model_types.base import ModelConfig @@ -346,33 +349,14 @@ def _set_llm_tokenizers(config: "ModelConfig") -> None: output_feature.decoder.fallback_label = output_feature.preprocessing.fallback_label -def _set_generation_max_new_tokens(config: "ModelConfig") -> None: - """Sets the max_new_tokens parameter in the generation config to the max sequence length of the output - features. - - This ensures that the generation config is set to the correct value for the LLM model type. - """ - from transformers import AutoConfig - - from ludwig.schema.llms.generation import LLMGenerationConfig - from ludwig.utils.llm_utils import get_context_len - - _DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens - if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH: - # Max new tokens is explicitly set by user, so don't override - return - - if config.output_features[0].type != TEXT: - # This is trickier to set for other output features, so don't override for now. - # TODO: Add better support for category output features - return - - max_possible_sequence_length = _DEFAULT_MAX_SEQUENCE_LENGTH +def _get_maximum_possible_sequence_length(config: "ModelConfig", default_max_sequence_length: int) -> int: + """Returns the maximum possible sequence length for the LLM model based on the model config.""" + max_possible_sequence_length = default_max_sequence_length if config.output_features[0].preprocessing.max_sequence_length is not None: # Note: We don't need to check for max between feature.preprocessing.max_sequence_length and # defaults.text.preprocessing.max_sequence_length because the latter is only applied to input features. max_possible_sequence_length = max( - _DEFAULT_MAX_SEQUENCE_LENGTH, config.output_features[0].preprocessing.max_sequence_length + default_max_sequence_length, config.output_features[0].preprocessing.max_sequence_length ) elif config.preprocessing.global_max_sequence_length is not None: # This is not perfect since it includes tokens from both input + output features, but this at least @@ -381,7 +365,7 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: max_possible_sequence_length = max( max_possible_sequence_length, config.preprocessing.global_max_sequence_length ) - elif max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH: + elif max_possible_sequence_length == default_max_sequence_length: # It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case # we should fall back to the window size of the pretrained model. By this point, because of schema validation # checks, we know that the base_model exists so we can safely grab the base model's config. @@ -391,6 +375,26 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None: # Artifically leave a buffer of half the total model window size to trade off # runtime while likely covering a majority of the max sequence length. max_possible_sequence_length = max_possible_sequence_length // 2 + return max_possible_sequence_length + + +def _set_generation_max_new_tokens(config: "ModelConfig") -> None: + """Sets the max_new_tokens parameter in the generation config to the max sequence length of the output + features. + + This ensures that the generation config is set to the correct value for the LLM model type. + """ + _DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens + if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH: + # Max new tokens is explicitly set by user, so don't override + return + + if config.output_features[0].type != TEXT: + # This is trickier to set for other output features, so don't override for now. + # TODO: Add better support for category output features + return + + max_possible_sequence_length = _get_maximum_possible_sequence_length(config, _DEFAULT_MAX_SEQUENCE_LENGTH) logger.info( f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max "