Skip to content

Commit

Permalink
Set default max_sequence_length to None for LLM text input/output fea…
Browse files Browse the repository at this point in the history
…tures (#3547)
  • Loading branch information
arnavgarg1 committed Aug 28, 2023
1 parent feec8a6 commit 5ef0878
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
1 change: 1 addition & 0 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
self.max_new_tokens = self.config_obj.generation.max_new_tokens
self.max_input_length = self.context_len - self.max_new_tokens - 8

# TODO(Arnav): This needs be more flexible to account for RoPE Scaling
# When merging input IDs and target IDs for LLM fine-tuning, we want to make sure that the merged tensor is
# not longer than the global maximum sequence length. This is provided in the preprocessing config. We never
# want to exceed the maximum possible context length so we also check for that.
Expand Down
27 changes: 24 additions & 3 deletions ludwig/schema/features/preprocessing/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,19 @@ class TextPreprocessingConfig(BaseTextPreprocessingConfig):


@DeveloperAPI
@register_preprocessor("text_llm")
@register_preprocessor("text_llm_input")
@ludwig_dataclass
class LLMTextPreprocessingConfig(BaseTextPreprocessingConfig):
class LLMTextInputPreprocessingConfig(BaseTextPreprocessingConfig):
"""LLMs require the prompt to be provided at the top-level, not preprocessing."""

pass
max_sequence_length: int = schema_utils.PositiveInteger(
default=None,
allow_none=True,
description="The maximum length (number of tokens) of the sequence. Sequences that are longer than this value "
"will be truncated. Useful as a stopgap measure if `sequence_length` is set to `None`. If `None`, max sequence "
"length will be inferred from the training dataset.",
parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["max_sequence_length"],
)


@DeveloperAPI
Expand Down Expand Up @@ -217,3 +224,17 @@ class TextOutputPreprocessingConfig(BaseTextPreprocessingConfig):
description="The size of the ngram when using the `ngram` tokenizer (e.g, 2 = bigram, 3 = trigram, etc.).",
parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["ngram_size"],
)


@DeveloperAPI
@register_preprocessor("text_llm_output")
@ludwig_dataclass
class LLMTextOutputPreprocessingConfig(TextOutputPreprocessingConfig):
max_sequence_length: int = schema_utils.PositiveInteger(
default=None,
allow_none=True,
description="The maximum length (number of tokens) of the sequence. Sequences that are longer than this value "
"will be truncated. Useful as a stopgap measure if `sequence_length` is set to `None`. If `None`, max sequence "
"length will be inferred from the training dataset.",
parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["max_sequence_length"],
)
4 changes: 3 additions & 1 deletion ludwig/schema/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class GBMTextInputFeatureConfig(TextInputFeatureConfig):
@llm_input_config_registry.register(TEXT)
@ludwig_dataclass
class LLMTextInputFeatureConfig(TextInputFeatureConfig):
preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm")
preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm_input")

encoder: BaseEncoderConfig = EncoderDataclassField(
MODEL_LLM,
Expand Down Expand Up @@ -171,6 +171,8 @@ class LLMTextOutputFeatureConfig(TextOutputFeatureConfig):
parameter_metadata=INTERNAL_ONLY,
)

preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="text_llm_output")

decoder: BaseDecoderConfig = DecoderDataclassField(
MODEL_LLM,
feature_type=TEXT,
Expand Down
17 changes: 17 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,23 @@ def test_llama_rope_scaling():
assert model.model.config.rope_scaling["factor"] == 2.0


def test_default_max_sequence_length():
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: TEST_MODEL_NAME,
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})],
OUTPUT_FEATURES: [text_feature(name="output")],
TRAINER: {
TYPE: "finetune",
BATCH_SIZE: 8,
EPOCHS: 2,
},
}
config_obj = ModelConfig.from_dict(config)
assert config_obj.input_features[0].preprocessing.max_sequence_length is None
assert config_obj.output_features[0].preprocessing.max_sequence_length is None


def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool:
# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
Expand Down

0 comments on commit 5ef0878

Please sign in to comment.