From eaff3e88b2f4ce9f411b131e123b69f64e811586 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 18 Nov 2025 16:30:52 +0800 Subject: [PATCH 1/5] fix cp mllm --- swift/megatron/trainers/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 594561cdd8..c5b89c135c 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -111,7 +111,8 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): if cp_size > 1: args = get_args() keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] - if args.is_multimodal: + is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' + if args.is_multimodal and not is_grpo: keys.append('decoder_input') else: keys.append('input_ids') From 61bbe18ad331cf9d0a444127f8cc41080c2b0ad9 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 18 Nov 2025 20:12:20 +0800 Subject: [PATCH 2/5] fix grpo cp --- swift/megatron/trainers/grpo_trainer.py | 63 ++++++++++++++----------- swift/megatron/trainers/utils.py | 3 +- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index d6e0a68690..8629f2f04d 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -74,7 +74,8 @@ def _prepare_template_data_collator(self): if args.tensor_model_parallel_size > 1 and args.sequence_parallel: padding_to = args.tensor_model_parallel_size if args.context_parallel_size > 1: - padding_to = (padding_to or 1) * args.context_parallel_size + # CP split uses 2*cp_size chunks for load balancing + padding_to = (padding_to or 1) * (1 * args.context_parallel_size) if args.fp8_format: padding_to = max((padding_to or 1) * 8, 16) logger.info(f'padding_to: {padding_to}') @@ -368,17 +369,16 @@ def _get_rollout_group(self): Get or create the rollout process group (TP×PP×CP). The rollout group is used for: - 1. Data slicing: distributing rollout data across all model parallel ranks (including CP) - 2. Gather operations: collecting results from all model parallel ranks (including CP) + 1. Data slicing: distributing rollout data across ranks with same data samples + 2. Gather operations: collecting results from ranks with same data samples - Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct - data distribution during rollout phase. + Note: Groups are created per data parallel index, containing TP×PP×CP ranks each. + This follows Megatron's data_iterator logic where same data_parallel_rank processes + identical data samples. - Key insight: ranks with the same DP index but different TP/PP/CP indices should be - in the same rollout group. These ranks will: - - During rollout: each process different data slices - - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split) - - During gather: collect all data from TP×PP×CP ranks for training + Key insight: ranks with the SAME data parallel index process the SAME data samples + and must coordinate for rollout data distribution. + Megatron rank order: TP → CP → EP → DP → PP """ if self._rollout_group is not None: return self._rollout_group @@ -389,31 +389,38 @@ def _get_rollout_group(self): self._rollout_group = mpu.get_model_parallel_group() return self._rollout_group + # Use RankGenerator to create rollout groups following Megatron-LM logic + global_rank = torch.distributed.get_rank() + # Get parallel dimensions tp_size = mpu.get_tensor_model_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() - global_rank = torch.distributed.get_rank() - - # Calculate rollout group size - rollout_group_size = tp_size * pp_size * cp_size - - # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group - # This is typically true for the default order (tp-cp-ep-dp-pp) - # Each DP group has rollout_group_size consecutive ranks - ranks_per_dp_group = rollout_group_size - my_dp_block_index = global_rank // ranks_per_dp_group + cp_size = mpu.get_context_parallel_world_size() - # Calculate the rank range for my rollout group - group_start = my_dp_block_index * ranks_per_dp_group + # Create RankGenerator following Megatron-LM pattern + # Order: tp-cp-ep-dp-pp (default in Megatron-LM) + decoder_rank_generator = mpu.RankGenerator( + tp=tp_size, + ep=1, + dp=dp_size, + pp=pp_size, + cp=cp_size, + order='tp-cp-ep-dp-pp', + rank_offset=0, + ) - # Create all rollout groups (must be done on all ranks) + # Create rollout groups based on data consistency from data_iterator + # Same data_parallel_rank processes same data - group ranks with same DP index if not hasattr(self, '_rollout_groups_created'): - for dp_idx in range(dp_size): - group_start = dp_idx * ranks_per_dp_group - group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size))) - group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP') - if global_rank in group_ranks: + # Use 'tp-cp-ep-pp' to get groups with same DP index (DP is excluded from variation) + dp_groups = decoder_rank_generator.get_ranks('tp-cp-ep-pp') + for dp_group_ranks in dp_groups: + # Sort for consistency + dp_group_ranks = sorted(dp_group_ranks) + group = torch.distributed.new_group(ranks=dp_group_ranks, group_desc='ROLLOUT_GROUP') + + if global_rank in dp_group_ranks: self._rollout_group = group self._rollout_groups_created = True diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index c5b89c135c..594561cdd8 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -111,8 +111,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): if cp_size > 1: args = get_args() keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] - is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' - if args.is_multimodal and not is_grpo: + if args.is_multimodal: keys.append('decoder_input') else: keys.append('input_ids') From bc150b29328460b24420f52aff122c273231c517 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 18 Nov 2025 20:13:40 +0800 Subject: [PATCH 3/5] clean up --- swift/megatron/trainers/grpo_trainer.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 8629f2f04d..18958ee37a 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -53,7 +53,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor self._prepare_metrics() - self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() self._prepare_scheduler() # TODO @@ -66,22 +65,6 @@ def train(self, train_dataset, val_dataset, data_collator): self._train_valid_test_dataset_provider.is_distributed = True super().train(train_dataset, val_dataset, data_collator) - def _prepare_template_data_collator(self): - template = self.template - args = self.args - data_collator = template.data_collator - padding_to = None - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - padding_to = args.tensor_model_parallel_size - if args.context_parallel_size > 1: - # CP split uses 2*cp_size chunks for load balancing - padding_to = (padding_to or 1) * (1 * args.context_parallel_size) - if args.fp8_format: - padding_to = max((padding_to or 1) * 8, 16) - logger.info(f'padding_to: {padding_to}') - data_collator = partial(data_collator, padding_to=padding_to) - template.data_collator = data_collator - def _init_grpo_params(self): args: MegatronArguments = self.args # distributed params From 13d40a4289813e3227a4fd18584561dea0d59127 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 18 Nov 2025 20:39:12 +0800 Subject: [PATCH 4/5] fix --- swift/megatron/trainers/grpo_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index 18958ee37a..e81e470994 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -32,7 +32,7 @@ from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available, remove_response) from ..argument import MegatronArguments, MegatronRLHFArguments -from ..utils import forward_step_helper +from ..utils import forward_step_helper, get_padding_to from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, @@ -478,6 +478,8 @@ def _replace_data_iterator(self, data_iterator, model): def _generate_and_score_completions(self, batch): # Get or create the rollout group (TP×PP×CP) + args = get_args() + rollout_group = self._get_rollout_group() rollout_batch = self.get_local_rollout_batch(batch) @@ -496,7 +498,8 @@ def _get_encoded_batch(rollout_batch, advantages): template = self.template with self._template_context(template): encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] - encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + encoded_batch = to_device( + template.data_collator(encoded_batch, padding_to=get_padding_to(args)), self.device) labels = encoded_batch['labels'] assert self.template.padding_free position_ids = encoded_batch.get('text_position_ids') From 92758f3c3258eb2a9f62c86fd40ba59bd2a06511 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 18 Nov 2025 21:45:52 +0800 Subject: [PATCH 5/5] fix --- swift/megatron/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index c7b15cf652..ba4af92f7b 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -279,7 +279,7 @@ def forward_step_helper(model, inputs, dtype=None): args = get_args() if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] + seq_length = inputs['position_ids'].shape[-1] if args.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size],