diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index d3253d4b39..d6e0a68690 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -245,7 +245,7 @@ def _export_and_load_weights(self): # Colocate mode: load_weights supports iterator, pass directly llm_model = self.engine.inner_model llm_model.load_weights(weight_iterator) - elif self.vllm_mode == 'server' and self.is_main_process: + elif self.vllm_mode == 'server': # Server mode: process in buckets and sync with flattened tensors self._load_weights_to_server_in_buckets(weight_iterator) @@ -285,7 +285,7 @@ def _sync_bucket_to_server(self, bucket_params: List[Tuple[str, torch.Tensor]]): Args: bucket_params: List of (name, tensor) tuples to sync """ - if not bucket_params: + if not bucket_params or not self.is_main_process: return # Create FlattenedTensorBucket for efficient transfer