Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 38 additions & 45 deletions swift/megatron/trainers/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available,
remove_response)
from ..argument import MegatronArguments, MegatronRLHFArguments
from ..utils import forward_step_helper
from ..utils import forward_step_helper, get_padding_to
from .rlhf_mixin import MegatronRLHFTrainer
from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu,
load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer,
Expand All @@ -53,7 +53,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
self.hf_model_dir = args.model_info.model_dir
self.processing_class = self.template.processor
self._prepare_metrics()
self._prepare_template_data_collator()
self._init_grpo_params()
self._prepare_rewards()
self._prepare_scheduler() # TODO
Expand All @@ -66,21 +65,6 @@ def train(self, train_dataset, val_dataset, data_collator):
self._train_valid_test_dataset_provider.is_distributed = True
super().train(train_dataset, val_dataset, data_collator)

def _prepare_template_data_collator(self):
template = self.template
args = self.args
data_collator = template.data_collator
padding_to = None
if args.tensor_model_parallel_size > 1 and args.sequence_parallel:
padding_to = args.tensor_model_parallel_size
if args.context_parallel_size > 1:
padding_to = (padding_to or 1) * args.context_parallel_size
if args.fp8_format:
padding_to = max((padding_to or 1) * 8, 16)
logger.info(f'padding_to: {padding_to}')
data_collator = partial(data_collator, padding_to=padding_to)
template.data_collator = data_collator

def _init_grpo_params(self):
args: MegatronArguments = self.args
# distributed params
Expand Down Expand Up @@ -368,17 +352,16 @@ def _get_rollout_group(self):
Get or create the rollout process group (TP×PP×CP).

The rollout group is used for:
1. Data slicing: distributing rollout data across all model parallel ranks (including CP)
2. Gather operations: collecting results from all model parallel ranks (including CP)
1. Data slicing: distributing rollout data across ranks with same data samples
2. Gather operations: collecting results from ranks with same data samples

Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct
data distribution during rollout phase.
Note: Groups are created per data parallel index, containing TP×PP×CP ranks each.
This follows Megatron's data_iterator logic where same data_parallel_rank processes
identical data samples.

Key insight: ranks with the same DP index but different TP/PP/CP indices should be
in the same rollout group. These ranks will:
- During rollout: each process different data slices
- During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split)
- During gather: collect all data from TP×PP×CP ranks for training
Key insight: ranks with the SAME data parallel index process the SAME data samples
and must coordinate for rollout data distribution.
Megatron rank order: TP → CP → EP → DP → PP
"""
if self._rollout_group is not None:
return self._rollout_group
Expand All @@ -389,31 +372,38 @@ def _get_rollout_group(self):
self._rollout_group = mpu.get_model_parallel_group()
return self._rollout_group

# Use RankGenerator to create rollout groups following Megatron-LM logic
global_rank = torch.distributed.get_rank()

# Get parallel dimensions
tp_size = mpu.get_tensor_model_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
global_rank = torch.distributed.get_rank()

# Calculate rollout group size
rollout_group_size = tp_size * pp_size * cp_size

# Simple and reliable method: assume ranks are organized in contiguous blocks per DP group
# This is typically true for the default order (tp-cp-ep-dp-pp)
# Each DP group has rollout_group_size consecutive ranks
ranks_per_dp_group = rollout_group_size
my_dp_block_index = global_rank // ranks_per_dp_group
cp_size = mpu.get_context_parallel_world_size()

# Calculate the rank range for my rollout group
group_start = my_dp_block_index * ranks_per_dp_group
# Create RankGenerator following Megatron-LM pattern
# Order: tp-cp-ep-dp-pp (default in Megatron-LM)
decoder_rank_generator = mpu.RankGenerator(
tp=tp_size,
ep=1,
dp=dp_size,
pp=pp_size,
cp=cp_size,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)
Comment on lines +382 to +394
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The RankGenerator is initialized with a hardcoded ep=1. This could lead to incorrect rollout group creation if expert parallelism is used (i.e., expert_model_parallel_size > 1). The rollout group should encompass all model parallel dimensions for a given data parallel rank, including expert parallelism.

To make this more robust, I suggest fetching the expert parallel size from the arguments and using it to initialize the RankGenerator.

Suggested change
cp_size = mpu.get_context_parallel_world_size()
# Calculate the rank range for my rollout group
group_start = my_dp_block_index * ranks_per_dp_group
# Create RankGenerator following Megatron-LM pattern
# Order: tp-cp-ep-dp-pp (default in Megatron-LM)
decoder_rank_generator = mpu.RankGenerator(
tp=tp_size,
ep=1,
dp=dp_size,
pp=pp_size,
cp=cp_size,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)
cp_size = mpu.get_context_parallel_world_size()
ep_size = self.args.expert_model_parallel_size
# Create RankGenerator following Megatron-LM pattern
# Order: tp-cp-ep-dp-pp (default in Megatron-LM)
decoder_rank_generator = mpu.RankGenerator(
tp=tp_size,
ep=ep_size,
dp=dp_size,
pp=pp_size,
cp=cp_size,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same in Megatron-LM initialize_model_parallel, ignore


# Create all rollout groups (must be done on all ranks)
# Create rollout groups based on data consistency from data_iterator
# Same data_parallel_rank processes same data - group ranks with same DP index
if not hasattr(self, '_rollout_groups_created'):
for dp_idx in range(dp_size):
group_start = dp_idx * ranks_per_dp_group
group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size)))
group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP')
if global_rank in group_ranks:
# Use 'tp-cp-ep-pp' to get groups with same DP index (DP is excluded from variation)
dp_groups = decoder_rank_generator.get_ranks('tp-cp-ep-pp')
for dp_group_ranks in dp_groups:
# Sort for consistency
dp_group_ranks = sorted(dp_group_ranks)
group = torch.distributed.new_group(ranks=dp_group_ranks, group_desc='ROLLOUT_GROUP')

if global_rank in dp_group_ranks:
self._rollout_group = group
self._rollout_groups_created = True

Expand Down Expand Up @@ -488,6 +478,8 @@ def _replace_data_iterator(self, data_iterator, model):

def _generate_and_score_completions(self, batch):
# Get or create the rollout group (TP×PP×CP)
args = get_args()

rollout_group = self._get_rollout_group()

rollout_batch = self.get_local_rollout_batch(batch)
Expand All @@ -506,7 +498,8 @@ def _get_encoded_batch(rollout_batch, advantages):
template = self.template
with self._template_context(template):
encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch]
encoded_batch = to_device(template.data_collator(encoded_batch), self.device)
encoded_batch = to_device(
template.data_collator(encoded_batch, padding_to=get_padding_to(args)), self.device)
labels = encoded_batch['labels']
assert self.template.padding_free
position_ids = encoded_batch.get('text_position_ids')
Expand Down
2 changes: 1 addition & 1 deletion swift/megatron/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def forward_step_helper(model, inputs, dtype=None):
args = get_args()
if mpu.is_pipeline_first_stage():
micro_batch_size = 1 # use qkv_format 'thd'
seq_length = inputs['input_ids'].shape[1]
seq_length = inputs['position_ids'].shape[-1]
if args.sequence_parallel:
seq_length //= mpu.get_tensor_model_parallel_world_size()
recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size],
Expand Down
Loading