Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] Various fixes for LLM Fine-Tuning issues that caused loss disparity between train and val sets #3437

Merged
merged 29 commits into from
Jun 23, 2023

Conversation

arnavgarg1
Copy link
Contributor

@arnavgarg1 arnavgarg1 commented Jun 13, 2023

This PR fixes the following issues with training and preprocessing that were causing training loss to be significantly lower than val/test loss during LLM fine-tuning. It also removes dead code, and adds an extended suite of unit tests for all of the LLM utility functions.

Finally, this PR also introduces a new preprocessing parameter called global_max_sequence_length that is used to control the number of tokens being fed into the model during LLM fine-tuning. This parameter represents the total number of tokens that the model's forward pass will receive once input_ids and target_ids are merged together. If not set, it will skip truncation and pass in the merged tensors into the model's forward pass.

preprocessing:
    global_max_sequence_length: null

or

preprocessing:
    global_max_sequence_length: 128

Here's the complete changelog of all issues that were fixed:

Preprocessing

  • During preprocessing, we're inserting quotes that change the token IDs completely. For e.g., the base PEFT notebook creates 'Tweet text : @HMRCcustomers No this is my first job Label : no complaint' vs us creating 'Tweet text : "@HMRCcustomers No this is my first job" Label : ' (ignore the no complaint part for now since we do that at a later stage during the forward pass). This is now fixed ✅
  • While removing left padding from input_ids or from target_ids, on some occasions, we were returning a tensor with the first token missing. This was happening in cases where the incoming tensor had no pad tokens since it was already at max_sequence_length. This is now fixed to correctly only remove pad tokens if they are present ✅
  • attention_mask issues:
    • It was being set to float32 instead of int64, which is now fixed ✅
    • The final pad token that we add after merging the labels is not attended to in the attention mask (it was being set to 0). It is now correctly set to 1. ✅
  • The model_inputs tensor that we create by concatenating input_ids and target_ids can have length > max_sequence_length. This is wrong since it should always be capped at max_sequence_length. It is now correctly fixed ✅

Training

  • While computing loss, the target tensors need to have a pad token appended to them. This is now fixed ✅
  • When realigning target and prediction tensors to have the same shape:
    • We're not considering max_sequence_length, which we must consider once again. However, this case is tricky because we don't want the target tensor to have any output values if the input tensor didn't actually have the labels in it (because it got truncated due to max sequence length). This is specifically because in this case we don't want to predict loss on these kinds of inputs. This has a nasty fix that involves caching the model_inputs during the forward pass as class variables and then re-using them when realigning the tensors. This is now fixed ✅
    • We were casting predictions to float instead of leaving them as int32/int64. This is now fixed ✅
    • We were casting targets to float32 instead of leaving them as int64. This is now fixed ✅
  • Fixes an issue where perplexity was being computed incorrectly. The root cause is the fact that we want to calculate perplexity using shifted cross entropy loss as opposed to regular cross entropy loss. This requires getting the next token shifted cross entropy loss and manually calling torch.exp() to get the perplexity score. Fixed now ✅

@arnavgarg1 arnavgarg1 changed the title [WIP] Fix LLM Fine-tuning issues Fix LLM Fine-tuning issues Jun 14, 2023
@arnavgarg1 arnavgarg1 marked this pull request as ready for review June 14, 2023 19:03
@github-actions
Copy link

github-actions bot commented Jun 14, 2023

Unit Test Results

  6 files  ±0    6 suites  ±0   1h 24m 35s ⏱️ + 4m 57s
33 tests ±0  29 ✔️ ±0    4 💤 ±0  0 ±0 
99 runs  ±0  87 ✔️ ±0  12 💤 ±0  0 ±0 

Results for commit 48bd590. ± Comparison against base commit 9112470.

♻️ This comment has been updated with latest results.

@arnavgarg1 arnavgarg1 marked this pull request as draft June 14, 2023 21:00
@arnavgarg1 arnavgarg1 changed the title Fix LLM Fine-tuning issues Fix LLM Fine-Tuning issues that caused loss disparity between train and val sets Jun 15, 2023
@arnavgarg1 arnavgarg1 marked this pull request as ready for review June 15, 2023 03:19
@arnavgarg1 arnavgarg1 requested a review from tgaddair June 15, 2023 03:19
Copy link
Collaborator

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

It's great that you've gotten to the bottom of the loss/metric differences with the PEFT notebook you were referencing. I would like to understand the tensor manipulations more deeply.

Most of my comments are questions.

It looks like some unit tests are failing due to the new NextTokenPerplexity metric. Could you check on that?

ludwig/data/prompt.py Show resolved Hide resolved
ludwig/modules/metric_modules.py Outdated Show resolved Hide resolved
self.model.update_metrics_finetune(targets, preds)

# accumulate predictions from batch for each output feature
if collect_predictions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is collect_predictions=False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's always false by default. I just copied this over from ECD to be honest

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious @w4nderlust if you have more context on this collect_predictions mechanism?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to remember, I believe it has to do with evaluation / training time. collecting the actual predictions, in particular the probabilities, could be expensive, especially for sequence/text output features, where there are these big tensors hanging around.
if i remember correctly, when running evaluate and the experiment functions, it is actually set to false for those reasons. (also the predictions before where concatenated before being written on disk, and that in some cases meant out of memory for no good reason after training)

ludwig/models/predictor.py Outdated Show resolved Hide resolved
ludwig/modules/metric_modules.py Outdated Show resolved Hide resolved
ludwig/models/llm.py Show resolved Hide resolved
ludwig/models/llm.py Show resolved Hide resolved
ludwig/models/llm.py Outdated Show resolved Hide resolved
ludwig/models/llm.py Show resolved Hide resolved
ludwig/models/llm.py Show resolved Hide resolved
@arnavgarg1 arnavgarg1 changed the title Fix LLM Fine-Tuning issues that caused loss disparity between train and val sets [LLM] Various fixes for LLM Fine-Tuning issues that caused loss disparity between train and val sets Jun 20, 2023
@arnavgarg1
Copy link
Contributor Author

@justinxzhao ready for re-review!

Copy link
Collaborator

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few more questions!

ludwig/models/llm.py Show resolved Hide resolved
@@ -235,5 +234,5 @@ class LLMTextDefaultsConfig(TextInputFeatureConfigMixin, TextOutputFeatureConfig

loss: BaseLossConfig = LossDataclassField(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm -- I wonder if this would be as simple as using a StringOptions field (CC: @ksbrar)

schema_utils.StringOptions(
        [NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY],
        default= NEXT_TOKEN_SOFTMAX_CROSS_ENTROPY,
        allow_none=False,
    )

ludwig/models/llm.py Outdated Show resolved Hide resolved
ludwig/models/llm.py Show resolved Hide resolved
self.model.update_metrics_finetune(targets, preds)

# accumulate predictions from batch for each output feature
if collect_predictions:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious @w4nderlust if you have more context on this collect_predictions mechanism?

ludwig/utils/llm_utils.py Show resolved Hide resolved
ludwig/utils/llm_utils.py Show resolved Hide resolved
Copy link
Collaborator

@justinxzhao justinxzhao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! LGTM.

@arnavgarg1 arnavgarg1 merged commit dc0d685 into master Jun 23, 2023
16 checks passed
@arnavgarg1 arnavgarg1 deleted the llm_debug branch June 23, 2023 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants