Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def initialize_megatron(*_args, **kwargs):
args = get_args()
data_parallel_size = mpu.get_data_parallel_world_size()
step_batch_size = args.micro_batch_size * data_parallel_size
num_generations = args.num_generations if hasattr(args, 'num_generations') else 1
num_generations = args.num_generations if args.rlhf_type == 'grpo' else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The direct access to args.rlhf_type will cause an AttributeError for training configurations that do not define this attribute, such as SFT. This will lead to a crash in non-RLHF training scenarios.

To prevent this, you should use getattr to safely access the rlhf_type attribute with a default value.

Suggested change
num_generations = args.num_generations if args.rlhf_type == 'grpo' else 1
num_generations = args.num_generations if getattr(args, 'rlhf_type', None) == 'grpo' else 1

if args.train_iters is None and args.max_epochs is not None:
if hasattr(train_dataset, '__len__'):
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
Expand Down
Loading