-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLM] Various fixes for LLM Fine-Tuning issues that caused loss disparity between train and val sets #3437
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
It's great that you've gotten to the bottom of the loss/metric differences with the PEFT notebook you were referencing. I would like to understand the tensor manipulations more deeply.
Most of my comments are questions.
It looks like some unit tests are failing due to the new NextTokenPerplexity metric. Could you check on that?
self.model.update_metrics_finetune(targets, preds) | ||
|
||
# accumulate predictions from batch for each output feature | ||
if collect_predictions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is collect_predictions=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's always false by default. I just copied this over from ECD to be honest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious @w4nderlust if you have more context on this collect_predictions
mechanism?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to remember, I believe it has to do with evaluation / training time. collecting the actual predictions, in particular the probabilities, could be expensive, especially for sequence/text output features, where there are these big tensors hanging around.
if i remember correctly, when running evaluate and the experiment functions, it is actually set to false for those reasons. (also the predictions before where concatenated before being written on disk, and that in some cases meant out of memory for no good reason after training)
@justinxzhao ready for re-review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few more questions!
@@ -235,5 +234,5 @@ class LLMTextDefaultsConfig(TextInputFeatureConfigMixin, TextOutputFeatureConfig | |||
|
|||
loss: BaseLossConfig = LossDataclassField( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm -- I wonder if this would be as simple as using a StringOptions
field (CC: @ksbrar)
schema_utils.StringOptions(
[NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY],
default= NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY,
allow_none=False,
)
self.model.update_metrics_finetune(targets, preds) | ||
|
||
# accumulate predictions from batch for each output feature | ||
if collect_predictions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious @w4nderlust if you have more context on this collect_predictions
mechanism?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! LGTM.
This PR fixes the following issues with training and preprocessing that were causing training loss to be significantly lower than val/test loss during LLM fine-tuning. It also removes dead code, and adds an extended suite of unit tests for all of the LLM utility functions.
Finally, this PR also introduces a new
preprocessing
parameter calledglobal_max_sequence_length
that is used to control the number of tokens being fed into the model during LLM fine-tuning. This parameter represents the total number of tokens that the model's forward pass will receive once input_ids and target_ids are merged together. If not set, it will skip truncation and pass in the merged tensors into the model's forward pass.or
Here's the complete changelog of all issues that were fixed:
Preprocessing
Training