Add sequence parallel compatibility with transformers >= 5.4.0#9167
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the minimum Python version requirement to 3.10 in the documentation and modifies the loss scaling logic in seq2seq_trainer.py. A review comment identifies that changing the label indexing from labels[:, 1:] to labels for calculating num_items_in_batch introduces an inconsistency with other parts of the trainer and may lead to incorrect gradient scaling in Causal Language Modeling tasks.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request updates the minimum Python version requirement to 3.10 and refactors sequence parallel masking functions to improve compatibility with different versions of the transformers library. It also adjusts the loss normalization logic in the seq2seq trainer. However, the refactoring of flash_attention_mask and sdpa_mask introduces potential bugs: by switching to *args and **kwargs but only checking kwargs for specific parameters like attention_mask and cache_position, the code may fail to capture these values if they are passed positionally. This could result in missing masks or TypeError exceptions during execution.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request updates the minimum Python version to 3.10 and refactors attention masking and loss calculation logic to improve compatibility with newer transformer versions and sequence parallelism. Feedback highlights potential issues with positional argument handling in the sdpa_mask wrapper, concerns about the accuracy of the num_items_in_batch calculation for non-sequence-parallel training, and a discrepancy between the code and comments regarding the reduction group used for batch item counts.
| def sdpa_mask(*args, **kwargs): | ||
| if self.world_size == 1: | ||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, | ||
| cache_position, | ||
| kv_length, *args, | ||
| **kwargs) | ||
| device = cache_position.device | ||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](*args, **kwargs) | ||
| if 'cache_position' in kwargs: | ||
| device = kwargs['cache_position'].device | ||
| else: | ||
| # transformers>=5.4.0 | ||
| device = kwargs['device'] | ||
| cache_position = self.real_position_ids[0] | ||
| cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) | ||
| cache_position = torch.arange(0, cache_position.shape[0], device=device) | ||
| kv_length = cache_position.shape[0] | ||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, | ||
| cache_position, | ||
| kv_length, *args, | ||
| **kwargs) | ||
| kwargs['kv_length'] = cache_position.shape[0] | ||
| if 'cache_position' in kwargs: | ||
| kwargs['cache_position'] = cache_position | ||
| else: | ||
| kwargs['q_length'] = kwargs['kv_length'] | ||
| return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](*args, **kwargs) |
There was a problem hiding this comment.
Similar to flash_attention_mask, the sdpa_mask wrapper fails to handle positional arguments correctly. If cache_position is passed positionally (index 1), it will not be detected in kwargs, leading to a KeyError when trying to access kwargs['device'] (or kwargs['cache_position']). Furthermore, updating kwargs while leaving the original positional args intact will cause a TypeError when calling sdpa_origin due to multiple values for the same parameter, or it will cause sdpa_origin to use the old positional value instead of the updated one.
| if num_items_in_batch is None: | ||
| num_items_in_batch = (labels[:, 1:] != -100).sum() | ||
| # https://github.com/huggingface/transformers/blob/9dff7ca5c9693f4c02cdd2a9c2abc4772fcea5da/src/transformers/trainer.py#L2137 | ||
| num_items_in_batch = (labels != -100).sum() # compat SP |
There was a problem hiding this comment.
Changing the num_items_in_batch calculation to (labels != -100).sum() for all cases may introduce a slight inaccuracy in non-sequence-parallel training for Causal LMs. In the non-SP case, labels are typically not shifted yet, and the loss is computed on labels[:, 1:]. Using the full labels count as the denominator will result in a slightly lower loss value than intended. It is better to keep the conditional logic to ensure the count matches the tokens that actually contribute to the loss.
num_items_in_batch = (labels != -100).sum() if self.template.sequence_parallel_size > 1 else (labels[:, 1:] != -100).sum() # compat SP| # to full length via GatherLoss. Reduce the denominator | ||
| # across the SP group so it matches the gathered loss. | ||
| dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group) | ||
| dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM) |
There was a problem hiding this comment.
The all_reduce call for num_items_in_batch now performs a global reduction (default group) instead of reducing only across the sp_group. However, the comment on lines 201-203 still states that the reduction is 'across the SP group'. If a global reduction is intended (e.g., to support global token-based averaging in newer transformers versions), the comment should be updated. Otherwise, if the goal is to reconstruct the local batch size for the SP group, the group argument should be restored.
No description provided.