Skip to content

Add sequence parallel compatibility with transformers >= 5.4.0#9167

Merged
Jintao-Huang merged 6 commits into
modelscope:mainfrom
Jintao-Huang:update_docs_0421
Apr 21, 2026
Merged

Add sequence parallel compatibility with transformers >= 5.4.0#9167
Jintao-Huang merged 6 commits into
modelscope:mainfrom
Jintao-Huang:update_docs_0421

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the minimum Python version requirement to 3.10 in the documentation and modifies the loss scaling logic in seq2seq_trainer.py. A review comment identifies that changing the label indexing from labels[:, 1:] to labels for calculating num_items_in_batch introduces an inconsistency with other parts of the trainer and may lead to incorrect gradient scaling in Causal Language Modeling tasks.

Comment thread swift/trainers/seq2seq_trainer.py Outdated
@Jintao-Huang Jintao-Huang changed the title [docs] update docs Add sequence parallel compatibility with transformers >= 5.4.0 Apr 21, 2026
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the minimum Python version requirement to 3.10 and refactors sequence parallel masking functions to improve compatibility with different versions of the transformers library. It also adjusts the loss normalization logic in the seq2seq trainer. However, the refactoring of flash_attention_mask and sdpa_mask introduces potential bugs: by switching to *args and **kwargs but only checking kwargs for specific parameters like attention_mask and cache_position, the code may fail to capture these values if they are passed positionally. This could result in missing masks or TypeError exceptions during execution.

Comment thread swift/sequence_parallel/ulysses.py
Comment thread swift/sequence_parallel/ulysses.py Outdated
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the minimum Python version to 3.10 and refactors attention masking and loss calculation logic to improve compatibility with newer transformer versions and sequence parallelism. Feedback highlights potential issues with positional argument handling in the sdpa_mask wrapper, concerns about the accuracy of the num_items_in_batch calculation for non-sequence-parallel training, and a discrepancy between the code and comments regarding the reduction group used for batch item counts.

Comment on lines +198 to +214
def sdpa_mask(*args, **kwargs):
if self.world_size == 1:
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)
device = cache_position.device
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](*args, **kwargs)
if 'cache_position' in kwargs:
device = kwargs['cache_position'].device
else:
# transformers>=5.4.0
device = kwargs['device']
cache_position = self.real_position_ids[0]
cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0)
cache_position = torch.arange(0, cache_position.shape[0], device=device)
kv_length = cache_position.shape[0]
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)
kwargs['kv_length'] = cache_position.shape[0]
if 'cache_position' in kwargs:
kwargs['cache_position'] = cache_position
else:
kwargs['q_length'] = kwargs['kv_length']
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to flash_attention_mask, the sdpa_mask wrapper fails to handle positional arguments correctly. If cache_position is passed positionally (index 1), it will not be detected in kwargs, leading to a KeyError when trying to access kwargs['device'] (or kwargs['cache_position']). Furthermore, updating kwargs while leaving the original positional args intact will cause a TypeError when calling sdpa_origin due to multiple values for the same parameter, or it will cause sdpa_origin to use the old positional value instead of the updated one.

if num_items_in_batch is None:
num_items_in_batch = (labels[:, 1:] != -100).sum()
# https://github.com/huggingface/transformers/blob/9dff7ca5c9693f4c02cdd2a9c2abc4772fcea5da/src/transformers/trainer.py#L2137
num_items_in_batch = (labels != -100).sum() # compat SP
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing the num_items_in_batch calculation to (labels != -100).sum() for all cases may introduce a slight inaccuracy in non-sequence-parallel training for Causal LMs. In the non-SP case, labels are typically not shifted yet, and the loss is computed on labels[:, 1:]. Using the full labels count as the denominator will result in a slightly lower loss value than intended. It is better to keep the conditional logic to ensure the count matches the tokens that actually contribute to the loss.

num_items_in_batch = (labels != -100).sum() if self.template.sequence_parallel_size > 1 else (labels[:, 1:] != -100).sum()  # compat SP

# to full length via GatherLoss. Reduce the denominator
# across the SP group so it matches the gathered loss.
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM, group=sequence_parallel.sp_group)
dist.all_reduce(num_items_in_batch, op=dist.ReduceOp.SUM)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The all_reduce call for num_items_in_batch now performs a global reduction (default group) instead of reducing only across the sp_group. However, the comment on lines 201-203 still states that the reduction is 'across the SP group'. If a global reduction is intended (e.g., to support global token-based averaging in newer transformers versions), the comment should be updated. Otherwise, if the goal is to reconstruct the local batch size for the SP group, the group argument should be restored.

@Jintao-Huang Jintao-Huang merged commit e7c9c7f into modelscope:main Apr 21, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants