Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ def forward(
logits, _ = self.output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
else:
logits = self.output_layer(hidden_states)[0]
if args.sequence_parallel and args.tensor_model_parallel_size > 1:
logits = gather_from_sequence_parallel_region(logits)
hidden_states = gather_from_sequence_parallel_region(hidden_states)
logits = self.output_layer(hidden_states)[0]
Comment on lines 253 to +255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change appears to fix a correctness issue for sequence classification with sequence parallelism by gathering hidden_states before the output_layer. While this ensures correctness, it may introduce a performance regression.

The new implementation involves:

  1. An all_gather operation on the potentially large hidden_states tensor.
  2. Redundant computation of the output_layer on all tensor parallel ranks, since each now computes the full logits.

The previous implementation gathered the smaller logits tensor, which was more efficient in terms of communication and avoided redundant computation.

If sequence classification relies only on the hidden state of a single token (e.g., the last one), could we consider a more optimized approach? For example:

  1. Compute the logit only on the rank that holds the required token's hidden state.
  2. Broadcast the resulting logit to all other ranks.

This would be more efficient in both communication and computation. If feasible, it would be a valuable optimization.

if has_config_logger_enabled(self.config):
payload = OrderedDict({
'input_ids': input_ids,
Expand Down
Loading