diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 08267b918f..9c37bda427 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -250,9 +250,9 @@ def forward( logits, _ = self.output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) else: - logits = self.output_layer(hidden_states)[0] if args.sequence_parallel and args.tensor_model_parallel_size > 1: - logits = gather_from_sequence_parallel_region(logits) + hidden_states = gather_from_sequence_parallel_region(hidden_states) + logits = self.output_layer(hidden_states)[0] if has_config_logger_enabled(self.config): payload = OrderedDict({ 'input_ids': input_ids,