-
Notifications
You must be signed in to change notification settings - Fork 30.6k
feature: Add robust token counting with padding exclusion #40416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: Add robust token counting with padding exclusion #40416
Conversation
…ens_seen variable and kept bool for backward compatibility and added string also to ensure everything goes well and kept default as is. also robust test cases are created
…t and also solved code quality issue
Hello, I made changes and our feature test case is successful. I am working on passing on checks, I noticed in my first commit it gave me success in run_tests but because of code_quality it failed and I solved it then in 3rd, 4th, 5th commit there I am getting inconsistent result in run_tests in terms of number of failed --> 3, 2, 1 respectively. Is this because environment issue or what can be? |
cc @SunMarc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 😉
Thank you all! |
…e#40416) * created robust token counting by using existing include_num_input_tokens_seen variable and kept bool for backward compatibility and added string also to ensure everything goes well and kept default as is. also robust test cases are created * some codebase mismatched in my local and remote, commiting to solve it and also solved code quality issue * ci: retrigger tests * another attemp to trigger CI for checks
…e#40416) * created robust token counting by using existing include_num_input_tokens_seen variable and kept bool for backward compatibility and added string also to ensure everything goes well and kept default as is. also robust test cases are created * some codebase mismatched in my local and remote, commiting to solve it and also solved code quality issue * ci: retrigger tests * another attemp to trigger CI for checks
Fixes #40401
This pull request improves the Trainer by adding a better way to count input tokens. It includes a new option to exclude padding. This is done by expanding the functionality of the current include_num_input_tokens_seen argument in TrainingArguments, ensuring full backward compatibility.
What was the feature?
The goal was to give users more precise control over how input tokens are counted during training. This feature allows excluding padding tokens from the total count. This is useful for accurate logging and performance analysis, especially in tasks with variable sequence lengths.
What was done and why?
To implement this effectively without adding unnecessary new parameters (bool flag), the following changes were made:
Updated Existing Parameter: The include_num_input_tokens_seen argument in TrainingArguments was updated to accept string values ("all", "non_padding") in addition to boolean values. This allows for clearer control while keeping full backward compatibility (True is mapped to "all," and False to "no").
Improved Counting Logic: The Trainer's token counting logic was made more reliable. When "non_padding" is selected, the Trainer now follows a prioritized approach:
It first tries to use attention_mask.sum() for the most accurate count of non-padded tokens.
If attention_mask is not available, it counts tokens where input_ids are not equal to the pad_token_id.
If neither method works, it counts all tokens and logs a warning to inform the user.
Testing:
To ensure the reliability of this feature, a thorough test suite has been added to tests/trainer/test_trainer.py. The new tests cover:
All token counting modes ("all," "non_padding," True, False).
The new fallback logic, with specific test cases for when attention_mask is present, when it is absent (falling back to pad_token_id), and when neither is available (testing the warning and fallback to counting all tokens).
Full backward compatibility.
I noticed torch_dtype is replaced by dtype so in our files I made them manual changes so no issues will be created to merge it. #39782
Also I clicked on Update Branch button.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
Models:
Library: