Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically set max_new_tokens based on output feature length, GMSL and model window size #3713

Merged
merged 14 commits into from
Oct 13, 2023
9 changes: 2 additions & 7 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ludwig.utils.llm_utils import (
add_left_padding,
generate_merged_ids,
get_context_len,
pad_target_tensor_for_fine_tuning,
realign_target_and_prediction_tensors_for_inference,
remove_left_padding,
Expand Down Expand Up @@ -126,13 +127,7 @@ def __init__(
self.curr_device = next(self.model.parameters()).device
logger.info("Done.")

# Determines the maximum length of the context (input + output tokens)
if hasattr(self.model_config, "max_sequence_length"):
self.context_len = self.model_config.max_sequence_length
elif hasattr(self.model_config, "max_position_embeddings"):
self.context_len = self.model_config.max_position_embeddings
else:
self.context_len = 2048
self.context_len = get_context_len(self.model_config)

# TODO(Arnav): This needs be more flexible to account for RoPE Scaling
# When merging input IDs and target IDs for LLM fine-tuning, we want to make sure that the merged tensor is
Expand Down
6 changes: 3 additions & 3 deletions ludwig/schema/model_types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
sanitize_and_filter_combiner_entities_,
set_derived_feature_columns_,
set_hyperopt_defaults_,
set_llm_tokenizers,
set_llm_parameters,
set_preprocessing_parameters,
set_tagger_decoder_parameters,
set_validation_parameters,
Expand Down Expand Up @@ -69,8 +69,8 @@ def __post_init__(self):
set_tagger_decoder_parameters(self)
sanitize_and_filter_combiner_entities_(self)

# Set preprocessing parameters for text features for LLM model type
set_llm_tokenizers(self)
# Reconcile LLM parameters
set_llm_parameters(self)

# Reconcile conflicting preprocessing parameters
set_preprocessing_parameters(self)
Expand Down
72 changes: 68 additions & 4 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
INPUT_FEATURES,
LOSS,
MODEL_ECD,
MODEL_LLM,
OUTPUT_FEATURES,
PARAMETERS,
PREPROCESSING,
Expand Down Expand Up @@ -299,16 +300,24 @@ def set_tagger_decoder_parameters(config: "ModelConfig") -> None:
output_feature.reduce_input = None


def set_llm_tokenizers(config: "ModelConfig") -> None:
def set_llm_parameters(config: "ModelConfig") -> None:
if config.model_type != MODEL_LLM:
return

# Set preprocessing parameters for text features for LLM model type
_set_llm_tokenizers(config)

# Set max_new_tokens in generation config to the max sequence length of the output features
_set_generation_max_new_tokens(config)


def _set_llm_tokenizers(config: "ModelConfig") -> None:
"""Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the
correct shared vocabulary from the tokenizer.

This also ensures padding is correctly set to left padding to prevent the LLM from trying to continue to sequence
based on the right padding tokens, which might exist based on sequence length.
"""
if config.model_type != "llm":
return

pretrained_model_name_or_path = config.base_model
if not isinstance(pretrained_model_name_or_path, str) or pretrained_model_name_or_path is None:
raise ValueError("Must set `base_model` when using the LLM model.")
Expand Down Expand Up @@ -337,6 +346,61 @@ def set_llm_tokenizers(config: "ModelConfig") -> None:
output_feature.decoder.fallback_label = output_feature.preprocessing.fallback_label


def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
"""Sets the max_new_tokens parameter in the generation config to the max sequence length of the output
features.

This ensures that the generation config is set to the correct value for the LLM model type.
"""
from transformers import AutoConfig
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved

from ludwig.schema.llms.generation import LLMGenerationConfig
from ludwig.utils.llm_utils import get_context_len

_DEFAULT_MAX_SEQUENCE_LENGTH = LLMGenerationConfig().max_new_tokens
if config.generation.max_new_tokens != _DEFAULT_MAX_SEQUENCE_LENGTH:
# Max new tokens is explicitly set by user, so don't override
return

if config.output_features[0].type != TEXT:
# This is trickier to set for other output features, so don't override for now.
# TODO: Add better support for category output features
return

max_possible_sequence_length = _DEFAULT_MAX_SEQUENCE_LENGTH
if config.output_features[0].preprocessing.max_sequence_length is not None:
# Note: We don't need to check for max between feature.preprocessing.max_sequence_length and
# defaults.text.preprocessing.max_sequence_length because the latter is only applied to input features.
max_possible_sequence_length = max(
_DEFAULT_MAX_SEQUENCE_LENGTH, config.output_features[0].preprocessing.max_sequence_length
)
elif config.preprocessing.global_max_sequence_length is not None:
# This is not perfect since it includes tokens from both input + output features, but this at least
# ensures that max possible of the sequence length is used. It is very likely that the model learns
# to generate sequences than this value.
max_possible_sequence_length = max(
max_possible_sequence_length, config.preprocessing.global_max_sequence_length
)
elif max_possible_sequence_length == _DEFAULT_MAX_SEQUENCE_LENGTH:
# It's possible that both max_sequence_length and global_max_sequence_length are not set, in which case
# we should fall back to the window size of the pretrained model. By this point, because of schema validation
# checks, we know that the base_model exists so we can safely grab the base model's config.
# TODO (Arnav): Figure out how to factor in rope scaling factor into this calculation.
model_config = AutoConfig.from_pretrained(config.base_model)
max_possible_sequence_length = get_context_len(model_config)
# Artifically leave a buffer of half the total model window size to trade off
# runtime while likely covering a majority of the max sequence length.
max_possible_sequence_length = max_possible_sequence_length // 2
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved

logger.info(
f"Setting generation max_new_tokens to {max_possible_sequence_length} to correspond with the max "
"sequence length assigned to the output feature or the global max sequence length. This will ensure that "
"the correct number of tokens are generated at inference time. To override this behavior, set "
"`generation.max_new_tokens` to a different value in your Ludwig config."
)
config.generation.max_new_tokens = max_possible_sequence_length


@DeveloperAPI
def contains_grid_search_parameters(hyperopt_config: HyperoptConfigDict) -> bool:
"""Returns True if any hyperopt parameter in the config is using the grid_search space."""
Expand Down
43 changes: 43 additions & 0 deletions ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from bitsandbytes.nn.modules import Embedding
from transformers import (
AutoConfig,
AutoModelForCausalLM,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
Expand All @@ -22,6 +23,9 @@
logger = logging.getLogger(__name__)


FALLBACK_CONTEXT_LEN = 2048


def set_pad_token(tokenizer: PreTrainedTokenizer):
"""Sets the pad token for the tokenizer if it is not already set.

Expand Down Expand Up @@ -57,6 +61,45 @@ def set_pad_token(tokenizer: PreTrainedTokenizer):
tokenizer.pad_token_id = tokenizer.eos_token_id


def get_context_len(model_config: AutoConfig):
"""Determines the maximum length of the context (input + output tokens) based on the provided model
configuration.

Args:
model_config (AutoConfig): The model configuration object containing information about the model's properties.

Returns:
int: The maximum context length, which can be derived from the model configuration. If no relevant attribute
is found, the default value of 2048 is returned.

This function examines the provided model configuration object to identify the attribute that specifies the maximum
context length. It checks for attributes in the following order of preference:
1. 'max_sequence_length': If this attribute is present in the model configuration, its value is returned.
2. 'max_position_embeddings': If 'max_sequence_length' is not found but 'max_position_embeddings' is present, its
value is returned.
3. 'n_positions': If neither 'max_sequence_length' nor 'max_position_embeddings' are found, and 'n_positions' is
present, its value is returned.
4. Default: If none of the relevant attributes are present, the function returns a default value of 2048.

Note:
- The maximum context length is important for defining the size of input and output sequences in a model.

Example Usage:
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> context_len = get_context_len(config)
>>> print(context_len)
512
"""
if hasattr(model_config, "max_sequence_length"):
return model_config.max_sequence_length
elif hasattr(model_config, "max_position_embeddings"):
return model_config.max_position_embeddings
elif hasattr(model_config, "n_positions"):
return model_config.n_positions
else:
return FALLBACK_CONTEXT_LEN


def has_padding_token(input_tensor: torch.Tensor, tokenizer: PreTrainedTokenizer):
"""Checks if the input tensor contains any padding tokens.

Expand Down
59 changes: 59 additions & 0 deletions tests/ludwig/schema/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,62 @@ def test_llm_quantization_backend_compatibility():
ModelConfig.from_dict(config)

ray.shutdown()


class TestMaxNewTokensOverride:
def test_max_new_tokens_override_no_changes_to_max_new_tokens(self):
"""Tests that the default value for max_new_tokens is respected when explicitly set in the config."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
"generation": {"max_new_tokens": 64},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 64

def test_max_new_tokens_override_large_max_sequence_length(self):
"""Tests that the default value for max_new_tokens is overridden when max_sequence_length is set to a large
value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text", "preprocessing": {"max_sequence_length": 100}}],
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100

def test_max_new_tokens_override_large_global_max_sequence_length(self):
"""Tests that the default value for max_new_tokens is overridden when global_max_sequence_length is set to
a larger value than the default max_new_tokens."""
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
PREPROCESSING: {"global_max_sequence_length": 100},
}

config_obj = ModelConfig.from_dict(config)
assert config_obj.generation.max_new_tokens == 100

def test_max_new_tokens_override_fallback_to_model_window_size(self):
config = {
MODEL_TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
INPUT_FEATURES: [{NAME: "text_input", TYPE: "text"}],
# Default value for generation.max_sequence_length is 32
OUTPUT_FEATURES: [{NAME: "text_output", TYPE: "text"}],
}

config_obj = ModelConfig.from_dict(config)
# Base model context length is 2048 tokens by default
# Since we fallback to setting max_new_tokens to the model context length / 2, we expect it to be 1024 tokens
assert config_obj.generation.max_new_tokens == 1024
27 changes: 26 additions & 1 deletion tests/ludwig/utils/test_llm_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest
import torch
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer

from ludwig.constants import LOGITS, PREDICTIONS, PROBABILITIES
from ludwig.utils.llm_utils import (
add_left_padding,
create_attention_mask,
FALLBACK_CONTEXT_LEN,
find_last_matching_index,
generate_merged_ids,
get_context_len,
has_padding_token,
pad_target_tensor_for_fine_tuning,
realign_target_and_prediction_tensors_for_inference,
Expand Down Expand Up @@ -57,6 +59,29 @@ def test_set_pad_token_already_exists():
assert tokenizer.pad_token_id == 1


class TestSetContextLen:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting mechanic for grouping tests - curious if you saw this pattern recommended somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinxzhao Definitely didn't see it recommended anywhere, but I wanted to find a logical way to group these tests together since they're about the same "topic" but testing different aspects of it, so decided to write a class. Is that fine, or would you like me to just write 4 individual tests?

Copy link
Contributor Author

@arnavgarg1 arnavgarg1 Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea in general is that since they're all testing the same function but different scenarios, it makes sense to either put them all in the same dedicated module for clarity or in some sort of container like a class. Typically you can just use parameterization but this one would require a lot of conditionals to be used in the test so I decided to skip it. Alternatively, I could also combine them all into one test. All options are ok - no strong preference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see this being a useful way to organize tests particularly for very large test files. It's a bit more to maintain, but seems sufficiently lightweight.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, let me split them up!

def test_max_sequence_length(self):
# Test when 'max_sequence_length' is present in the model configuration
config = AutoConfig.from_pretrained("huggyllama/llama-7b")
assert get_context_len(config) == config.max_sequence_length

def test_max_position_embeddings(self):
# Test when 'max_position_embeddings' is present in the model configuration
config = AutoConfig.from_pretrained("huggyllama/llama-7b")
del config.max_sequence_length
assert get_context_len(config) == config.max_position_embeddings

def test_n_positions(self):
# Test when 'n_positions' is present in the model configuration
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM")
assert get_context_len(config) == config.n_positions

def test_default_value(self):
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTJForCausalLM")
del config.n_positions
assert get_context_len(config) == FALLBACK_CONTEXT_LEN


def test_has_padding_token_with_padding_tokens(tokenizer):
input_sentence = "This is an example sentence."
input_ids = tokenizer([input_sentence])
Expand Down
Loading