Conversation
Add `to_transformers_dict` function to convert InputFeature instances into a dictionary compatible with transformers models. The function extracts relevant keys and ensures values are either numpy arrays or torch tensors as required by the transformers library.
- Update `eval` function to pass `adapter_name="default"` to `forward_only`, `calculate_loss`, and `calculate_metric` methods - In `train` function, set optimizer for adapter and include `adapter_name` in `get_train_configs`, `forward_backward`, and `clip_grad_and_step` calls - Ensures proper adapter-specific operations during training and evaluation
- Precompute decay and no-decay parameter name lists before optimizer group creation - Add explicit param_names field to optimizer groups for better debugging and transparency - Maintain identical functional behavior while improving code readability
Modify TransformersModel to only apply sp_strategy.postprocess_outputs when labels are None, preventing unintended postprocessing during training or evaluation with labels present. This ensures postprocessing is reserved for inference scenarios.
Add conditional loss reduction using sp_strategy when labels are present in inputs. This ensures that the loss calculation accounts for the sp_strategy's specific reduction logic, improving model training consistency and alignment with the strategy's objectives.
- Add comprehensive docstring to `_get_sp_group_from_device_mesh` explaining how SP groups are derived when no explicit "sp" mesh dimension exists - Include inline comments in backward passes and attention logic to clarify gradient handling and layout transformations - Improve readability and maintainability of sequence parallel implementation
Summary of ChangesHello @meichangsu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refines the implementation of sequence parallelism within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for sequence parallelism, primarily by adding a NativeFSDPStrategy and integrating it into the TransformersModel. The changes also involve extensive additions of comments and docstrings to clarify the new sequence parallel logic, especially in src/twinkle/model/transformers/strategy/sequence_parallel.py. My review focuses on improving code efficiency and maintainability in the new and modified code sections. I've suggested using constants for magic strings, optimizing data handling functions, and improving the efficiency of parameter grouping logic.
| model.forward_only(inputs=batch, adapter_name="default") | ||
| model.calculate_loss(adapter_name="default") | ||
| metrics = model.calculate_metric(is_training=False, adapter_name="default") |
There was a problem hiding this comment.
The string "default" is used as a magic string for adapter_name in multiple places in this file (e.g., lines 76, 77, 78, 101, 102, 104, 115, 119). It would be better to define this as a constant at the beginning of the file to improve readability and maintainability. For example: DEFAULT_ADAPTER_NAME = "default".
| _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'logits_to_keep', 'num_items_in_batch'] | ||
| for key in list(feature.keys()): |
There was a problem hiding this comment.
For performance, it's better to use a set for _keys for O(1) average time complexity for membership testing. Also, iterating over list(feature.keys()) is inefficient as it creates a new list. You can iterate directly over the dictionary keys.
Additionally, import torch is inside the function. According to PEP 8, imports should usually be at the top of the file, unless there's a specific reason for lazy loading.
| _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'logits_to_keep', 'num_items_in_batch'] | |
| for key in list(feature.keys()): | |
| _keys = {'input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask', 'logits_to_keep', 'num_items_in_batch'} | |
| for key in feature: |
| decay_param_names = [ | ||
| n for n, p in params.items() if n in decay_parameters and p.requires_grad | ||
| ] | ||
| no_decay_param_names = [ | ||
| n for n, p in params.items() if n not in decay_parameters and p.requires_grad | ||
| ] |
There was a problem hiding this comment.
This logic iterates over params.items() twice to create decay_param_names and no_decay_param_names. You could achieve the same result with a single loop for better performance, especially if params is large.
Consider this alternative:
decay_param_names = []
no_decay_param_names = []
for n, p in params.items():
if p.requires_grad:
if n in decay_parameters:
decay_param_names.append(n)
else:
no_decay_param_names.append(n)Also, for better performance of n in decay_parameters, consider converting decay_parameters to a set after it's created.
No description provided.