From e7ab82ce362987615793d2ee76cff22e563f2bb2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 3 Oct 2025 17:39:50 +0800 Subject: [PATCH 01/21] update --- swift/megatron/trainers/base.py | 22 ++++++- swift/megatron/trainers/dpo_trainer.py | 2 +- swift/megatron/trainers/kto_trainer.py | 65 ++++++++++++++++++++- swift/megatron/trainers/trainer.py | 2 +- swift/megatron/trainers/utils.py | 80 -------------------------- 5 files changed, 85 insertions(+), 86 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9bff200dcb..9b03675cad 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -123,7 +123,7 @@ def new_cyclic_iter(self, iterable): yield x i += 1 - def _replace_data_iterator(self, data_iterator): + def _replace_data_iterator(self, data_iterator, model): return data_iterator @staticmethod @@ -325,7 +325,7 @@ def _all_reduce_metric(self, metric: Dict[str, torch.Tensor]) -> Dict[str, torch return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - new_data_iterator = self._replace_data_iterator(data_iterator) + new_data_iterator = self._replace_data_iterator(data_iterator, model) return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, config) @@ -374,7 +374,7 @@ def evaluate(self, # Don't care about timing during evaluation config.timers = None ft_integration.on_eval_step_start() - new_data_iterator = self._replace_data_iterator(data_iterator) + new_data_iterator = self._replace_data_iterator(data_iterator, model) loss_dicts = forward_backward_func( forward_step_func=forward_step_func, data_iterator=new_data_iterator, @@ -867,3 +867,19 @@ def _forward_step_helper(model, inputs): output_tensor = None return output_tensor + + def get_batch(self, data_iterator, vp_stage=None): + """Generate a batch.""" + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator, vp_stage=vp_stage) + args = get_args() + num_samples = batch.pop('num_samples') + text_position_ids = batch.pop('text_position_ids', None) + if text_position_ids is None: + text_position_ids = batch.get('position_ids') + if args.padding_free and text_position_ids is not None: + batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) + batch['packed_seq_params'].num_samples = num_samples + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + return batch diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 783dea8a6e..ab1de23857 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -108,7 +108,7 @@ def forward_step(self, data_iterator, model): ref_model.set_input_tensor(input_tensor[:input_tensor.shape[0] // 2].detach()) timers('batch-generator', log_level=2).start() with self.stimer(bdata=True): - data = get_batch(data_iterator, vp_stage) + data = self.get_batch(data_iterator, vp_stage) timers('batch-generator').stop() data.pop('loss_scale', None) ref_output_tensor = ref_model(**data) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 46bd4deb72..0af1aec0e1 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -182,7 +182,7 @@ def _replace_data_iterator_with_model(self, data_iterator, model): def ref_forward(self, ref_model, data_iterator): with self.stimer(bdata=True): - data = get_kto_batch(data_iterator) + data = self.get_batch(data_iterator) data.pop('loss_scale', None) ref_inputs = { @@ -251,3 +251,66 @@ def evaluate(self, self._replace_data_iterator = partial(self._replace_data_iterator_with_model, model=model) return super().evaluate(forward_step_func, data_iterator, model, process_non_loss_data_func, config, verbose, non_loss_data_func) + + def get_batch(data_iterator): + """Generate a kto batch.""" + args = get_args() + + data = next(data_iterator) + is_finished = data.pop('is_finished', False) + + batch = to_device(data, 'cuda', non_blocking=True) + + kto_tensor_keys = [ + 'completion_input_ids', 'completion_labels', 'completion_attention_mask', 'completion_position_ids', + 'KL_completion_input_ids', 'KL_completion_labels', 'KL_completion_attention_mask', + 'KL_completion_position_ids' + ] + + # pp + if args.pipeline_model_parallel_size == 1: + pass + elif mpu.is_pipeline_first_stage(): + for key in kto_tensor_keys: + if 'labels' in key: + batch[key] = None + elif mpu.is_pipeline_last_stage(): + for key in kto_tensor_keys: + if 'input_ids' in key: + batch[key] = None + else: + for key in kto_tensor_keys: + batch[key] = None + + # Padding-Free + num_samples = batch.get('num_samples') + if args.padding_free: + if 'completion_position_ids' in batch and batch['completion_position_ids'] is not None: + batch['completion_packed_seq_params'] = get_packed_seq_params(batch['completion_position_ids']) + if num_samples is not None: + batch['completion_packed_seq_params'].num_samples = num_samples + + if 'KL_completion_position_ids' in batch and batch['KL_completion_position_ids'] is not None: + batch['KL_completion_packed_seq_params'] = get_packed_seq_params(batch['KL_completion_position_ids']) + if num_samples is not None: + batch['KL_completion_packed_seq_params'].num_samples = num_samples + + # cp + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + completion_psp = batch.get('completion_packed_seq_params') + kl_psp = batch.get('KL_completion_packed_seq_params') + + if completion_psp is None and kl_psp is None: + batch = mcore_get_batch_on_this_cp_rank(batch) + else: + for key, val in batch.items(): + if key in kto_tensor_keys and val is not None: + if key.startswith('KL_completion_') and kl_psp is not None: + batch[key] = split_cp_inputs(val, kl_psp.cu_seqlens_q, -1) + elif key.startswith('completion_') and completion_psp is not None: + batch[key] = split_cp_inputs(val, completion_psp.cu_seqlens_q, -1) + + if is_finished: + args.train_iters = args.curr_iteration + 1 + return batch diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 7e87f63eb7..95f70a662e 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -139,7 +139,7 @@ def forward_step(self, data_iterator, model): vp_stage = model.module.module.vp_stage timers('batch-generator', log_level=2).start() with self.stimer(bdata=True): - data = get_batch(data_iterator, vp_stage) + data = self.get_batch(data_iterator, vp_stage) timers('batch-generator').stop() loss_scale = data.pop('loss_scale', None) channels = data.pop('channel', None) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 0a301e479d..9f77e0aaba 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -115,83 +115,3 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) return batch - - -def get_batch(data_iterator, vp_stage=None): - """Generate a batch.""" - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank(data_iterator, vp_stage=vp_stage) - args = get_args() - num_samples = batch.pop('num_samples') - text_position_ids = batch.pop('text_position_ids', None) - if text_position_ids is None: - text_position_ids = batch.get('position_ids') - if args.padding_free and text_position_ids is not None: - batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - return batch - - -def get_kto_batch(data_iterator): - """Generate a kto batch.""" - args = get_args() - - data = next(data_iterator) - is_finished = data.pop('is_finished', False) - - batch = to_device(data, 'cuda', non_blocking=True) - - kto_tensor_keys = [ - 'completion_input_ids', 'completion_labels', 'completion_attention_mask', 'completion_position_ids', - 'KL_completion_input_ids', 'KL_completion_labels', 'KL_completion_attention_mask', 'KL_completion_position_ids' - ] - - # pp - if args.pipeline_model_parallel_size == 1: - pass - elif mpu.is_pipeline_first_stage(): - for key in kto_tensor_keys: - if 'labels' in key: - batch[key] = None - elif mpu.is_pipeline_last_stage(): - for key in kto_tensor_keys: - if 'input_ids' in key: - batch[key] = None - else: - for key in kto_tensor_keys: - batch[key] = None - - # Padding-Free - num_samples = batch.get('num_samples') - if args.padding_free: - if 'completion_position_ids' in batch and batch['completion_position_ids'] is not None: - batch['completion_packed_seq_params'] = get_packed_seq_params(batch['completion_position_ids']) - if num_samples is not None: - batch['completion_packed_seq_params'].num_samples = num_samples - - if 'KL_completion_position_ids' in batch and batch['KL_completion_position_ids'] is not None: - batch['KL_completion_packed_seq_params'] = get_packed_seq_params(batch['KL_completion_position_ids']) - if num_samples is not None: - batch['KL_completion_packed_seq_params'].num_samples = num_samples - - # cp - cp_size = mpu.get_context_parallel_world_size() - if cp_size > 1: - completion_psp = batch.get('completion_packed_seq_params') - kl_psp = batch.get('KL_completion_packed_seq_params') - - if completion_psp is None and kl_psp is None: - batch = mcore_get_batch_on_this_cp_rank(batch) - else: - for key, val in batch.items(): - if key in kto_tensor_keys and val is not None: - if key.startswith('KL_completion_') and kl_psp is not None: - batch[key] = split_cp_inputs(val, kl_psp.cu_seqlens_q, -1) - elif key.startswith('completion_') and completion_psp is not None: - batch[key] = split_cp_inputs(val, completion_psp.cu_seqlens_q, -1) - - if is_finished: - args.train_iters = args.curr_iteration + 1 - return batch From 0c89beedf1bb5a071bcc3439538bd0bea1e37933 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 3 Oct 2025 17:46:02 +0800 Subject: [PATCH 02/21] update --- swift/megatron/trainers/base.py | 8 +++----- swift/megatron/trainers/kto_trainer.py | 24 ++++------------------ swift/trainers/rlhf_trainer/kto_trainer.py | 2 +- 3 files changed, 8 insertions(+), 26 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9b03675cad..619fe1488e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -20,7 +20,7 @@ from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector -from megatron.training import (ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, +from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, get_wandb_writer, is_last_rank, one_logger_utils, pretrain, print_rank_0, print_rank_last, training) from megatron.training.checkpointing import load_checkpoint @@ -35,7 +35,8 @@ from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model -from .utils import get_swift_datasets_provider +from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, + get_swift_datasets_provider) logger = get_logger() @@ -151,7 +152,6 @@ def sh_ten_merge_fn(sub_state_dict): def _load_adapter_base_checkpoint(self, *_args, **kwargs): adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter' - from megatron.training import checkpointing sharded_state_dict = kwargs.get('sharded_state_dict') if sharded_state_dict is None: return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) @@ -180,7 +180,6 @@ def _load_adapter_base_checkpoint(self, *_args, **kwargs): return res def _load_base_checkpoint(self, *_args, **kwargs): - from megatron.training import checkpointing sharded_state_dict = kwargs.get('sharded_state_dict') if sharded_state_dict is None: return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) @@ -224,7 +223,6 @@ def _load_base_checkpoint(self, *_args, **kwargs): @contextmanager def _patch_load_state_dict(self, load_base_checkpoint): - from megatron.training import checkpointing checkpointing.origin__load_base_checkpoint = checkpointing._load_base_checkpoint checkpointing._load_base_checkpoint = load_base_checkpoint diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 0af1aec0e1..21c9a41292 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -10,9 +10,10 @@ from megatron.training.utils import unwrap_model from torch.distributed.nn import all_reduce +from swift.llm import to_device from swift.utils import get_current_device, get_logger from .base import MegatronRLHFTrainer -from .utils import get_kto_batch +from .utils import get_packed_seq_params, mcore_get_batch_on_this_cp_rank, split_cp_inputs logger = get_logger() @@ -150,7 +151,7 @@ def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, referenc loss = loss / mpu.get_context_parallel_world_size() return loss, reporting_metric - def _replace_data_iterator_with_model(self, data_iterator, model): + def _replace_data_iterator(self, data_iterator, model): args = get_args() num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) @@ -209,11 +210,6 @@ def ref_forward(self, ref_model, data_iterator): data['reference_KL_logps'] = None return data - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - new_data_iterator = self._replace_data_iterator_with_model(data_iterator, model) - return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, - config) - def forward_step(self, data_iterator, model): data = next(data_iterator) @@ -240,19 +236,7 @@ def forward_step(self, data_iterator, model): all_labels=all_labels, packed_seq_params=completion_packed_seq_params) - def evaluate(self, - forward_step_func, - data_iterator, - model, - process_non_loss_data_func, - config, - verbose=False, - non_loss_data_func=None): - self._replace_data_iterator = partial(self._replace_data_iterator_with_model, model=model) - return super().evaluate(forward_step_func, data_iterator, model, process_non_loss_data_func, config, verbose, - non_loss_data_func) - - def get_batch(data_iterator): + def get_batch(self, data_iterator, vp_stage=None): """Generate a kto batch.""" args = get_args() diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainer/kto_trainer.py index 9d93abe5bf..f24529aa68 100644 --- a/swift/trainers/rlhf_trainer/kto_trainer.py +++ b/swift/trainers/rlhf_trainer/kto_trainer.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch import torch.nn as nn From 01c9dfe0d1f32da00a631237238f91f4d4138002 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 3 Oct 2025 21:30:48 +0800 Subject: [PATCH 03/21] update --- swift/megatron/trainers/dpo_trainer.py | 1 - swift/megatron/trainers/kto_trainer.py | 57 +++++++------------------- 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index ab1de23857..6a2cd5ce0a 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -10,7 +10,6 @@ from swift.trainers import DPOTrainer from swift.utils import get_current_device, get_logger from .base import MegatronRLHFTrainer -from .utils import get_batch logger = get_logger() diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 21c9a41292..655245f723 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -1,14 +1,14 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple from functools import partial import torch -import torch.distributed as dist -import torch.nn.functional as F from megatron.core import mpu from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy from megatron.training import get_args from megatron.training.utils import unwrap_model from torch.distributed.nn import all_reduce +from trl import KTOTrainer from swift.llm import to_device from swift.utils import get_current_device, get_logger @@ -18,14 +18,23 @@ logger = get_logger() +class DummyKTOTrainer(KTOTrainer): + # For reusing the dpo_loss function in TRL. + def __init__(self, args): + self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) + self.loss_type = args.loss_type + self.beta = args.beta + self.desirable_weight = args.desirable_weight + self.undesirable_weight = args.undesirable_weight + self.calculate_KL = args.calculate_KL + + class MegatronKTOTrainer(MegatronRLHFTrainer): def __init__(self, args, template): super().__init__(args, template) - self.beta = args.beta - self.desirable_weight = args.desirable_weight - self.undesirable_weight = args.undesirable_weight self.calculate_KL = args.calculate_KL + self.dummy_kto_trainer = DummyKTOTrainer(args) @staticmethod def get_logps(output_tensor, labels, packed_seq_params=None): @@ -64,37 +73,6 @@ def get_logps(output_tensor, labels, packed_seq_params=None): return all_logps - @staticmethod - def kto_loss(policy_chosen_logps, policy_rejected_logps, policy_KL_logps, reference_chosen_logps, - reference_rejected_logps, reference_KL_logps, beta, desirable_weight, undesirable_weight, calculate_KL, - device): - if calculate_KL and policy_KL_logps is not None and reference_KL_logps is not None: - kl = (policy_KL_logps - reference_KL_logps).mean().detach() - dist.all_reduce(kl, group=mpu.get_data_parallel_group()) - kl = kl / mpu.get_data_parallel_world_size() - kl = kl.clamp(min=0) - else: - kl = torch.tensor(0.0, device=device) - - chosen_rewards = torch.tensor([], device=kl.device) - if policy_chosen_logps.shape[0] > 0: - chosen_logratios = policy_chosen_logps - reference_chosen_logps - chosen_losses = 1 - F.sigmoid(beta * (chosen_logratios - kl)) - chosen_rewards = beta * chosen_logratios.detach() - else: - chosen_losses = torch.tensor([], device=kl.device) - - rejected_rewards = torch.tensor([], device=kl.device) - if policy_rejected_logps.shape[0] > 0: - rejected_logratios = policy_rejected_logps - reference_rejected_logps - rejected_losses = 1 - F.sigmoid(beta * (kl - rejected_logratios)) - rejected_rewards = beta * rejected_logratios.detach() - else: - rejected_losses = torch.tensor([], device=kl.device) - - losses = torch.cat((desirable_weight * chosen_losses, undesirable_weight * rejected_losses), 0) - return losses, chosen_rewards, rejected_rewards, kl - def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, reference_KL_logps, labels, all_labels, packed_seq_params): policy_logps = self.get_logps(output_tensor, labels, packed_seq_params) @@ -105,18 +83,13 @@ def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, referenc reference_chosen_logps = reference_logps[is_desirable] reference_rejected_logps = reference_logps[~is_desirable] - loss, chosen_rewards, rejected_rewards, kl = self.kto_loss( + loss, chosen_rewards, rejected_rewards, kl = self.dummy_kto_trainer.kto_loss( policy_chosen_logps=policy_chosen_logps, policy_rejected_logps=policy_rejected_logps, policy_KL_logps=policy_KL_logps, reference_chosen_logps=reference_chosen_logps, reference_rejected_logps=reference_rejected_logps, reference_KL_logps=reference_KL_logps, - beta=self.beta, - desirable_weight=self.desirable_weight, - undesirable_weight=self.undesirable_weight, - calculate_KL=self.calculate_KL, - device=policy_logps.device, ) loss = loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=policy_logps.device) From f622656e40f3b7af80294c3379ff21b0537c84c9 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 4 Oct 2025 14:19:29 +0800 Subject: [PATCH 04/21] update --- .../Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" | 2 +- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/BestPractices/Qwen3-Best-Practice.md | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git "a/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" index fe9d94ba91..e560ae258c 100644 --- "a/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -328,7 +328,7 @@ swift rlhf \ Qwen3-235B-A22B-Instruct-250718 单机8卡H20 LoRA训练的最佳实践参考:[https://github.com/modelscope/ms-swift/pull/5033](https://github.com/modelscope/ms-swift/pull/5033)。 -ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO。支持的模型可以在[支持的模型文档](../Instruction/支持的模型和数据集.md)中找到。 +ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO/KTO。支持的模型可以在[支持的模型文档](../Instruction/支持的模型和数据集.md)中找到。 关于环境准备以及 HF 和 MCore 模型权重的转换,可以参考[Megatron-SWIFT训练文档](../Megatron-SWIFT/快速开始.md)。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ea97a35fe2..970b19a3a4 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -252,7 +252,7 @@ Megatron训练参数继承自Megatron参数和基本参数(与ms-swift共用da - mlp_padding_free: 默认为False。用于padding_free设置为false时,对mlp进行padding_free优化。这可以在自定义attention_mask的同时,提升训练速度和减少显存占用。 - vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为True。 - gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。 -- 🔥packing: 是否使用序列packing,默认为False。当前支持CPT/SFT/DPO。 +- 🔥packing: 是否使用序列packing,默认为False。当前支持CPT/SFT/DPO/KTO。 - packing_length: packing的长度。默认为None,设置为max_length。 - streaming: 流式读取并处理数据集,默认False。 - 注意:因为流式数据集无法获得其长度,因此需要设置`--train_iters`参数。设置`max_epochs`参数确保训练到对应epochs时退出训练,并对权重进行验证和保存。 diff --git a/docs/source_en/BestPractices/Qwen3-Best-Practice.md b/docs/source_en/BestPractices/Qwen3-Best-Practice.md index ffc50effdd..931e8b1834 100644 --- a/docs/source_en/BestPractices/Qwen3-Best-Practice.md +++ b/docs/source_en/BestPractices/Qwen3-Best-Practice.md @@ -332,7 +332,7 @@ swift rlhf \ Best practice reference for single-node 8xH20 LoRA training with Qwen3-235B-A22B-Instruct-250718: https://github.com/modelscope/ms-swift/pull/5033. -ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO for large models. Supported models can be found in the [Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO/KTO for large models. Supported models can be found in the [Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md). For environment setup and conversion between HF and MCore model weights, refer to the [Megatron-SWIFT Training Documentation](../Megatron-SWIFT/Quick-start.md). diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 7a9b50c880..73507d15fc 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -267,7 +267,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - mlp_padding_free: The default is False. This is used for applying padding-free optimization to the MLP when padding_free is set to false. It allows for improved training speed and reduced memory usage while customizing the attention_mask. - vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT part during multimodal model training. Default: True. - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Default: None. -- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports CPT/SFT/DPO. +- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports CPT/SFT/DPO/KTO. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - streaming: Stream data loading and processing, default is False. - Note: Since the length of a streaming dataset cannot be determined, the `--train_iters` parameter must be set. Also set the `max_epochs` parameter to ensure training exits after the specified number of epochs, and to validate and save the model weights accordingly. From a0a866fae52248d7cef2c7b1894a1895afbdbd43 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 4 Oct 2025 14:38:15 +0800 Subject: [PATCH 05/21] update --- ...5\277\253\351\200\237\345\274\200\345\247\213.md" | 2 +- ...4\273\244\350\241\214\345\217\202\346\225\260.md" | 12 +++++++----- ...6\250\241\346\200\201\346\250\241\345\236\213.md" | 2 +- docs/source_en/GetStarted/Quick-start.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 12 +++++++----- docs/source_en/Megatron-SWIFT/Multimodal-Model.md | 2 +- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" index aea59cbc83..fcb4009898 100644 --- "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -10,7 +10,7 @@ ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架 - 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 - 🍊 RLHF训练:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。 - 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 -- 🥥 Megatron并行技术:支持使用Megatron并行技术对CPT/SFT/DPO进行加速,现支持200+大语言模型。 +- 🥥 Megatron并行技术:支持使用Megatron并行技术对CPT/SFT/DPO/KTO进行加速,现支持200+大语言模型。 - 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 - 🍉 工具箱能力:除了对大模型和多模态大模型的训练支持外,还支持其推理、评测、量化和部署全流程。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 970b19a3a4..1a51eaa66a 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -233,13 +233,15 @@ lora训练: - reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 - label_smoothing: 默认为0.。 - f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 -- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 +- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 **KTO参数**: -- beta: KL正则项系数,默认为`0.1`。 -- desirable_weight: KTO算法中对desirable response的loss权重 $\lambda_D$,默认为`1.`。 -- undesirable_weight: KTO算法中对undesirable response的loss权重 $\lambda_U$,默认为`1.`。 -- calculate_KL: 是否计算KL散度,默认为True。 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 +- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 +- desirable_weight: 抵消 desirable 和 undesirable 配对数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 +- undesirable_weight: 抵消 desirable 和 undesirable 配对数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 ## 训练参数 diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 16edbd35c5..6cd7b192de 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -1,6 +1,6 @@ # 多模态模型 -ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO/KTO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 环境准备请参考Megatron-SWIFT的[快速开始文档](./快速开始.md)。 diff --git a/docs/source_en/GetStarted/Quick-start.md b/docs/source_en/GetStarted/Quick-start.md index 845570a712..d36396023c 100644 --- a/docs/source_en/GetStarted/Quick-start.md +++ b/docs/source_en/GetStarted/Quick-start.md @@ -10,7 +10,7 @@ ms-swift is a comprehensive training and deployment framework for large language - Quantization Training: Provides training for quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. - 🍊 RLHF Training: Supports human alignment training methods like DPO, GRPO, RM, PPO, GKD, KTO, CPO, SimPO, ORPO for both text-based and multimodal large models. - 🍓 Multimodal Training: Capable of training models for different modalities such as images, videos, and audios; supports tasks like VQA (Visual Question Answering), Captioning, OCR (Optical Character Recognition), and Grounding. -- 🥥 Megatron Parallelism: Supports accelerating CPT/SFT/DPO using Megatron parallelism techniques, currently compatible with 200+ large language models. +- 🥥 Megatron Parallelism: Supports accelerating CPT/SFT/DPO/KTO using Megatron parallelism techniques, currently compatible with 200+ large language models. - Interface-driven Training: Offers training, inference, evaluation, and quantization capabilities through an interface, enabling a complete workflow for large models. - Plugins and Extensions: Allows customization and extension of models and datasets, and supports customizations for components like loss, metric, trainer, loss-scale, callback, optimizer, etc. - 🍉 Toolbox Capabilities: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 73507d15fc..b52b9089aa 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -248,13 +248,15 @@ LoRA Training: - reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. - label_smoothing: Default is 0. - f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. -- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. +- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. **KTO Parameters**: -- beta: Coefficient for the KL regularization term. Default is `0.1`. -- desirable_weight: Loss weight $\lambda_D$ for desirable response in the KTO algorithm, default is `1.`. -- undesirable_weight: Loss weight $\lambda_U$ for undesirable response in the KTO algorithm, default is `1.`. -- calculate_KL: Whether to calculate KL divergence. Default is `True`. +- ref_load: same meaning as in DPO. +- ref_adapter_load: same meaning as in DPO. +- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. +- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. +- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. ## Training Parameters diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index a595946cb5..c72850c08f 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # Multimodal Models -ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO/KTO for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). From 3176ede45b7a7b23531843880492a8c4b221f910 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 4 Oct 2025 20:40:12 +0800 Subject: [PATCH 06/21] update --- swift/megatron/argument/megatron_args.py | 25 +++- swift/megatron/argument/rlhf_args.py | 4 - swift/megatron/trainers/dpo_trainer.py | 15 +- swift/megatron/trainers/kto_trainer.py | 179 ++++------------------- swift/megatron/trainers/trainer.py | 1 - 5 files changed, 64 insertions(+), 160 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 92834e44d7..f403d19e65 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -17,6 +17,7 @@ @dataclass class RLHFMegatronArgumentsMixin: + rlhf_type: Literal['dpo', 'kto', None] = None ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -25,7 +26,28 @@ class RLHFMegatronArgumentsMixin: reference_free: bool = False label_smoothing: float = 0. f_divergence_type: str = 'reverse_kl' - loss_type: str = 'sigmoid' + loss_type: Optional[str] = None + + # kto + desirable_weight: float = 1. + undesirable_weight: float = 1. + calculate_KL: Optional[bool] = None + + def _init_kto(self): + if self.calculate_KL is None: + # Not all losses require a KL calculation + self.calculate_KL = True + if self.loss_type in ['apo_zero_unpaired']: + self.calculate_KL = False + + def __post_init__(self): + if self.rlhf_type is None: + return + default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'} + if self.loss_type is None: + self.loss_type = default_loss_type[self.rlhf_type] + if self.rlhf_type == 'kto': + self._init_kto() @dataclass @@ -403,6 +425,7 @@ def __post_init__(self): require_version('peft>=0.15') else: require_version('peft>=0.12') + RLHFMegatronArgumentsMixin.__post_init__(self) MegatronTunerMixin.__post_init__(self) os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index 304b8b58fc..74c8c29c1b 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -11,7 +11,3 @@ class MegatronRLHFArguments(MegatronTrainArguments): loss_scale: str = 'last_round' calculate_per_token_loss: bool = False - - desirable_weight: float = 1. - undesirable_weight: float = 1. - calculate_KL: bool = True diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 6a2cd5ce0a..841003382b 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -98,20 +98,21 @@ def forward_step(self, data_iterator, model): # Get the batch. unwrapped_model = model.module.module input_tensor = unwrapped_model.get_input_tensor() - if input_tensor is not None: - unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:]) vp_stage = unwrapped_model.vp_stage + timers('batch-generator', log_level=2).start() + with self.stimer(bdata=True): + data = self.get_batch(data_iterator, vp_stage) + timers('batch-generator').stop() + data.pop('loss_scale', None) + # ref_model with torch.no_grad(), self.null_ref_context() as ref_models: ref_model = ref_models[vp_stage or 0] if input_tensor is not None: ref_model.set_input_tensor(input_tensor[:input_tensor.shape[0] // 2].detach()) - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() - data.pop('loss_scale', None) ref_output_tensor = ref_model(**data) + if input_tensor is not None: + unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:]) with self.stimer: output_tensor = model(**data) return torch.concat([ref_output_tensor, output_tensor], dim=0), partial( diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 655245f723..d34b62cc67 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -5,15 +5,12 @@ import torch from megatron.core import mpu from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy -from megatron.training import get_args -from megatron.training.utils import unwrap_model +from megatron.training import get_args, get_timers from torch.distributed.nn import all_reduce from trl import KTOTrainer -from swift.llm import to_device from swift.utils import get_current_device, get_logger from .base import MegatronRLHFTrainer -from .utils import get_packed_seq_params, mcore_get_batch_on_this_cp_rank, split_cp_inputs logger = get_logger() @@ -33,7 +30,6 @@ class MegatronKTOTrainer(MegatronRLHFTrainer): def __init__(self, args, template): super().__init__(args, template) - self.calculate_KL = args.calculate_KL self.dummy_kto_trainer = DummyKTOTrainer(args) @staticmethod @@ -124,150 +120,39 @@ def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, referenc loss = loss / mpu.get_context_parallel_world_size() return loss, reporting_metric - def _replace_data_iterator(self, data_iterator, model): - args = get_args() - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - - processed_data_list = [] - policy_model = unwrap_model(model)[0] - - for _ in range(num_iters_per_step): - with torch.no_grad(), self.null_ref_context() as ref_models: - assert len(ref_models) == 1, 'KTO currently does not support VPP.' - data = self.ref_forward(ref_models[0], data_iterator) - - if self.calculate_KL: - with torch.no_grad(): - kl_inputs = { - 'input_ids': data.get('KL_completion_input_ids'), - 'attention_mask': data.get('KL_completion_attention_mask'), - 'position_ids': data.get('KL_completion_position_ids'), - } - - kl_output_tensor = self._forward_step_helper(policy_model, kl_inputs) - - policy_KL_logps = self.get_logps(kl_output_tensor, data['KL_completion_labels'], - data.get('KL_completion_packed_seq_params')) - data['policy_KL_logps'] = policy_KL_logps - - processed_data_list.append(data) - - return iter(processed_data_list) + @staticmethod + def _get_input_tensor(input_tensor, is_ref: bool, is_KL: bool): + i = (not is_ref) * 2 + is_KL + return input_tensor[i:i + 1] - def ref_forward(self, ref_model, data_iterator): + def forward_step(self, data_iterator, model): + timers = get_timers() + # Get the batch. + unwrapped_model = model.module.module + input_tensor = unwrapped_model.get_input_tensor() + vp_stage = unwrapped_model.vp_stage + timers('batch-generator', log_level=2).start() with self.stimer(bdata=True): - data = self.get_batch(data_iterator) + data = self.get_batch(data_iterator, vp_stage) + timers('batch-generator').stop() data.pop('loss_scale', None) - ref_inputs = { - 'input_ids': data.get('completion_input_ids'), - 'attention_mask': data.get('completion_attention_mask'), - 'position_ids': data.get('completion_position_ids'), - } + with torch.no_grad(), self.null_ref_context() as ref_models: + ref_model = ref_models[vp_stage or 0] + if input_tensor is not None: + ref_model.set_input_tensor(self._get_input_tensor(True, False).detach()) + ref_output_tensor = ref_model(**data) + if input_tensor is not None: + ref_model.set_input_tensor(self._get_input_tensor(True, True).detach()) + ref_KL_output_tensor = ref_model(**data) with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, ref_inputs) - data['reference_logps'] = self.get_logps(output_tensor, data['completion_labels'], - data.get('completion_packed_seq_params')) - - if self.calculate_KL: - kl_inputs = { - 'input_ids': data.get('KL_completion_input_ids'), - 'attention_mask': data.get('KL_completion_attention_mask'), - 'position_ids': data.get('KL_completion_position_ids'), - } - with torch.no_grad(): - kl_output_tensor = self._forward_step_helper(ref_model, kl_inputs) - data['reference_KL_logps'] = self.get_logps(kl_output_tensor, data['KL_completion_labels'], - data.get('KL_completion_packed_seq_params')) - else: - data['reference_KL_logps'] = None - return data - - def forward_step(self, data_iterator, model): - data = next(data_iterator) - - reference_logps = data.pop('reference_logps') - reference_KL_logps = data.pop('reference_KL_logps', None) - policy_KL_logps = data.pop('policy_KL_logps', None) - all_labels = torch.tensor(data.pop('label')).to(get_current_device()) - completion_packed_seq_params = data.get('completion_packed_seq_params') - - main_inputs = { - 'input_ids': data['completion_input_ids'], - 'attention_mask': data.get('completion_attention_mask'), - 'position_ids': data.get('completion_position_ids') - } - with self.stimer(): - output_tensor = model(**main_inputs) - - return output_tensor, partial( - self.loss_func, - policy_KL_logps=policy_KL_logps, - reference_logps=reference_logps, - reference_KL_logps=reference_KL_logps, - labels=data['completion_labels'], - all_labels=all_labels, - packed_seq_params=completion_packed_seq_params) - - def get_batch(self, data_iterator, vp_stage=None): - """Generate a kto batch.""" - args = get_args() - - data = next(data_iterator) - is_finished = data.pop('is_finished', False) - - batch = to_device(data, 'cuda', non_blocking=True) - - kto_tensor_keys = [ - 'completion_input_ids', 'completion_labels', 'completion_attention_mask', 'completion_position_ids', - 'KL_completion_input_ids', 'KL_completion_labels', 'KL_completion_attention_mask', - 'KL_completion_position_ids' - ] - - # pp - if args.pipeline_model_parallel_size == 1: - pass - elif mpu.is_pipeline_first_stage(): - for key in kto_tensor_keys: - if 'labels' in key: - batch[key] = None - elif mpu.is_pipeline_last_stage(): - for key in kto_tensor_keys: - if 'input_ids' in key: - batch[key] = None - else: - for key in kto_tensor_keys: - batch[key] = None - - # Padding-Free - num_samples = batch.get('num_samples') - if args.padding_free: - if 'completion_position_ids' in batch and batch['completion_position_ids'] is not None: - batch['completion_packed_seq_params'] = get_packed_seq_params(batch['completion_position_ids']) - if num_samples is not None: - batch['completion_packed_seq_params'].num_samples = num_samples - - if 'KL_completion_position_ids' in batch and batch['KL_completion_position_ids'] is not None: - batch['KL_completion_packed_seq_params'] = get_packed_seq_params(batch['KL_completion_position_ids']) - if num_samples is not None: - batch['KL_completion_packed_seq_params'].num_samples = num_samples - - # cp - cp_size = mpu.get_context_parallel_world_size() - if cp_size > 1: - completion_psp = batch.get('completion_packed_seq_params') - kl_psp = batch.get('KL_completion_packed_seq_params') - - if completion_psp is None and kl_psp is None: - batch = mcore_get_batch_on_this_cp_rank(batch) - else: - for key, val in batch.items(): - if key in kto_tensor_keys and val is not None: - if key.startswith('KL_completion_') and kl_psp is not None: - batch[key] = split_cp_inputs(val, kl_psp.cu_seqlens_q, -1) - elif key.startswith('completion_') and completion_psp is not None: - batch[key] = split_cp_inputs(val, completion_psp.cu_seqlens_q, -1) - - if is_finished: - args.train_iters = args.curr_iteration + 1 - return batch + if input_tensor is not None: + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True)) + KL_output_tensor = model(*data) + + if input_tensor is not None: + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False)) + with self.stimer: + output_tensor = model(**data) + return torch.concat([ref_output_tensor, ref_KL_output_tensor, output_tensor, KL_output_tensor], dim=0), partial( + self.loss_func, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params')) diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 95f70a662e..98422b8c43 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -12,7 +12,6 @@ from swift.utils import get_logger from .base import BaseMegatronTrainer -from .utils import get_batch logger = get_logger() From bdfc93955ab50c0cc9dd3ec8d83dcc639d7330c0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 4 Oct 2025 23:38:28 +0800 Subject: [PATCH 07/21] update --- swift/megatron/trainers/base.py | 59 ++++++--- swift/megatron/trainers/dpo_trainer.py | 16 --- swift/megatron/trainers/kto_trainer.py | 174 ++++++++++++------------- swift/megatron/trainers/utils.py | 7 +- 4 files changed, 123 insertions(+), 133 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 619fe1488e..a58bb537fc 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -315,11 +315,12 @@ def _initialize_embedding(model): tensor = module.weight.new_empty(num_to_initialize, module.weight.shape[1]) module.weight.data[initialize_mask] = init_method(tensor) - def _all_reduce_metric(self, metric: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _all_reduce_metric(self, + metric: Dict[str, torch.Tensor], + reduction=torch.distributed.ReduceOp.AVG) -> Dict[str, torch.Tensor]: values = list(metric.values()) reporting_metric = values[0].new_tensor(values) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): @@ -800,6 +801,30 @@ def build_pretraining_data_loader(*_args, **kwargs): def forward_step(self, data_iterator, model): pass + def _prepare_batch(self, data, vp_stage, num_samples=None): + batch = get_batch_on_this_tp_rank(data, vp_stage=vp_stage) + if num_samples is None: + num_samples = batch.pop('num_samples') + args = get_args() + text_position_ids = batch.pop('text_position_ids', None) + if text_position_ids is None: + text_position_ids = batch.get('position_ids') + if args.padding_free and text_position_ids is not None: + batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) + batch['packed_seq_params'].num_samples = num_samples + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + return batch + + def get_batch(self, data_iterator, vp_stage=None): + """Generate a batch.""" + args = get_args() + data = next(data_iterator) + is_finished = data.pop('is_finished', False) + if is_finished: + args.train_iters = args.curr_iteration + 1 + return self._prepare_batch(data, vp_stage) + class MegatronRLHFTrainer(BaseMegatronTrainer): @@ -866,18 +891,18 @@ def _forward_step_helper(model, inputs): return output_tensor - def get_batch(self, data_iterator, vp_stage=None): - """Generate a batch.""" - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank(data_iterator, vp_stage=vp_stage) + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): args = get_args() - num_samples = batch.pop('num_samples') - text_position_ids = batch.pop('text_position_ids', None) - if text_position_ids is None: - text_position_ids = batch.get('position_ids') - if args.padding_free and text_position_ids is not None: - batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - return batch + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + if num_samples is None: + num_samples = packed_seq_params.num_samples * 2 + cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size + all_logps = per_token_logps.new_zeros((num_samples, )) + for i in range(num_samples): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 841003382b..868069b197 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -36,22 +36,6 @@ def __init__(self, args, template): self.dummy_dpo_trainer = DummyDPOTrainer(args) self.ref_models = [] - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params): ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach() output_tensor = output_tensor[output_tensor.shape[0] // 2:] diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d34b62cc67..1a5b14b61c 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -30,100 +30,68 @@ class MegatronKTOTrainer(MegatronRLHFTrainer): def __init__(self, args, template): super().__init__(args, template) + assert args.padding_free, 'Currently `rlhf_type="kto"` only supports padding_free.' self.dummy_kto_trainer = DummyKTOTrainer(args) - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params=None): - args = get_args() - if output_tensor is None: - return None - - shifted_logits = output_tensor[:, :-1, :].contiguous() - shifted_labels = labels[:, 1:].contiguous() - - logits_for_loss = shifted_logits.transpose(0, 1).contiguous() - labels_for_loss = shifted_labels.transpose(0, 1).contiguous() - - per_token_cross_entropy_loss = vocab_parallel_cross_entropy( - logits_for_loss, labels_for_loss, label_smoothing=0.0) - - per_token_logps = -per_token_cross_entropy_loss - loss_mask = (labels_for_loss != -100) - masked_logps = per_token_logps * loss_mask - - if args.padding_free and packed_seq_params is not None: - flattened_logps = masked_logps.squeeze(1) # [seq-1] - - cu_seqlens = packed_seq_params.cu_seqlens_q - num_sequences = cu_seqlens.shape[0] - 1 - all_logps = flattened_logps.new_zeros((num_sequences, )) - for i in range(num_sequences): - start_index, end_index = cu_seqlens[i], cu_seqlens[i + 1] - 1 - if end_index > start_index: - all_logps[i] = flattened_logps[start_index:end_index].sum() - else: - all_logps = masked_logps.sum(dim=0) - - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - - return all_logps - - def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, reference_KL_logps, labels, all_labels, - packed_seq_params): - policy_logps = self.get_logps(output_tensor, labels, packed_seq_params) - is_desirable = all_labels.bool() - - policy_chosen_logps = policy_logps[is_desirable] - policy_rejected_logps = policy_logps[~is_desirable] - reference_chosen_logps = reference_logps[is_desirable] - reference_rejected_logps = reference_logps[~is_desirable] + def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool): + labels = data['labels'] + packed_seq_params = data['packed_seq_params'] + length = packed_seq_params.cu_seqlens_q[-1] + output = self._get_input_tensor(output_tensor, is_KL, is_ref, length, dim=1) + return self.get_logps(output, labels, packed_seq_params, packed_seq_params.num_samples) + + def loss_func(self, output_tensor, *, data, kl_data, label): + label = data['label'] + policy_logps = self._kto_get_logps(output_tensor, data, False, False) + ref_logps = self._kto_get_logps(output_tensor, data, False, True) + policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False) + ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True) + + label = label.bool() + policy_chosen_logps = policy_logps[label] + policy_rejected_logps = policy_logps[~label] + ref_chosen_logps = ref_logps[label] + ref_rejected_logps = ref_logps[~label] loss, chosen_rewards, rejected_rewards, kl = self.dummy_kto_trainer.kto_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - policy_KL_logps=policy_KL_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - reference_KL_logps=reference_KL_logps, + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + ref_chosen_logps, + ref_rejected_logps, + ref_KL_logps, ) - loss = loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=policy_logps.device) - - with torch.no_grad(): - chosen_rewards_mean = chosen_rewards.mean() if chosen_rewards.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - rejected_rewards_mean = rejected_rewards.mean() if rejected_rewards.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - policy_chosen_logps_mean = policy_chosen_logps.mean() if policy_chosen_logps.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - policy_rejected_logps_mean = policy_rejected_logps.mean( - ) if policy_rejected_logps.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - - metric = { - 'loss': loss.clone().detach(), - 'logps/chosen': policy_chosen_logps_mean, - 'logps/rejected': policy_rejected_logps_mean, - 'rewards/chosen': chosen_rewards_mean, - 'rewards/rejected': rejected_rewards_mean, - 'rewards/margins': chosen_rewards_mean - rejected_rewards_mean, - 'kl': kl.detach() if kl is not None else torch.tensor(0.0, device=loss.device), + loss = loss.mean() + mean_metric = { + 'loss': loss.detach().clone(), + 'kl': kl.detach(), } - - reporting_metric = loss.new_tensor(list(metric.values())) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + metric = self._all_reduce_metric(mean_metric) + sum_metric = { + 'logps/chosen': torch.stack([policy_chosen_logps.nansum(), policy_chosen_logps.shape[0]]), + 'logps/rejected': torch.stack([policy_rejected_logps.nansum(), policy_rejected_logps.shape[0]]), + 'rewards/chosen': torch.stack([chosen_rewards.nansum(), chosen_rewards.shape[0]]), + 'rewards/rejected': torch.stack([rejected_rewards.nansum(), rejected_rewards.shape[0]]), + } + metric.update(self._all_reduce_metric(sum_metric, torch.distributed.ReduceOp.SUM)) # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 loss = loss / mpu.get_context_parallel_world_size() - return loss, reporting_metric + return loss, metric @staticmethod - def _get_input_tensor(input_tensor, is_ref: bool, is_KL: bool): - i = (not is_ref) * 2 + is_KL - return input_tensor[i:i + 1] + def _get_input_tensor(input_tensor, is_KL: bool, is_ref: bool, length: int, dim: int): + # polocy, ref, polocy_KL, ref_KL + total_length = input_tensor.shape[dim] + KL_length = (total_length - 2 * length) // 2 + slice_list = [0, length, 2 * length, total_length - KL_length, total_length] + idx = is_KL * 2 + is_ref + slice_ = (slice(None), ) * dim + (slice(slice_list[idx], slice_list[idx + 1]), ) + res = input_tensor[slice_] + if is_KL or is_ref: + res = res.detach() + return res def forward_step(self, data_iterator, model): timers = get_timers() @@ -133,26 +101,44 @@ def forward_step(self, data_iterator, model): vp_stage = unwrapped_model.vp_stage timers('batch-generator', log_level=2).start() with self.stimer(bdata=True): - data = self.get_batch(data_iterator, vp_stage) + # not support loss_scale + data, kl_data = self.get_batch(data_iterator, vp_stage) timers('batch-generator').stop() - data.pop('loss_scale', None) + label = data.pop('label') + length = data['packed_seq_params'].cu_seqlens_q[-1] with torch.no_grad(), self.null_ref_context() as ref_models: ref_model = ref_models[vp_stage or 0] + if self.args.calculate_KL: + if input_tensor is not None: + ref_model.set_input_tensor(self._get_input_tensor(True, True, length)) + ref_KL_output_tensor = ref_model(**kl_data) + if input_tensor is not None: - ref_model.set_input_tensor(self._get_input_tensor(True, False).detach()) + ref_model.set_input_tensor(self._get_input_tensor(True, False, length)) ref_output_tensor = ref_model(**data) - if input_tensor is not None: - ref_model.set_input_tensor(self._get_input_tensor(True, True).detach()) - ref_KL_output_tensor = ref_model(**data) - with torch.no_grad(): - if input_tensor is not None: - unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True)) - KL_output_tensor = model(*data) + + if self.args.calculate_KL: + with torch.no_grad(): + if input_tensor is not None: + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True)) + KL_output_tensor = model(**kl_data) if input_tensor is not None: unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False)) with self.stimer: output_tensor = model(**data) - return torch.concat([ref_output_tensor, ref_KL_output_tensor, output_tensor, KL_output_tensor], dim=0), partial( - self.loss_func, labels=data.get('labels'), packed_seq_params=data.get('packed_seq_params')) + if self.args.calculate_KL: + res = torch.concat([output_tensor, ref_output_tensor], dim=1) + else: + res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=1) + return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) + + def _prepare_batch(self, data, vp_stage): + res = [] + num_samples = data.pop('num_samples') + for key in ['completion_', 'KL_completion_']: + _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} + res.append(super()._prepare_batch(_data, vp_stage, num_samples)) + res[0]['label'] = data['label'] + return res diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 9f77e0aaba..06695298a8 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -27,11 +27,9 @@ def swift_datasets_provider(train_val_test_num_samples): # Code borrowed from NVIDIA/Megatron-LM -def get_batch_on_this_tp_rank(data_iterator, vp_stage=None): +def get_batch_on_this_tp_rank(data, vp_stage=None): args = get_args() - data = next(data_iterator) - is_finished = data.pop('is_finished', False) if args.task_type == 'causal_lm': data['labels'] = torch.roll(data['labels'], -1, dims=-1) if 'loss_scale' in data: @@ -48,9 +46,6 @@ def get_batch_on_this_tp_rank(data_iterator, vp_stage=None): for key in ('input_ids', 'labels', 'loss_scale'): batch[key] = None - if is_finished: - args.train_iters = args.curr_iteration + 1 - return batch From 849972cc9e33dd7becd829dae2d7f78996a4a2c5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 14:36:20 +0800 Subject: [PATCH 08/21] update --- swift/megatron/trainers/kto_trainer.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 1a5b14b61c..217a76f9b9 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -18,7 +18,7 @@ class DummyKTOTrainer(KTOTrainer): # For reusing the dpo_loss function in TRL. def __init__(self, args): - self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) + self.accelerator = namedtuple('Accelerator', ['device', 'gather_for_metrics'])(device=get_current_device(), gather_for_metrics=) self.loss_type = args.loss_type self.beta = args.beta self.desirable_weight = args.desirable_weight @@ -33,21 +33,20 @@ def __init__(self, args, template): assert args.padding_free, 'Currently `rlhf_type="kto"` only supports padding_free.' self.dummy_kto_trainer = DummyKTOTrainer(args) - def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool): + def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool, length: int): labels = data['labels'] packed_seq_params = data['packed_seq_params'] - length = packed_seq_params.cu_seqlens_q[-1] output = self._get_input_tensor(output_tensor, is_KL, is_ref, length, dim=1) return self.get_logps(output, labels, packed_seq_params, packed_seq_params.num_samples) def loss_func(self, output_tensor, *, data, kl_data, label): - label = data['label'] - policy_logps = self._kto_get_logps(output_tensor, data, False, False) - ref_logps = self._kto_get_logps(output_tensor, data, False, True) - policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False) - ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True) + length = data['packed_seq_params'].cu_seqlens_q[-1] + policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) + ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) + policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) + ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) - label = label.bool() + label = output_tensor.new_tensor(label, dtype=torch.bool) policy_chosen_logps = policy_logps[label] policy_rejected_logps = policy_logps[~label] ref_chosen_logps = ref_logps[label] @@ -129,9 +128,9 @@ def forward_step(self, data_iterator, model): with self.stimer: output_tensor = model(**data) if self.args.calculate_KL: - res = torch.concat([output_tensor, ref_output_tensor], dim=1) - else: res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=1) + else: + res = torch.concat([output_tensor, ref_output_tensor], dim=1) return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) def _prepare_batch(self, data, vp_stage): From b087499f09f5d3b5dd22142346e5c8da9feb89d6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 16:30:44 +0800 Subject: [PATCH 09/21] updae --- swift/megatron/trainers/base.py | 21 ++++++++-------- swift/megatron/trainers/kto_trainer.py | 34 ++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index a58bb537fc..0f9127097a 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime -from typing import Dict +from typing import Dict, Literal import megatron.core import torch @@ -457,11 +457,7 @@ def evaluate(self, timers('evaluate').stop() timers.log(['evaluate']) - - total_loss_dict.update({ - k: torch.tensor([v], device='cuda') - for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics['eval'], 'eval_').items() - }) + self.custom_log(total_loss_dict, 'eval') rerun_state_machine.set_mode(rerun_mode) if is_last_rank(): logs = {} @@ -470,6 +466,14 @@ def evaluate(self, self.jsonl_writer.append(logs) return total_loss_dict, collected_non_loss_data, False + def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: + prefix = '' if mode == 'train' else 'eval_' + advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 + total_loss_dict.update({ + k: torch.tensor([v * advanced_iters], device='cuda') + for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics[mode], prefix).items() + }) + # Code borrowed from NVIDIA/Megatron-LM def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): @@ -618,10 +622,7 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) if iteration % args.log_interval == 0 or iteration == 1: - total_loss_dict.update({ - k: torch.tensor([v * total_loss_dict[advanced_iters_key]], device='cuda') - for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics['train']).items() - }) + self.custom_log(total_loss_dict, 'train') origin_total_loss_dict = total_loss_dict.copy() if args.record_memory_history and is_last_rank(): diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 217a76f9b9..f7b262d23a 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -1,8 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections import namedtuple from functools import partial +from typing import Literal import torch +from accelerate.utils import gather from megatron.core import mpu from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy from megatron.training import get_args, get_timers @@ -17,8 +19,13 @@ class DummyKTOTrainer(KTOTrainer): # For reusing the dpo_loss function in TRL. + + def gather_for_metrics(self, input_data, *args, **kwargs): + return gather(input_data) + def __init__(self, args): - self.accelerator = namedtuple('Accelerator', ['device', 'gather_for_metrics'])(device=get_current_device(), gather_for_metrics=) + self.accelerator = namedtuple('Accelerator', ['device', 'gather_for_metrics'])( + device=get_current_device(), gather_for_metrics=self.gather_for_metrics) self.loss_type = args.loss_type self.beta = args.beta self.desirable_weight = args.desirable_weight @@ -68,10 +75,12 @@ def loss_func(self, output_tensor, *, data, kl_data, label): } metric = self._all_reduce_metric(mean_metric) sum_metric = { - 'logps/chosen': torch.stack([policy_chosen_logps.nansum(), policy_chosen_logps.shape[0]]), - 'logps/rejected': torch.stack([policy_rejected_logps.nansum(), policy_rejected_logps.shape[0]]), - 'rewards/chosen': torch.stack([chosen_rewards.nansum(), chosen_rewards.shape[0]]), - 'rewards/rejected': torch.stack([rejected_rewards.nansum(), rejected_rewards.shape[0]]), + 'logps/chosen_sum': policy_chosen_logps.nansum(), + 'logps/rejected_sum': policy_rejected_logps.nansum(), + 'rewards/chosen_sum': chosen_rewards.nansum(), + 'rewards/rejected_sum': rejected_rewards.nansum(), + 'count/chosen': loss.new_tensor(policy_rejected_logps.shape[0]), + 'count/rejected': loss.new_tensor(rejected_rewards.shape[0]), } metric.update(self._all_reduce_metric(sum_metric, torch.distributed.ReduceOp.SUM)) # fix megatron-lm bug @@ -141,3 +150,18 @@ def _prepare_batch(self, data, vp_stage): res.append(super()._prepare_batch(_data, vp_stage, num_samples)) res[0]['label'] = data['label'] return res + + def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: + super().custom_log(total_loss_dict, mode) + res = {} + for k, v in total_loss_dict.items(): + if k.startswith('count/'): + continue + if k.endswith('_sum'): + new_k = k.rsplit('_', 1)[-2] + count = total_loss_dict[f"count/{new_k.rsplit('/', 1)[-1]}"] + res[new_k] = v / count + else: + res[k] = v + total_loss_dict.clear() + total_loss_dict.update(res) From cb1d0fbcefd3601ab87b45e672ab78dcb59d6573 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 16:39:21 +0800 Subject: [PATCH 10/21] update --- swift/megatron/trainers/kto_trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index f7b262d23a..029430169d 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -152,16 +152,17 @@ def _prepare_batch(self, data, vp_stage): return res def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: + prefix = '' if mode == 'train' else 'eval_' super().custom_log(total_loss_dict, mode) res = {} for k, v in total_loss_dict.items(): - if k.startswith('count/'): + if k.startswith(f'{prefix}count/') or k.endswith('_sum'): continue - if k.endswith('_sum'): - new_k = k.rsplit('_', 1)[-2] - count = total_loss_dict[f"count/{new_k.rsplit('/', 1)[-1]}"] - res[new_k] = v / count - else: - res[k] = v + res[k] = v + for key in ['chosen', 'rejected']: + count = total_loss_dict[f'{prefix}count/{key}'] + res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count + res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count + res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected'] total_loss_dict.clear() total_loss_dict.update(res) From a76e6afd24eba818d6d9d48691faaae7bd5f8caf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 16:57:05 +0800 Subject: [PATCH 11/21] update --- swift/megatron/trainers/base.py | 4 ++-- swift/megatron/trainers/kto_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 0f9127097a..7ed5290c9e 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -470,8 +470,8 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: prefix = '' if mode == 'train' else 'eval_' advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 total_loss_dict.update({ - k: torch.tensor([v * advanced_iters], device='cuda') - for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics[mode], prefix).items() + f'{prefix}{k}': torch.tensor([v * advanced_iters], device='cuda') + for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics[mode]).items() }) # Code borrowed from NVIDIA/Megatron-LM diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 029430169d..7ebe2d08dd 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -79,7 +79,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label): 'logps/rejected_sum': policy_rejected_logps.nansum(), 'rewards/chosen_sum': chosen_rewards.nansum(), 'rewards/rejected_sum': rejected_rewards.nansum(), - 'count/chosen': loss.new_tensor(policy_rejected_logps.shape[0]), + 'count/chosen': loss.new_tensor(chosen_rewards.shape[0]), 'count/rejected': loss.new_tensor(rejected_rewards.shape[0]), } metric.update(self._all_reduce_metric(sum_metric, torch.distributed.ReduceOp.SUM)) From b493fcff3c9af778a85a28b595c778981c83f124 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 17:06:45 +0800 Subject: [PATCH 12/21] update --- swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/train/rlhf.py | 2 +- swift/megatron/trainers/base.py | 1 + swift/megatron/trainers/kto_trainer.py | 2 -- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index f403d19e65..88b886ee1c 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -17,7 +17,7 @@ @dataclass class RLHFMegatronArgumentsMixin: - rlhf_type: Literal['dpo', 'kto', None] = None + rlhf_type: Literal['dpo', 'kto'] = None ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 82b01d9ad3..da964950dc 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,8 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +from swift.llm.train.kto import prepare_kto_dataset from swift.utils import get_logger -from ...llm.train.kto import prepare_kto_dataset from ..argument import MegatronRLHFArguments from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer from .sft import MegatronSft diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 7ed5290c9e..9b42affcd2 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -28,6 +28,7 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version +from torch.distributed.nn import all_reduce from transformers.utils import ContextManagers from swift.llm import dynamic_gradient_checkpointing diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 7ebe2d08dd..0a4084c3a9 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -6,9 +6,7 @@ import torch from accelerate.utils import gather from megatron.core import mpu -from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy from megatron.training import get_args, get_timers -from torch.distributed.nn import all_reduce from trl import KTOTrainer from swift.utils import get_current_device, get_logger From 973cd6df1b106e6fed8d08a9af18d0c3e66b87e5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 19:27:02 +0800 Subject: [PATCH 13/21] update --- swift/megatron/trainers/base.py | 2 +- swift/megatron/trainers/kto_trainer.py | 32 +++++++++++++++----------- swift/megatron/trainers/utils.py | 11 ++++----- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9b42affcd2..833221d845 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -471,7 +471,7 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: prefix = '' if mode == 'train' else 'eval_' advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 total_loss_dict.update({ - f'{prefix}{k}': torch.tensor([v * advanced_iters], device='cuda') + k: torch.tensor([v * advanced_iters], device='cuda') for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics[mode]).items() }) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 0a4084c3a9..db171e6db8 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -48,9 +48,11 @@ def loss_func(self, output_tensor, *, data, kl_data, label): length = data['packed_seq_params'].cu_seqlens_q[-1] policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) - policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) - ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) - + if self.args.calculate_KL: + policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) + ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) + else: + policy_KL_logps, ref_KL_logps = None, None label = output_tensor.new_tensor(label, dtype=torch.bool) policy_chosen_logps = policy_logps[label] policy_rejected_logps = policy_logps[~label] @@ -111,31 +113,35 @@ def forward_step(self, data_iterator, model): data, kl_data = self.get_batch(data_iterator, vp_stage) timers('batch-generator').stop() label = data.pop('label') + data.pop('loss_scale', None) + kl_data.pop('loss_scale', None) + length = data['packed_seq_params'].cu_seqlens_q[-1] with torch.no_grad(), self.null_ref_context() as ref_models: ref_model = ref_models[vp_stage or 0] if self.args.calculate_KL: if input_tensor is not None: - ref_model.set_input_tensor(self._get_input_tensor(True, True, length)) + ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, True, length, 0)) ref_KL_output_tensor = ref_model(**kl_data) if input_tensor is not None: - ref_model.set_input_tensor(self._get_input_tensor(True, False, length)) + ref_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True, length, 0)) ref_output_tensor = ref_model(**data) if self.args.calculate_KL: with torch.no_grad(): if input_tensor is not None: - unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True)) + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, 0)) KL_output_tensor = model(**kl_data) if input_tensor is not None: - unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False)) + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, 0)) with self.stimer: output_tensor = model(**data) + dim = 1 if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else 0 if self.args.calculate_KL: - res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=1) + res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=dim) else: res = torch.concat([output_tensor, ref_output_tensor], dim=1) return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) @@ -154,13 +160,13 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: super().custom_log(total_loss_dict, mode) res = {} for k, v in total_loss_dict.items(): - if k.startswith(f'{prefix}count/') or k.endswith('_sum'): + if k.startswith(f'count/') or k.endswith('_sum'): continue res[k] = v for key in ['chosen', 'rejected']: - count = total_loss_dict[f'{prefix}count/{key}'] - res[f'{prefix}logps/{key}'] = total_loss_dict[f'{prefix}logps/{key}_sum'] / count - res[f'{prefix}rewards/{key}'] = total_loss_dict[f'{prefix}rewards/{key}_sum'] / count - res[f'{prefix}rewards/margins'] = res[f'{prefix}rewards/chosen'] - res[f'{prefix}rewards/rejected'] + count = total_loss_dict[f'count/{key}'] + res[f'logps/{key}'] = total_loss_dict[f'logps/{key}_sum'] / count + res[f'rewards/{key}'] = total_loss_dict[f'rewards/{key}_sum'] / count + res[f'rewards/margins'] = res[f'rewards/chosen'] - res[f'rewards/rejected'] total_loss_dict.clear() total_loss_dict.update(res) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 06695298a8..35dd538f0d 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -36,15 +36,12 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): data['loss_scale'] = torch.roll(data['loss_scale'], -1, dims=-1) batch = to_device(data, 'cuda', non_blocking=True) if args.pipeline_model_parallel_size == 1: - pass - elif mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + return batch + if not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + batch['input_ids'] = None + if not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): batch['labels'] = None batch['loss_scale'] = None - elif mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): - batch['input_ids'] = None - else: - for key in ('input_ids', 'labels', 'loss_scale'): - batch[key] = None return batch From 255aece39c60c6ce67daae78d9f466b0d90509ab Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Oct 2025 21:24:15 +0800 Subject: [PATCH 14/21] update --- swift/megatron/trainers/base.py | 1 - swift/megatron/trainers/kto_trainer.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 833221d845..d28736f220 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -468,7 +468,6 @@ def evaluate(self, return total_loss_dict, collected_non_loss_data, False def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: - prefix = '' if mode == 'train' else 'eval_' advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 total_loss_dict.update({ k: torch.tensor([v * advanced_iters], device='cuda') diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index db171e6db8..ba348a7419 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -156,17 +156,16 @@ def _prepare_batch(self, data, vp_stage): return res def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: - prefix = '' if mode == 'train' else 'eval_' super().custom_log(total_loss_dict, mode) res = {} for k, v in total_loss_dict.items(): - if k.startswith(f'count/') or k.endswith('_sum'): + if k.startswith('count/') or k.endswith('_sum'): continue res[k] = v for key in ['chosen', 'rejected']: count = total_loss_dict[f'count/{key}'] res[f'logps/{key}'] = total_loss_dict[f'logps/{key}_sum'] / count res[f'rewards/{key}'] = total_loss_dict[f'rewards/{key}_sum'] / count - res[f'rewards/margins'] = res[f'rewards/chosen'] - res[f'rewards/rejected'] + res['rewards/margins'] = res['rewards/chosen'] - res['rewards/rejected'] total_loss_dict.clear() total_loss_dict.update(res) From 38023e968d8b91c5af11e786676680fb3b153145 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 18:11:30 +0800 Subject: [PATCH 15/21] update --- ...345\277\253\351\200\237\345\274\200\345\247\213.md" | 1 + docs/source_en/Megatron-SWIFT/Quick-start.md | 1 + swift/megatron/trainers/kto_trainer.py | 10 +++++----- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index f5d1a3cc13..13f0bea0cf 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -9,6 +9,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数 | 预训练| ✅ | ✅| ✅ | ✅ | | 指令监督微调 | ✅ | ✅| ✅ | ✅ | | DPO | ✅ | ✅| ✅ | ✅ | +| KTO | ✅ | ✅| ✅ | ✅ | | 分类任务 | ✅ | ✅| ✅ | ✅ | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 8e9c1cf8d6..8ba3d4d44a 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -8,6 +8,7 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr | Pretraining | ✅ | ✅ | ✅ | ✅ | | Instruction-supervised fine-tuning | ✅ | ✅ | ✅ | ✅ | | DPO | ✅ | ✅ | ✅ | ✅ | +| KTO | ✅ | ✅ | ✅ | ✅ | | Classification tasks | ✅ | ✅ | ✅ | ✅ | ## Environment Setup diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index ba348a7419..9b927de435 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -16,7 +16,7 @@ class DummyKTOTrainer(KTOTrainer): - # For reusing the dpo_loss function in TRL. + # For reusing the kto_loss function in TRL. def gather_for_metrics(self, input_data, *args, **kwargs): return gather(input_data) @@ -90,7 +90,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label): @staticmethod def _get_input_tensor(input_tensor, is_KL: bool, is_ref: bool, length: int, dim: int): - # polocy, ref, polocy_KL, ref_KL + # policy, ref, policy_KL, ref_KL total_length = input_tensor.shape[dim] KL_length = (total_length - 2 * length) // 2 slice_list = [0, length, 2 * length, total_length - KL_length, total_length] @@ -132,18 +132,18 @@ def forward_step(self, data_iterator, model): if self.args.calculate_KL: with torch.no_grad(): if input_tensor is not None: - unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, 0)) + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, length, 0)) KL_output_tensor = model(**kl_data) if input_tensor is not None: - unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, 0)) + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, 0)) with self.stimer: output_tensor = model(**data) dim = 1 if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else 0 if self.args.calculate_KL: res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=dim) else: - res = torch.concat([output_tensor, ref_output_tensor], dim=1) + res = torch.concat([output_tensor, ref_output_tensor], dim=dim) return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) def _prepare_batch(self, data, vp_stage): From caeaae2379b6c47293def189827122c5e877d58a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 19:36:26 +0800 Subject: [PATCH 16/21] fix --- swift/megatron/trainers/kto_trainer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 9b927de435..290a133004 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -19,7 +19,13 @@ class DummyKTOTrainer(KTOTrainer): # For reusing the kto_loss function in TRL. def gather_for_metrics(self, input_data, *args, **kwargs): - return gather(input_data) + output_tensors = torch.empty( + mpu.get_data_parallel_world_size() * input_data.numel(), + dtype=input_data.dtype, + device=input_data.device, + ) + torch.distributed.all_gather_into_tensor(output_tensors, input_data, group=mpu.get_data_parallel_group()) + return output_tensors def __init__(self, args): self.accelerator = namedtuple('Accelerator', ['device', 'gather_for_metrics'])( @@ -163,9 +169,12 @@ def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: continue res[k] = v for key in ['chosen', 'rejected']: - count = total_loss_dict[f'count/{key}'] + count = total_loss_dict.get(f'count/{key}') + if count is None or count.item() == 0: + continue res[f'logps/{key}'] = total_loss_dict[f'logps/{key}_sum'] / count res[f'rewards/{key}'] = total_loss_dict[f'rewards/{key}_sum'] / count - res['rewards/margins'] = res['rewards/chosen'] - res['rewards/rejected'] + if 'rewards/chosen' in res and 'rewards/rejected' in res: + res['rewards/margins'] = res['rewards/chosen'] - res['rewards/rejected'] total_loss_dict.clear() total_loss_dict.update(res) From a438ff7922f1e2f3d992e820cf989539b2ac3062 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 19:40:04 +0800 Subject: [PATCH 17/21] fix --- swift/megatron/trainers/kto_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 290a133004..f85ab0aee1 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -4,7 +4,6 @@ from typing import Literal import torch -from accelerate.utils import gather from megatron.core import mpu from megatron.training import get_args, get_timers from trl import KTOTrainer From b7cb9b2d8f3d9f646476129531743763b1086ffa Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 19:44:22 +0800 Subject: [PATCH 18/21] update --- swift/megatron/trainers/base.py | 85 ---------------------- swift/megatron/trainers/dpo_trainer.py | 2 +- swift/megatron/trainers/kto_trainer.py | 2 +- swift/megatron/trainers/rlhf_mixin.py | 99 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 87 deletions(-) create mode 100644 swift/megatron/trainers/rlhf_mixin.py diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index d28736f220..1ecd6cd3c0 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -13,7 +13,6 @@ from megatron.core import mpu from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.enums import ModelType -from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine @@ -28,8 +27,6 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version -from torch.distributed.nn import all_reduce -from transformers.utils import ContextManagers from swift.llm import dynamic_gradient_checkpointing from swift.plugin import MeanMetric @@ -825,85 +822,3 @@ def get_batch(self, data_iterator, vp_stage=None): if is_finished: args.train_iters = args.curr_iteration + 1 return self._prepare_batch(data, vp_stage) - - -class MegatronRLHFTrainer(BaseMegatronTrainer): - - def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - args = get_args() - if args.train_type == 'full': - ref_models = get_model(model_provider_func, model_type, wrap_with_ddp=False) - for m in ref_models: - m = unwrap_model(m) - m.requires_grad_(False).eval() - if args.ref_load is None: - args.ref_load = args.load - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_models, None, None, load_arg='ref_load') - self.ref_models = ref_models - return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) - - @contextmanager - def null_ref_context(self): - args = get_args() - contexts = [] - if args.train_type == 'full': - ref_models = self.ref_models - else: - if args.ref_adapter_load is None: - for m in self.peft_models: - contexts.append(m.disable_adapter()) - ref_models = self.unwrapped_models - with ContextManagers(contexts): - if args.ref_adapter_load: - for m in self.peft_models: - m.set_adapter('ref_adapter') - yield ref_models - if args.ref_adapter_load: - for m in self.peft_models: - m.set_adapter('default') - - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - if num_samples is None: - num_samples = packed_seq_params.num_samples * 2 - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples, )) - for i in range(num_samples): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 868069b197..c067612201 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -9,7 +9,7 @@ from swift.trainers import DPOTrainer from swift.utils import get_current_device, get_logger -from .base import MegatronRLHFTrainer +from .rlhf_mixin import MegatronRLHFTrainer logger = get_logger() diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index f85ab0aee1..a3d8cd2f01 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -9,7 +9,7 @@ from trl import KTOTrainer from swift.utils import get_current_device, get_logger -from .base import MegatronRLHFTrainer +from .rlhf_mixin import MegatronRLHFTrainer logger = get_logger() diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py new file mode 100644 index 0000000000..ead111435e --- /dev/null +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -0,0 +1,99 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from contextlib import contextmanager + +import torch +import torch.nn +from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.training import get_args, get_model +from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model +from torch.distributed.nn import all_reduce +from transformers.utils import ContextManagers + +from swift.utils import get_logger +from .base import BaseMegatronTrainer + +logger = get_logger() + + +class MegatronRLHFTrainer(BaseMegatronTrainer): + + def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): + args = get_args() + if args.train_type == 'full': + ref_models = get_model(model_provider_func, model_type, wrap_with_ddp=False) + for m in ref_models: + m = unwrap_model(m) + m.requires_grad_(False).eval() + if args.ref_load is None: + args.ref_load = args.load + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_models, None, None, load_arg='ref_load') + self.ref_models = ref_models + return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) + + @contextmanager + def null_ref_context(self): + args = get_args() + contexts = [] + if args.train_type == 'full': + ref_models = self.ref_models + else: + if args.ref_adapter_load is None: + for m in self.peft_models: + contexts.append(m.disable_adapter()) + ref_models = self.unwrapped_models + with ContextManagers(contexts): + if args.ref_adapter_load: + for m in self.peft_models: + m.set_adapter('ref_adapter') + yield ref_models + if args.ref_adapter_load: + for m in self.peft_models: + m.set_adapter('default') + + @staticmethod + def _forward_step_helper(model, inputs): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + if num_samples is None: + num_samples = packed_seq_params.num_samples * 2 + cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size + all_logps = per_token_logps.new_zeros((num_samples, )) + for i in range(num_samples): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps From 2062ecbb699cd61386b022fd0ce2a1e8c776fbfe Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 21:38:39 +0800 Subject: [PATCH 19/21] update --- examples/megatron/rlhf/kto/dense.sh | 36 +++++++++++++++++++++++ examples/megatron/rlhf/kto/moe.sh | 44 +++++++++++++++++++++++++++++ swift/llm/train/kto.py | 2 +- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 examples/megatron/rlhf/kto/dense.sh create mode 100644 examples/megatron/rlhf/kto/moe.sh diff --git a/examples/megatron/rlhf/kto/dense.sh b/examples/megatron/rlhf/kto/dense.sh new file mode 100644 index 0000000000..cbcb1c63c4 --- /dev/null +++ b/examples/megatron/rlhf/kto/dense.sh @@ -0,0 +1,36 @@ +# 4 * 43GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +megatron rlhf \ + --rlhf_type kto \ + --load Qwen2.5-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#20000' \ + --load_from_cache_file true \ + --split_dataset_ratio 0.01 \ + --tensor_model_parallel_size 4 \ + --packing true \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen2.5-7B-Instruct \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --desirable_weight 1 \ + --undesirable_weight 1 diff --git a/examples/megatron/rlhf/kto/moe.sh b/examples/megatron/rlhf/kto/moe.sh new file mode 100644 index 0000000000..c44936ab40 --- /dev/null +++ b/examples/megatron/rlhf/kto/moe.sh @@ -0,0 +1,44 @@ +# 2 * 48GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron rlhf \ + --rlhf_type kto \ + --load Qwen3-30B-A3B-Instruct-2507-mcore \ + --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#20000' \ + --load_from_cache_file true \ + --packing true \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --split_dataset_ratio 0.01 \ + --expert_model_parallel_size 2 \ + --moe_permute_fusion true \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-3 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --save megatron_output/Qwen3-30B-A3B-Instruct-2507 \ + --eval_interval 100 \ + --save_interval 100 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --desirable_weight 1 \ + --undesirable_weight 1 diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index 43ec3a8004..966c11cb61 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -72,7 +72,7 @@ def prepare_kto_dataset(args, train_dataset, val_dataset): f""" You have different amounts of desirable/positive and undesirable/negative examples but the weights on the desirable and undesirable losses don't seem to be in an ideal range. Based - on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, '{des_weight_upper_bound}] + on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). See the documentation on how to optimally set these weights.""", UserWarning) return train_dataset, val_dataset From 3a87e2a4a0280803cc8a44a4df38431ef6625133 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 21:48:24 +0800 Subject: [PATCH 20/21] update --- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 1a51eaa66a..4898d612ae 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -186,7 +186,7 @@ - 注意:在"ms-swift<3.7.1",其默认为None,自动从config.json读取。 - moe_z_loss_coeff: z-loss 的缩放系数。默认为None。 - 🔥moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。 -- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度。 +- 🔥moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度。 - moe_pad_expert_input_to_capacity: 对每个专家(expert)的输入进行填充,使其长度与专家容量(expert capacity length)对齐,默认为False。该操作仅在设置了 `--moe_expert_capacity_factor` 参数后才生效。 - moe_token_drop_policy: 可选为'probs', 'position'。默认为'probs'。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index b52b9089aa..2a54f5c9f2 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -198,7 +198,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - Note: In ms-swift versions earlier than 3.7.1, the default is None and the value is automatically loaded from config.json. - moe_z_loss_coeff: Scaling coefficient for z-loss. Default is None. - 🔥moe_shared_expert_overlap: Enables overlap between shared expert computation and the dispatcher. If not enabled, shared expert computation will be performed after routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False. -- moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed. +- 🔥moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed. - moe_pad_expert_input_to_capacity: Pad the input of each expert so that its length aligns with the expert capacity length. Default is `False`. This option only takes effect if `--moe_expert_capacity_factor` is set. - moe_token_drop_policy: Options are 'probs' and 'position'. Default is 'probs'. From 65f17ffa9c8be6e0ee3e25a97534986eb47949d4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 8 Oct 2025 21:49:33 +0800 Subject: [PATCH 21/21] update --- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 4898d612ae..9aabef22c7 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -186,7 +186,7 @@ - 注意:在"ms-swift<3.7.1",其默认为None,自动从config.json读取。 - moe_z_loss_coeff: z-loss 的缩放系数。默认为None。 - 🔥moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。 -- 🔥moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度。 +- 🔥moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度(例如设置为1)。 - moe_pad_expert_input_to_capacity: 对每个专家(expert)的输入进行填充,使其长度与专家容量(expert capacity length)对齐,默认为False。该操作仅在设置了 `--moe_expert_capacity_factor` 参数后才生效。 - moe_token_drop_policy: 可选为'probs', 'position'。默认为'probs'。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 2a54f5c9f2..4f72f6a52c 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -198,7 +198,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - Note: In ms-swift versions earlier than 3.7.1, the default is None and the value is automatically loaded from config.json. - moe_z_loss_coeff: Scaling coefficient for z-loss. Default is None. - 🔥moe_shared_expert_overlap: Enables overlap between shared expert computation and the dispatcher. If not enabled, shared expert computation will be performed after routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False. -- 🔥moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed. +- 🔥moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed (for example, set it to 1.). - moe_pad_expert_input_to_capacity: Pad the input of each expert so that its length aligns with the expert capacity length. Default is `False`. This option only takes effect if `--moe_expert_capacity_factor` is set. - moe_token_drop_policy: Options are 'probs' and 'position'. Default is 'probs'.