-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch size tuning for LLMs #3871
Conversation
ludwig/trainers/trainer_llm.py
Outdated
input_msl = input_feature.input_shape[0] | ||
output_msl = output_feature.output_shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the most reliable way to get the MSL? Or should we be looking up properties in the feature's preprocessing configuration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on what I see here, it looks like at least the input_shape should provide a tighter upper bound than looking at the preprocessing configuration, but maybe I'm misinterpreting that line. However, I do think output_shape needs to be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! Can you leave a quick comment about this in the code?
ludwig/trainers/trainer_llm.py
Outdated
snapshot_weights: bool = True, | ||
on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None, | ||
tune_for_training: bool = True, | ||
max_sequence_length: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Rename to global_max_sequence_length
?
3c22b27
to
cda2a7b
Compare
96855de
to
e708ede
Compare
ludwig/trainers/trainer_llm.py
Outdated
if not self.vocab_size: | ||
self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Since the trainer
object is available outside of the _TrainerBatchSizeEvaluator
class, we don't actually need to add this if condition. Instead, we can just do
# This is useful to create the synthetic input and target data which will be a
# random sequence of integers between 0 and vocab_size
self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab)
in the constructor itself.
ludwig/trainers/trainer_llm.py
Outdated
if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length: | ||
self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can also move into the constructor, something like this:
# Get the length of the longest output sequence from the training data
self.output_msl = self.output_feature.output_shape[0]
if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length:
# max_sequence_length here is the smaller value between the global max sequence length of the model
# and the model's context length
self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length
ludwig/utils/batch_size_tuner.py
Outdated
@@ -51,7 +52,9 @@ def _is_valid_batch_size(batch_size): | |||
gc.collect() | |||
|
|||
try: | |||
samples_per_sec = self.evaluate(batch_size, total_steps=5) | |||
samples_per_sec = self.evaluate( | |||
batch_size, total_steps=5, global_max_sequence_length=global_max_sequence_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Make this a constant in the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not quite sure I understand. Why should samples_per_sec
be a constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Infernaught I think @justinxzhao was referring to the 5 in total_steps=5
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh I see
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
84e569e
to
82f0c7c
Compare
def test_llm_batch_size_tuning(): | ||
dataset = pd.DataFrame({"instruction": ["a"] * 100, "output": ["a"] * 100}) | ||
config = yaml.safe_load( | ||
""" | ||
model_type: llm | ||
input_features: | ||
- name: instruction | ||
type: text | ||
output_features: | ||
- name: output | ||
type: text | ||
prompt: | ||
template: >- | ||
{instruction} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Infernaught Seeing the same test twice? on line 1258 and line 1348
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think pre-commit is also complaining about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Probably an issue with merging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likely so
dd4462a
to
6c4633b
Compare
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Arnav Garg <arnav@predibase.com>
This PR extends Ludwig's batch size tuning functionality to LLMs.
For each batch size, we generate synthetic data in the following way:
We consider three values:
(1) The sum of the max_sequence_lengths of the input feature and the output feature
(2) The global_max_sequence_length
(3) The model's context length
If (1) is the smallest, then we generate synthetic inputs and outputs with the corresponding max_sequence_lengths.
If (2) is the smallest, then we generate synthetic inputs and outputs with length global_max_sequence_length/2 + 1.
If (3) is the smallest, then we generate synthetic inputs and outputs with length context_len/2 + 1.