diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index 254af6dbab4..fc7319e3114 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -140,6 +140,7 @@ def __init__( self.max_new_tokens = self.config_obj.generation.max_new_tokens self.max_input_length = self.context_len - self.max_new_tokens - 8 + # TODO(Arnav): This needs be more flexible to account for RoPE Scaling # When merging input IDs and target IDs for LLM fine-tuning, we want to make sure that the merged tensor is # not longer than the global maximum sequence length. This is provided in the preprocessing config. We never # want to exceed the maximum possible context length so we also check for that. diff --git a/ludwig/schema/features/preprocessing/text.py b/ludwig/schema/features/preprocessing/text.py index d6e48a261d4..27b51f60ba1 100644 --- a/ludwig/schema/features/preprocessing/text.py +++ b/ludwig/schema/features/preprocessing/text.py @@ -151,12 +151,19 @@ class TextPreprocessingConfig(BaseTextPreprocessingConfig): @DeveloperAPI -@register_preprocessor("text_llm") +@register_preprocessor("text_llm_input") @ludwig_dataclass -class LLMTextPreprocessingConfig(BaseTextPreprocessingConfig): +class LLMTextInputPreprocessingConfig(BaseTextPreprocessingConfig): """LLMs require the prompt to be provided at the top-level, not preprocessing.""" - pass + max_sequence_length: int = schema_utils.PositiveInteger( + default=None, + allow_none=True, + description="The maximum length (number of tokens) of the sequence. Sequences that are longer than this value " + "will be truncated. Useful as a stopgap measure if `sequence_length` is set to `None`. If `None`, max sequence " + "length will be inferred from the training dataset.", + parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["max_sequence_length"], + ) @DeveloperAPI @@ -217,3 +224,17 @@ class TextOutputPreprocessingConfig(BaseTextPreprocessingConfig): description="The size of the ngram when using the `ngram` tokenizer (e.g, 2 = bigram, 3 = trigram, etc.).", parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["ngram_size"], ) + + +@DeveloperAPI +@register_preprocessor("text_llm_output") +@ludwig_dataclass +class LLMTextOutputPreprocessingConfig(TextOutputPreprocessingConfig): + max_sequence_length: int = schema_utils.PositiveInteger( + default=None, + allow_none=True, + description="The maximum length (number of tokens) of the sequence. Sequences that are longer than this value " + "will be truncated. Useful as a stopgap measure if `sequence_length` is set to `None`. If `None`, max sequence " + "length will be inferred from the training dataset.", + parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["max_sequence_length"], + ) diff --git a/ludwig/schema/features/text_feature.py b/ludwig/schema/features/text_feature.py index 604d2fd9828..8c9984a6016 100644 --- a/ludwig/schema/features/text_feature.py +++ b/ludwig/schema/features/text_feature.py @@ -81,7 +81,7 @@ class GBMTextInputFeatureConfig(TextInputFeatureConfig): @llm_input_config_registry.register(TEXT) @ludwig_dataclass class LLMTextInputFeatureConfig(TextInputFeatureConfig): - preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm") + preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm_input") encoder: BaseEncoderConfig = EncoderDataclassField( MODEL_LLM, @@ -171,6 +171,8 @@ class LLMTextOutputFeatureConfig(TextOutputFeatureConfig): parameter_metadata=INTERNAL_ONLY, ) + preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm_output") + decoder: BaseDecoderConfig = DecoderDataclassField( MODEL_LLM, feature_type=TEXT, diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index 9644a9e95a2..f32a7716911 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -481,6 +481,23 @@ def test_llama_rope_scaling(): assert model.model.config.rope_scaling["factor"] == 2.0 +def test_default_max_sequence_length(): + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: TEST_MODEL_NAME, + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: { + TYPE: "finetune", + BATCH_SIZE: 8, + EPOCHS: 2, + }, + } + config_obj = ModelConfig.from_dict(config) + assert config_obj.input_features[0].preprocessing.max_sequence_length is None + assert config_obj.output_features[0].preprocessing.max_sequence_length is None + + def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):