-
Notifications
You must be signed in to change notification settings - Fork 998
[template] support truncation_strategy spllit (swift pt) #6672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[template] support truncation_strategy spllit (swift pt) #6672
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances data handling capabilities by introducing a 'split' truncation strategy, which intelligently divides large data samples into smaller, manageable chunks during pretraining to optimize resource usage. It also refines the default Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for a new truncation_strategy, 'split', which is useful for pre-training tasks to avoid wasting tokens on long sequences. The changes span across argument definitions, template encoding logic, and documentation.
I've identified a critical bug in the data collation logic when using the 'split' strategy without packing, and a missing validation check for cached_dataset incompatibility. I've also included a suggestion to improve code clarity.
By the way, there's a small typo in the pull request title: "spllit" should be "split".
swift/llm/train/sft.py
Outdated
| if args.truncation_strategy == 'split' and (args.task_type != 'causal_lm' or template.mode != 'train' | ||
| or args.use_chat_template or args.model_meta.is_multimodal): | ||
| raise ValueError( | ||
| '`--truncation_strategy split` is currently only supported for plain text model pretraining') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation mentions that truncation_strategy='split' is incompatible with cached_dataset. It would be good to add a check here to enforce this and prevent potential issues. For example, you could add if args.cached_dataset: raise ValueError(...) before this block or within a refactored conditional.
swift/llm/template/base.py
Outdated
| while i < length: | ||
| splited = {} | ||
| for key in ['input_ids', 'labels', 'loss_scale']: | ||
| value = locals()[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using locals() can be a bit magical and less explicit. It would be clearer and more robust to access the variables directly from the encoded dictionary, since they were populated from it just a few lines above. This improves maintainability by making the data flow more explicit.
| value = locals()[key] | |
| value = encoded.get(key) |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a 'split' truncation strategy, primarily for pre-training, to handle long sequences by splitting them into multiple samples. The changes span across argument definitions, template encoding logic, and the training pipeline, with corresponding documentation updates. My review identifies a critical bug that could lead to an IndexError during data processing and suggests a fix. I also recommend improving an assertion message for better clarity and adding a check to enforce a documented incompatibility. Lastly, there is a minor typo in the pull request title ('spllit' instead of 'split').
| value = value[i:i + self.max_length] | ||
| if key == 'labels': | ||
| value[0] = -100 | ||
| elif key == 'loss_scale': | ||
| value[0] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential IndexError here. If len(encoded.get(key)) is less than i, the slice value[i:i + self.max_length] will be an empty list. Accessing value[0] on an empty list will raise an IndexError. This can happen if, for example, input_ids is longer than labels, and the loop is still running based on len(input_ids). You should add a check to ensure the sliced value is not empty before accessing its first element.
| value = value[i:i + self.max_length] | |
| if key == 'labels': | |
| value[0] = -100 | |
| elif key == 'loss_scale': | |
| value[0] = 0 | |
| value = value[i:i + self.max_length] | |
| if value: | |
| if key == 'labels': | |
| value[0] = -100 | |
| elif key == 'loss_scale': | |
| value[0] = 0 |
swift/llm/train/sft.py
Outdated
| or args.model_meta.is_multimodal): | ||
| raise ValueError( | ||
| '`--truncation_strategy split` is currently only supported for plain text model pretraining') | ||
| assert not args.lazy_tokenize, 'not support' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message 'not support' is not very descriptive. It would be more helpful for developers if it clearly stated why the assertion failed. Additionally, the documentation mentions that truncation_strategy='split' is incompatible with cached_dataset. It would be good to add an assertion here to enforce this and prevent potential issues.
| assert not args.lazy_tokenize, 'not support' | |
| assert not args.lazy_tokenize, '`--truncation_strategy split` does not support lazy tokenization' | |
| assert not args.cached_dataset, '`--truncation_strategy split` is not compatible with `--cached_dataset`' |
No description provided.