diff --git "a/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" index fe9d94ba91..e560ae258c 100644 --- "a/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source/BestPractices/Qwen3\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -328,7 +328,7 @@ swift rlhf \ Qwen3-235B-A22B-Instruct-250718 单机8卡H20 LoRA训练的最佳实践参考:[https://github.com/modelscope/ms-swift/pull/5033](https://github.com/modelscope/ms-swift/pull/5033)。 -ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO。支持的模型可以在[支持的模型文档](../Instruction/支持的模型和数据集.md)中找到。 +ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO/KTO。支持的模型可以在[支持的模型文档](../Instruction/支持的模型和数据集.md)中找到。 关于环境准备以及 HF 和 MCore 模型权重的转换,可以参考[Megatron-SWIFT训练文档](../Megatron-SWIFT/快速开始.md)。 diff --git "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" index aea59cbc83..fcb4009898 100644 --- "a/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/GetStarted/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -10,7 +10,7 @@ ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架 - 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 - 🍊 RLHF训练:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。 - 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 -- 🥥 Megatron并行技术:支持使用Megatron并行技术对CPT/SFT/DPO进行加速,现支持200+大语言模型。 +- 🥥 Megatron并行技术:支持使用Megatron并行技术对CPT/SFT/DPO/KTO进行加速,现支持200+大语言模型。 - 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 - 🍉 工具箱能力:除了对大模型和多模态大模型的训练支持外,还支持其推理、评测、量化和部署全流程。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ea97a35fe2..9aabef22c7 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -186,7 +186,7 @@ - 注意:在"ms-swift<3.7.1",其默认为None,自动从config.json读取。 - moe_z_loss_coeff: z-loss 的缩放系数。默认为None。 - 🔥moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。 -- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度。 +- 🔥moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。通过设置 `--moe_expert_capacity_factor`,超出专家容量的 token 会基于其被选中的概率被丢弃。可以令训练负载均匀,提升训练速度(例如设置为1)。 - moe_pad_expert_input_to_capacity: 对每个专家(expert)的输入进行填充,使其长度与专家容量(expert capacity length)对齐,默认为False。该操作仅在设置了 `--moe_expert_capacity_factor` 参数后才生效。 - moe_token_drop_policy: 可选为'probs', 'position'。默认为'probs'。 @@ -233,13 +233,15 @@ lora训练: - reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 - label_smoothing: 默认为0.。 - f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 -- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 +- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 **KTO参数**: -- beta: KL正则项系数,默认为`0.1`。 -- desirable_weight: KTO算法中对desirable response的loss权重 $\lambda_D$,默认为`1.`。 -- undesirable_weight: KTO算法中对undesirable response的loss权重 $\lambda_U$,默认为`1.`。 -- calculate_KL: 是否计算KL散度,默认为True。 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 +- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 +- desirable_weight: 抵消 desirable 和 undesirable 配对数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 +- undesirable_weight: 抵消 desirable 和 undesirable 配对数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 ## 训练参数 @@ -252,7 +254,7 @@ Megatron训练参数继承自Megatron参数和基本参数(与ms-swift共用da - mlp_padding_free: 默认为False。用于padding_free设置为false时,对mlp进行padding_free优化。这可以在自定义attention_mask的同时,提升训练速度和减少显存占用。 - vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为True。 - gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。 -- 🔥packing: 是否使用序列packing,默认为False。当前支持CPT/SFT/DPO。 +- 🔥packing: 是否使用序列packing,默认为False。当前支持CPT/SFT/DPO/KTO。 - packing_length: packing的长度。默认为None,设置为max_length。 - streaming: 流式读取并处理数据集,默认False。 - 注意:因为流式数据集无法获得其长度,因此需要设置`--train_iters`参数。设置`max_epochs`参数确保训练到对应epochs时退出训练,并对权重进行验证和保存。 diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 16edbd35c5..6cd7b192de 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -1,6 +1,6 @@ # 多模态模型 -ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO/KTO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 环境准备请参考Megatron-SWIFT的[快速开始文档](./快速开始.md)。 diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index f5d1a3cc13..13f0bea0cf 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -9,6 +9,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数 | 预训练| ✅ | ✅| ✅ | ✅ | | 指令监督微调 | ✅ | ✅| ✅ | ✅ | | DPO | ✅ | ✅| ✅ | ✅ | +| KTO | ✅ | ✅| ✅ | ✅ | | 分类任务 | ✅ | ✅| ✅ | ✅ | diff --git a/docs/source_en/BestPractices/Qwen3-Best-Practice.md b/docs/source_en/BestPractices/Qwen3-Best-Practice.md index ffc50effdd..931e8b1834 100644 --- a/docs/source_en/BestPractices/Qwen3-Best-Practice.md +++ b/docs/source_en/BestPractices/Qwen3-Best-Practice.md @@ -332,7 +332,7 @@ swift rlhf \ Best practice reference for single-node 8xH20 LoRA training with Qwen3-235B-A22B-Instruct-250718: https://github.com/modelscope/ms-swift/pull/5033. -ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO for large models. Supported models can be found in the [Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO/KTO for large models. Supported models can be found in the [Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md). For environment setup and conversion between HF and MCore model weights, refer to the [Megatron-SWIFT Training Documentation](../Megatron-SWIFT/Quick-start.md). diff --git a/docs/source_en/GetStarted/Quick-start.md b/docs/source_en/GetStarted/Quick-start.md index 845570a712..d36396023c 100644 --- a/docs/source_en/GetStarted/Quick-start.md +++ b/docs/source_en/GetStarted/Quick-start.md @@ -10,7 +10,7 @@ ms-swift is a comprehensive training and deployment framework for large language - Quantization Training: Provides training for quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. - 🍊 RLHF Training: Supports human alignment training methods like DPO, GRPO, RM, PPO, GKD, KTO, CPO, SimPO, ORPO for both text-based and multimodal large models. - 🍓 Multimodal Training: Capable of training models for different modalities such as images, videos, and audios; supports tasks like VQA (Visual Question Answering), Captioning, OCR (Optical Character Recognition), and Grounding. -- 🥥 Megatron Parallelism: Supports accelerating CPT/SFT/DPO using Megatron parallelism techniques, currently compatible with 200+ large language models. +- 🥥 Megatron Parallelism: Supports accelerating CPT/SFT/DPO/KTO using Megatron parallelism techniques, currently compatible with 200+ large language models. - Interface-driven Training: Offers training, inference, evaluation, and quantization capabilities through an interface, enabling a complete workflow for large models. - Plugins and Extensions: Allows customization and extension of models and datasets, and supports customizations for components like loss, metric, trainer, loss-scale, callback, optimizer, etc. - 🍉 Toolbox Capabilities: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 7a9b50c880..4f72f6a52c 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -198,7 +198,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - Note: In ms-swift versions earlier than 3.7.1, the default is None and the value is automatically loaded from config.json. - moe_z_loss_coeff: Scaling coefficient for z-loss. Default is None. - 🔥moe_shared_expert_overlap: Enables overlap between shared expert computation and the dispatcher. If not enabled, shared expert computation will be performed after routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False. -- moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed. +- 🔥moe_expert_capacity_factor: Capacity factor for each expert. `None` means no tokens will be dropped. Default is `None`. When `--moe_expert_capacity_factor` is set, tokens exceeding an expert’s capacity will be dropped based on their selection probability. This can balance the training load and improve training speed (for example, set it to 1.). - moe_pad_expert_input_to_capacity: Pad the input of each expert so that its length aligns with the expert capacity length. Default is `False`. This option only takes effect if `--moe_expert_capacity_factor` is set. - moe_token_drop_policy: Options are 'probs' and 'position'. Default is 'probs'. @@ -248,13 +248,15 @@ LoRA Training: - reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. - label_smoothing: Default is 0. - f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. -- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. +- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. **KTO Parameters**: -- beta: Coefficient for the KL regularization term. Default is `0.1`. -- desirable_weight: Loss weight $\lambda_D$ for desirable response in the KTO algorithm, default is `1.`. -- undesirable_weight: Loss weight $\lambda_U$ for undesirable response in the KTO algorithm, default is `1.`. -- calculate_KL: Whether to calculate KL divergence. Default is `True`. +- ref_load: same meaning as in DPO. +- ref_adapter_load: same meaning as in DPO. +- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. +- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. +- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. ## Training Parameters @@ -267,7 +269,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - mlp_padding_free: The default is False. This is used for applying padding-free optimization to the MLP when padding_free is set to false. It allows for improved training speed and reduced memory usage while customizing the attention_mask. - vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT part during multimodal model training. Default: True. - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Default: None. -- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports CPT/SFT/DPO. +- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports CPT/SFT/DPO/KTO. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - streaming: Stream data loading and processing, default is False. - Note: Since the length of a streaming dataset cannot be determined, the `--train_iters` parameter must be set. Also set the `max_epochs` parameter to ensure training exits after the specified number of epochs, and to validate and save the model weights accordingly. diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index a595946cb5..c72850c08f 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # Multimodal Models -ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO/KTO for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 8e9c1cf8d6..8ba3d4d44a 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -8,6 +8,7 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr | Pretraining | ✅ | ✅ | ✅ | ✅ | | Instruction-supervised fine-tuning | ✅ | ✅ | ✅ | ✅ | | DPO | ✅ | ✅ | ✅ | ✅ | +| KTO | ✅ | ✅ | ✅ | ✅ | | Classification tasks | ✅ | ✅ | ✅ | ✅ | ## Environment Setup diff --git a/examples/megatron/rlhf/kto/dense.sh b/examples/megatron/rlhf/kto/dense.sh new file mode 100644 index 0000000000..cbcb1c63c4 --- /dev/null +++ b/examples/megatron/rlhf/kto/dense.sh @@ -0,0 +1,36 @@ +# 4 * 43GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +megatron rlhf \ + --rlhf_type kto \ + --load Qwen2.5-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#20000' \ + --load_from_cache_file true \ + --split_dataset_ratio 0.01 \ + --tensor_model_parallel_size 4 \ + --packing true \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen2.5-7B-Instruct \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --desirable_weight 1 \ + --undesirable_weight 1 diff --git a/examples/megatron/rlhf/kto/moe.sh b/examples/megatron/rlhf/kto/moe.sh new file mode 100644 index 0000000000..c44936ab40 --- /dev/null +++ b/examples/megatron/rlhf/kto/moe.sh @@ -0,0 +1,44 @@ +# 2 * 48GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron rlhf \ + --rlhf_type kto \ + --load Qwen3-30B-A3B-Instruct-2507-mcore \ + --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#20000' \ + --load_from_cache_file true \ + --packing true \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --split_dataset_ratio 0.01 \ + --expert_model_parallel_size 2 \ + --moe_permute_fusion true \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-3 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --save megatron_output/Qwen3-30B-A3B-Instruct-2507 \ + --eval_interval 100 \ + --save_interval 100 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --desirable_weight 1 \ + --undesirable_weight 1 diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index 43ec3a8004..966c11cb61 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -72,7 +72,7 @@ def prepare_kto_dataset(args, train_dataset, val_dataset): f""" You have different amounts of desirable/positive and undesirable/negative examples but the weights on the desirable and undesirable losses don't seem to be in an ideal range. Based - on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, '{des_weight_upper_bound}] + on your data, we recommend EITHER desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). See the documentation on how to optimally set these weights.""", UserWarning) return train_dataset, val_dataset diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 92834e44d7..88b886ee1c 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -17,6 +17,7 @@ @dataclass class RLHFMegatronArgumentsMixin: + rlhf_type: Literal['dpo', 'kto'] = None ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -25,7 +26,28 @@ class RLHFMegatronArgumentsMixin: reference_free: bool = False label_smoothing: float = 0. f_divergence_type: str = 'reverse_kl' - loss_type: str = 'sigmoid' + loss_type: Optional[str] = None + + # kto + desirable_weight: float = 1. + undesirable_weight: float = 1. + calculate_KL: Optional[bool] = None + + def _init_kto(self): + if self.calculate_KL is None: + # Not all losses require a KL calculation + self.calculate_KL = True + if self.loss_type in ['apo_zero_unpaired']: + self.calculate_KL = False + + def __post_init__(self): + if self.rlhf_type is None: + return + default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'} + if self.loss_type is None: + self.loss_type = default_loss_type[self.rlhf_type] + if self.rlhf_type == 'kto': + self._init_kto() @dataclass @@ -403,6 +425,7 @@ def __post_init__(self): require_version('peft>=0.15') else: require_version('peft>=0.12') + RLHFMegatronArgumentsMixin.__post_init__(self) MegatronTunerMixin.__post_init__(self) os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index 304b8b58fc..74c8c29c1b 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -11,7 +11,3 @@ class MegatronRLHFArguments(MegatronTrainArguments): loss_scale: str = 'last_round' calculate_per_token_loss: bool = False - - desirable_weight: float = 1. - undesirable_weight: float = 1. - calculate_KL: bool = True diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 82b01d9ad3..da964950dc 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,8 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Optional, Union +from swift.llm.train.kto import prepare_kto_dataset from swift.utils import get_logger -from ...llm.train.kto import prepare_kto_dataset from ..argument import MegatronRLHFArguments from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer from .sft import MegatronSft diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 9bff200dcb..1ecd6cd3c0 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime -from typing import Dict +from typing import Dict, Literal import megatron.core import torch @@ -13,14 +13,13 @@ from megatron.core import mpu from megatron.core.dist_checkpointing.mapping import ShardedTensorFactory from megatron.core.enums import ModelType -from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper from megatron.core.utils import StragglerDetector -from megatron.training import (ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, +from megatron.training import (checkpointing, ft_integration, get_args, get_model, get_tensorboard_writer, get_timers, get_wandb_writer, is_last_rank, one_logger_utils, pretrain, print_rank_0, print_rank_last, training) from megatron.training.checkpointing import load_checkpoint @@ -28,14 +27,14 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version -from transformers.utils import ContextManagers from swift.llm import dynamic_gradient_checkpointing from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model -from .utils import get_swift_datasets_provider +from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, + get_swift_datasets_provider) logger = get_logger() @@ -123,7 +122,7 @@ def new_cyclic_iter(self, iterable): yield x i += 1 - def _replace_data_iterator(self, data_iterator): + def _replace_data_iterator(self, data_iterator, model): return data_iterator @staticmethod @@ -151,7 +150,6 @@ def sh_ten_merge_fn(sub_state_dict): def _load_adapter_base_checkpoint(self, *_args, **kwargs): adapter_name = kwargs.pop('adapter_name', None) or 'ref_adapter' - from megatron.training import checkpointing sharded_state_dict = kwargs.get('sharded_state_dict') if sharded_state_dict is None: return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) @@ -180,7 +178,6 @@ def _load_adapter_base_checkpoint(self, *_args, **kwargs): return res def _load_base_checkpoint(self, *_args, **kwargs): - from megatron.training import checkpointing sharded_state_dict = kwargs.get('sharded_state_dict') if sharded_state_dict is None: return checkpointing.origin__load_base_checkpoint(*_args, **kwargs) @@ -224,7 +221,6 @@ def _load_base_checkpoint(self, *_args, **kwargs): @contextmanager def _patch_load_state_dict(self, load_base_checkpoint): - from megatron.training import checkpointing checkpointing.origin__load_base_checkpoint = checkpointing._load_base_checkpoint checkpointing._load_base_checkpoint = load_base_checkpoint @@ -317,15 +313,16 @@ def _initialize_embedding(model): tensor = module.weight.new_empty(num_to_initialize, module.weight.shape[1]) module.weight.data[initialize_mask] = init_method(tensor) - def _all_reduce_metric(self, metric: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def _all_reduce_metric(self, + metric: Dict[str, torch.Tensor], + reduction=torch.distributed.ReduceOp.AVG) -> Dict[str, torch.Tensor]: values = list(metric.values()) reporting_metric = values[0].new_tensor(values) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - new_data_iterator = self._replace_data_iterator(data_iterator) + new_data_iterator = self._replace_data_iterator(data_iterator, model) return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, config) @@ -374,7 +371,7 @@ def evaluate(self, # Don't care about timing during evaluation config.timers = None ft_integration.on_eval_step_start() - new_data_iterator = self._replace_data_iterator(data_iterator) + new_data_iterator = self._replace_data_iterator(data_iterator, model) loss_dicts = forward_backward_func( forward_step_func=forward_step_func, data_iterator=new_data_iterator, @@ -458,11 +455,7 @@ def evaluate(self, timers('evaluate').stop() timers.log(['evaluate']) - - total_loss_dict.update({ - k: torch.tensor([v], device='cuda') - for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics['eval'], 'eval_').items() - }) + self.custom_log(total_loss_dict, 'eval') rerun_state_machine.set_mode(rerun_mode) if is_last_rank(): logs = {} @@ -471,6 +464,13 @@ def evaluate(self, self.jsonl_writer.append(logs) return total_loss_dict, collected_non_loss_data, False + def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: + advanced_iters = total_loss_dict['advanced iterations'] if mode == 'train' else 1 + total_loss_dict.update({ + k: torch.tensor([v * advanced_iters], device='cuda') + for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics[mode]).items() + }) + # Code borrowed from NVIDIA/Megatron-LM def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): @@ -619,10 +619,7 @@ def training_log(self, loss_dict, total_loss_dict, learning_rate, decoupled_lear mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) if iteration % args.log_interval == 0 or iteration == 1: - total_loss_dict.update({ - k: torch.tensor([v * total_loss_dict[advanced_iters_key]], device='cuda') - for k, v in SwiftMixin.compute_custom_metrics(self.custom_metrics['train']).items() - }) + self.custom_log(total_loss_dict, 'train') origin_total_loss_dict = total_loss_dict.copy() if args.record_memory_history and is_last_rank(): @@ -802,68 +799,26 @@ def build_pretraining_data_loader(*_args, **kwargs): def forward_step(self, data_iterator, model): pass - -class MegatronRLHFTrainer(BaseMegatronTrainer): - - def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): - args = get_args() - if args.train_type == 'full': - ref_models = get_model(model_provider_func, model_type, wrap_with_ddp=False) - for m in ref_models: - m = unwrap_model(m) - m.requires_grad_(False).eval() - if args.ref_load is None: - args.ref_load = args.load - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_models, None, None, load_arg='ref_load') - self.ref_models = ref_models - return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) - - @contextmanager - def null_ref_context(self): + def _prepare_batch(self, data, vp_stage, num_samples=None): + batch = get_batch_on_this_tp_rank(data, vp_stage=vp_stage) + if num_samples is None: + num_samples = batch.pop('num_samples') args = get_args() - contexts = [] - if args.train_type == 'full': - ref_models = self.ref_models - else: - if args.ref_adapter_load is None: - for m in self.peft_models: - contexts.append(m.disable_adapter()) - ref_models = self.unwrapped_models - with ContextManagers(contexts): - if args.ref_adapter_load: - for m in self.peft_models: - m.set_adapter('ref_adapter') - yield ref_models - if args.ref_adapter_load: - for m in self.peft_models: - m.set_adapter('default') - - @staticmethod - def _forward_step_helper(model, inputs): + text_position_ids = batch.pop('text_position_ids', None) + if text_position_ids is None: + text_position_ids = batch.get('position_ids') + if args.padding_free and text_position_ids is not None: + batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) + batch['packed_seq_params'].num_samples = num_samples + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + return batch + + def get_batch(self, data_iterator, vp_stage=None): + """Generate a batch.""" args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor + data = next(data_iterator) + is_finished = data.pop('is_finished', False) + if is_finished: + args.train_iters = args.curr_iteration + 1 + return self._prepare_batch(data, vp_stage) diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 783dea8a6e..c067612201 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -9,8 +9,7 @@ from swift.trainers import DPOTrainer from swift.utils import get_current_device, get_logger -from .base import MegatronRLHFTrainer -from .utils import get_batch +from .rlhf_mixin import MegatronRLHFTrainer logger = get_logger() @@ -37,22 +36,6 @@ def __init__(self, args, template): self.dummy_dpo_trainer = DummyDPOTrainer(args) self.ref_models = [] - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - num_samples = packed_seq_params.num_samples - cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((num_samples * 2, )) - for i in range(num_samples * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - def loss_func(self, output_tensor: torch.Tensor, *, labels: torch.Tensor, packed_seq_params): ref_output_tensor = output_tensor[:output_tensor.shape[0] // 2].detach() output_tensor = output_tensor[output_tensor.shape[0] // 2:] @@ -99,20 +82,21 @@ def forward_step(self, data_iterator, model): # Get the batch. unwrapped_model = model.module.module input_tensor = unwrapped_model.get_input_tensor() - if input_tensor is not None: - unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:]) vp_stage = unwrapped_model.vp_stage + timers('batch-generator', log_level=2).start() + with self.stimer(bdata=True): + data = self.get_batch(data_iterator, vp_stage) + timers('batch-generator').stop() + data.pop('loss_scale', None) + # ref_model with torch.no_grad(), self.null_ref_context() as ref_models: ref_model = ref_models[vp_stage or 0] if input_tensor is not None: ref_model.set_input_tensor(input_tensor[:input_tensor.shape[0] // 2].detach()) - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = get_batch(data_iterator, vp_stage) - timers('batch-generator').stop() - data.pop('loss_scale', None) ref_output_tensor = ref_model(**data) + if input_tensor is not None: + unwrapped_model.set_input_tensor(input_tensor[input_tensor.shape[0] // 2:]) with self.stimer: output_tensor = model(**data) return torch.concat([ref_output_tensor, output_tensor], dim=0), partial( diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index 46bd4deb72..a3d8cd2f01 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -1,253 +1,179 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple from functools import partial +from typing import Literal import torch -import torch.distributed as dist -import torch.nn.functional as F from megatron.core import mpu -from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy -from megatron.training import get_args -from megatron.training.utils import unwrap_model -from torch.distributed.nn import all_reduce +from megatron.training import get_args, get_timers +from trl import KTOTrainer from swift.utils import get_current_device, get_logger -from .base import MegatronRLHFTrainer -from .utils import get_kto_batch +from .rlhf_mixin import MegatronRLHFTrainer logger = get_logger() -class MegatronKTOTrainer(MegatronRLHFTrainer): +class DummyKTOTrainer(KTOTrainer): + # For reusing the kto_loss function in TRL. - def __init__(self, args, template): - super().__init__(args, template) + def gather_for_metrics(self, input_data, *args, **kwargs): + output_tensors = torch.empty( + mpu.get_data_parallel_world_size() * input_data.numel(), + dtype=input_data.dtype, + device=input_data.device, + ) + torch.distributed.all_gather_into_tensor(output_tensors, input_data, group=mpu.get_data_parallel_group()) + return output_tensors + + def __init__(self, args): + self.accelerator = namedtuple('Accelerator', ['device', 'gather_for_metrics'])( + device=get_current_device(), gather_for_metrics=self.gather_for_metrics) + self.loss_type = args.loss_type self.beta = args.beta self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight self.calculate_KL = args.calculate_KL - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params=None): - args = get_args() - if output_tensor is None: - return None - - shifted_logits = output_tensor[:, :-1, :].contiguous() - shifted_labels = labels[:, 1:].contiguous() - - logits_for_loss = shifted_logits.transpose(0, 1).contiguous() - labels_for_loss = shifted_labels.transpose(0, 1).contiguous() - - per_token_cross_entropy_loss = vocab_parallel_cross_entropy( - logits_for_loss, labels_for_loss, label_smoothing=0.0) - - per_token_logps = -per_token_cross_entropy_loss - loss_mask = (labels_for_loss != -100) - masked_logps = per_token_logps * loss_mask - - if args.padding_free and packed_seq_params is not None: - flattened_logps = masked_logps.squeeze(1) # [seq-1] - - cu_seqlens = packed_seq_params.cu_seqlens_q - num_sequences = cu_seqlens.shape[0] - 1 - all_logps = flattened_logps.new_zeros((num_sequences, )) - for i in range(num_sequences): - start_index, end_index = cu_seqlens[i], cu_seqlens[i + 1] - 1 - if end_index > start_index: - all_logps[i] = flattened_logps[start_index:end_index].sum() - else: - all_logps = masked_logps.sum(dim=0) - - if args.context_parallel_size > 1: - all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) - - return all_logps - @staticmethod - def kto_loss(policy_chosen_logps, policy_rejected_logps, policy_KL_logps, reference_chosen_logps, - reference_rejected_logps, reference_KL_logps, beta, desirable_weight, undesirable_weight, calculate_KL, - device): - if calculate_KL and policy_KL_logps is not None and reference_KL_logps is not None: - kl = (policy_KL_logps - reference_KL_logps).mean().detach() - dist.all_reduce(kl, group=mpu.get_data_parallel_group()) - kl = kl / mpu.get_data_parallel_world_size() - kl = kl.clamp(min=0) - else: - kl = torch.tensor(0.0, device=device) - - chosen_rewards = torch.tensor([], device=kl.device) - if policy_chosen_logps.shape[0] > 0: - chosen_logratios = policy_chosen_logps - reference_chosen_logps - chosen_losses = 1 - F.sigmoid(beta * (chosen_logratios - kl)) - chosen_rewards = beta * chosen_logratios.detach() - else: - chosen_losses = torch.tensor([], device=kl.device) +class MegatronKTOTrainer(MegatronRLHFTrainer): - rejected_rewards = torch.tensor([], device=kl.device) - if policy_rejected_logps.shape[0] > 0: - rejected_logratios = policy_rejected_logps - reference_rejected_logps - rejected_losses = 1 - F.sigmoid(beta * (kl - rejected_logratios)) - rejected_rewards = beta * rejected_logratios.detach() + def __init__(self, args, template): + super().__init__(args, template) + assert args.padding_free, 'Currently `rlhf_type="kto"` only supports padding_free.' + self.dummy_kto_trainer = DummyKTOTrainer(args) + + def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool, length: int): + labels = data['labels'] + packed_seq_params = data['packed_seq_params'] + output = self._get_input_tensor(output_tensor, is_KL, is_ref, length, dim=1) + return self.get_logps(output, labels, packed_seq_params, packed_seq_params.num_samples) + + def loss_func(self, output_tensor, *, data, kl_data, label): + length = data['packed_seq_params'].cu_seqlens_q[-1] + policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) + ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) + if self.args.calculate_KL: + policy_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, False, length) + ref_KL_logps = self._kto_get_logps(output_tensor, kl_data, True, True, length) else: - rejected_losses = torch.tensor([], device=kl.device) - - losses = torch.cat((desirable_weight * chosen_losses, undesirable_weight * rejected_losses), 0) - return losses, chosen_rewards, rejected_rewards, kl - - def loss_func(self, output_tensor, *, policy_KL_logps, reference_logps, reference_KL_logps, labels, all_labels, - packed_seq_params): - policy_logps = self.get_logps(output_tensor, labels, packed_seq_params) - is_desirable = all_labels.bool() - - policy_chosen_logps = policy_logps[is_desirable] - policy_rejected_logps = policy_logps[~is_desirable] - reference_chosen_logps = reference_logps[is_desirable] - reference_rejected_logps = reference_logps[~is_desirable] - - loss, chosen_rewards, rejected_rewards, kl = self.kto_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - policy_KL_logps=policy_KL_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - reference_KL_logps=reference_KL_logps, - beta=self.beta, - desirable_weight=self.desirable_weight, - undesirable_weight=self.undesirable_weight, - calculate_KL=self.calculate_KL, - device=policy_logps.device, + policy_KL_logps, ref_KL_logps = None, None + label = output_tensor.new_tensor(label, dtype=torch.bool) + policy_chosen_logps = policy_logps[label] + policy_rejected_logps = policy_logps[~label] + ref_chosen_logps = ref_logps[label] + ref_rejected_logps = ref_logps[~label] + + loss, chosen_rewards, rejected_rewards, kl = self.dummy_kto_trainer.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + ref_chosen_logps, + ref_rejected_logps, + ref_KL_logps, ) - loss = loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=policy_logps.device) - - with torch.no_grad(): - chosen_rewards_mean = chosen_rewards.mean() if chosen_rewards.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - rejected_rewards_mean = rejected_rewards.mean() if rejected_rewards.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - policy_chosen_logps_mean = policy_chosen_logps.mean() if policy_chosen_logps.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - policy_rejected_logps_mean = policy_rejected_logps.mean( - ) if policy_rejected_logps.numel() > 0 else torch.tensor( - 0.0, device=loss.device) - - metric = { - 'loss': loss.clone().detach(), - 'logps/chosen': policy_chosen_logps_mean, - 'logps/rejected': policy_rejected_logps_mean, - 'rewards/chosen': chosen_rewards_mean, - 'rewards/rejected': rejected_rewards_mean, - 'rewards/margins': chosen_rewards_mean - rejected_rewards_mean, - 'kl': kl.detach() if kl is not None else torch.tensor(0.0, device=loss.device), + loss = loss.mean() + mean_metric = { + 'loss': loss.detach().clone(), + 'kl': kl.detach(), } - - reporting_metric = loss.new_tensor(list(metric.values())) - torch.distributed.all_reduce( - reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) - reporting_metric = {k: reporting_metric[i] for i, k in enumerate(metric.keys())} + metric = self._all_reduce_metric(mean_metric) + sum_metric = { + 'logps/chosen_sum': policy_chosen_logps.nansum(), + 'logps/rejected_sum': policy_rejected_logps.nansum(), + 'rewards/chosen_sum': chosen_rewards.nansum(), + 'rewards/rejected_sum': rejected_rewards.nansum(), + 'count/chosen': loss.new_tensor(chosen_rewards.shape[0]), + 'count/rejected': loss.new_tensor(rejected_rewards.shape[0]), + } + metric.update(self._all_reduce_metric(sum_metric, torch.distributed.ReduceOp.SUM)) # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 loss = loss / mpu.get_context_parallel_world_size() - return loss, reporting_metric - - def _replace_data_iterator_with_model(self, data_iterator, model): - args = get_args() - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - - processed_data_list = [] - policy_model = unwrap_model(model)[0] - - for _ in range(num_iters_per_step): - with torch.no_grad(), self.null_ref_context() as ref_models: - assert len(ref_models) == 1, 'KTO currently does not support VPP.' - data = self.ref_forward(ref_models[0], data_iterator) + return loss, metric - if self.calculate_KL: - with torch.no_grad(): - kl_inputs = { - 'input_ids': data.get('KL_completion_input_ids'), - 'attention_mask': data.get('KL_completion_attention_mask'), - 'position_ids': data.get('KL_completion_position_ids'), - } - - kl_output_tensor = self._forward_step_helper(policy_model, kl_inputs) + @staticmethod + def _get_input_tensor(input_tensor, is_KL: bool, is_ref: bool, length: int, dim: int): + # policy, ref, policy_KL, ref_KL + total_length = input_tensor.shape[dim] + KL_length = (total_length - 2 * length) // 2 + slice_list = [0, length, 2 * length, total_length - KL_length, total_length] + idx = is_KL * 2 + is_ref + slice_ = (slice(None), ) * dim + (slice(slice_list[idx], slice_list[idx + 1]), ) + res = input_tensor[slice_] + if is_KL or is_ref: + res = res.detach() + return res - policy_KL_logps = self.get_logps(kl_output_tensor, data['KL_completion_labels'], - data.get('KL_completion_packed_seq_params')) - data['policy_KL_logps'] = policy_KL_logps + def forward_step(self, data_iterator, model): + timers = get_timers() + # Get the batch. + unwrapped_model = model.module.module + input_tensor = unwrapped_model.get_input_tensor() + vp_stage = unwrapped_model.vp_stage + timers('batch-generator', log_level=2).start() + with self.stimer(bdata=True): + # not support loss_scale + data, kl_data = self.get_batch(data_iterator, vp_stage) + timers('batch-generator').stop() + label = data.pop('label') + data.pop('loss_scale', None) + kl_data.pop('loss_scale', None) - processed_data_list.append(data) + length = data['packed_seq_params'].cu_seqlens_q[-1] - return iter(processed_data_list) + with torch.no_grad(), self.null_ref_context() as ref_models: + ref_model = ref_models[vp_stage or 0] + if self.args.calculate_KL: + if input_tensor is not None: + ref_model.set_input_tensor(self._get_input_tensor(input_tensor, True, True, length, 0)) + ref_KL_output_tensor = ref_model(**kl_data) - def ref_forward(self, ref_model, data_iterator): - with self.stimer(bdata=True): - data = get_kto_batch(data_iterator) - data.pop('loss_scale', None) + if input_tensor is not None: + ref_model.set_input_tensor(self._get_input_tensor(input_tensor, False, True, length, 0)) + ref_output_tensor = ref_model(**data) - ref_inputs = { - 'input_ids': data.get('completion_input_ids'), - 'attention_mask': data.get('completion_attention_mask'), - 'position_ids': data.get('completion_position_ids'), - } - with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, ref_inputs) - data['reference_logps'] = self.get_logps(output_tensor, data['completion_labels'], - data.get('completion_packed_seq_params')) - - if self.calculate_KL: - kl_inputs = { - 'input_ids': data.get('KL_completion_input_ids'), - 'attention_mask': data.get('KL_completion_attention_mask'), - 'position_ids': data.get('KL_completion_position_ids'), - } + if self.args.calculate_KL: with torch.no_grad(): - kl_output_tensor = self._forward_step_helper(ref_model, kl_inputs) - data['reference_KL_logps'] = self.get_logps(kl_output_tensor, data['KL_completion_labels'], - data.get('KL_completion_packed_seq_params')) + if input_tensor is not None: + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, True, False, length, 0)) + KL_output_tensor = model(**kl_data) + + if input_tensor is not None: + unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, 0)) + with self.stimer: + output_tensor = model(**data) + dim = 1 if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else 0 + if self.args.calculate_KL: + res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=dim) else: - data['reference_KL_logps'] = None - return data - - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - new_data_iterator = self._replace_data_iterator_with_model(data_iterator, model) - return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler, - config) - - def forward_step(self, data_iterator, model): - data = next(data_iterator) - - reference_logps = data.pop('reference_logps') - reference_KL_logps = data.pop('reference_KL_logps', None) - policy_KL_logps = data.pop('policy_KL_logps', None) - all_labels = torch.tensor(data.pop('label')).to(get_current_device()) - completion_packed_seq_params = data.get('completion_packed_seq_params') - - main_inputs = { - 'input_ids': data['completion_input_ids'], - 'attention_mask': data.get('completion_attention_mask'), - 'position_ids': data.get('completion_position_ids') - } - with self.stimer(): - output_tensor = model(**main_inputs) - - return output_tensor, partial( - self.loss_func, - policy_KL_logps=policy_KL_logps, - reference_logps=reference_logps, - reference_KL_logps=reference_KL_logps, - labels=data['completion_labels'], - all_labels=all_labels, - packed_seq_params=completion_packed_seq_params) - - def evaluate(self, - forward_step_func, - data_iterator, - model, - process_non_loss_data_func, - config, - verbose=False, - non_loss_data_func=None): - self._replace_data_iterator = partial(self._replace_data_iterator_with_model, model=model) - return super().evaluate(forward_step_func, data_iterator, model, process_non_loss_data_func, config, verbose, - non_loss_data_func) + res = torch.concat([output_tensor, ref_output_tensor], dim=dim) + return res, partial(self.loss_func, data=data, kl_data=kl_data, label=label) + + def _prepare_batch(self, data, vp_stage): + res = [] + num_samples = data.pop('num_samples') + for key in ['completion_', 'KL_completion_']: + _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} + res.append(super()._prepare_batch(_data, vp_stage, num_samples)) + res[0]['label'] = data['label'] + return res + + def custom_log(self, total_loss_dict, mode: Literal['train', 'eval']) -> None: + super().custom_log(total_loss_dict, mode) + res = {} + for k, v in total_loss_dict.items(): + if k.startswith('count/') or k.endswith('_sum'): + continue + res[k] = v + for key in ['chosen', 'rejected']: + count = total_loss_dict.get(f'count/{key}') + if count is None or count.item() == 0: + continue + res[f'logps/{key}'] = total_loss_dict[f'logps/{key}_sum'] / count + res[f'rewards/{key}'] = total_loss_dict[f'rewards/{key}_sum'] / count + if 'rewards/chosen' in res and 'rewards/rejected' in res: + res['rewards/margins'] = res['rewards/chosen'] - res['rewards/rejected'] + total_loss_dict.clear() + total_loss_dict.update(res) diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py new file mode 100644 index 0000000000..ead111435e --- /dev/null +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -0,0 +1,99 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from contextlib import contextmanager + +import torch +import torch.nn +from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.training import get_args, get_model +from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model +from torch.distributed.nn import all_reduce +from transformers.utils import ContextManagers + +from swift.utils import get_logger +from .base import BaseMegatronTrainer + +logger = get_logger() + + +class MegatronRLHFTrainer(BaseMegatronTrainer): + + def setup_model_and_optimizer(self, model_provider_func, model_type, *_args, **kwargs): + args = get_args() + if args.train_type == 'full': + ref_models = get_model(model_provider_func, model_type, wrap_with_ddp=False) + for m in ref_models: + m = unwrap_model(m) + m.requires_grad_(False).eval() + if args.ref_load is None: + args.ref_load = args.load + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_models, None, None, load_arg='ref_load') + self.ref_models = ref_models + return super().setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) + + @contextmanager + def null_ref_context(self): + args = get_args() + contexts = [] + if args.train_type == 'full': + ref_models = self.ref_models + else: + if args.ref_adapter_load is None: + for m in self.peft_models: + contexts.append(m.disable_adapter()) + ref_models = self.unwrapped_models + with ContextManagers(contexts): + if args.ref_adapter_load: + for m in self.peft_models: + m.set_adapter('ref_adapter') + yield ref_models + if args.ref_adapter_load: + for m in self.peft_models: + m.set_adapter('default') + + @staticmethod + def _forward_step_helper(model, inputs): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + if num_samples is None: + num_samples = packed_seq_params.num_samples * 2 + cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size + all_logps = per_token_logps.new_zeros((num_samples, )) + for i in range(num_samples): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 7e87f63eb7..98422b8c43 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -12,7 +12,6 @@ from swift.utils import get_logger from .base import BaseMegatronTrainer -from .utils import get_batch logger = get_logger() @@ -139,7 +138,7 @@ def forward_step(self, data_iterator, model): vp_stage = model.module.module.vp_stage timers('batch-generator', log_level=2).start() with self.stimer(bdata=True): - data = get_batch(data_iterator, vp_stage) + data = self.get_batch(data_iterator, vp_stage) timers('batch-generator').stop() loss_scale = data.pop('loss_scale', None) channels = data.pop('channel', None) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 0a301e479d..35dd538f0d 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -27,29 +27,21 @@ def swift_datasets_provider(train_val_test_num_samples): # Code borrowed from NVIDIA/Megatron-LM -def get_batch_on_this_tp_rank(data_iterator, vp_stage=None): +def get_batch_on_this_tp_rank(data, vp_stage=None): args = get_args() - data = next(data_iterator) - is_finished = data.pop('is_finished', False) if args.task_type == 'causal_lm': data['labels'] = torch.roll(data['labels'], -1, dims=-1) if 'loss_scale' in data: data['loss_scale'] = torch.roll(data['loss_scale'], -1, dims=-1) batch = to_device(data, 'cuda', non_blocking=True) if args.pipeline_model_parallel_size == 1: - pass - elif mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + return batch + if not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + batch['input_ids'] = None + if not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): batch['labels'] = None batch['loss_scale'] = None - elif mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): - batch['input_ids'] = None - else: - for key in ('input_ids', 'labels', 'loss_scale'): - batch[key] = None - - if is_finished: - args.train_iters = args.curr_iteration + 1 return batch @@ -115,83 +107,3 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) return batch - - -def get_batch(data_iterator, vp_stage=None): - """Generate a batch.""" - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank(data_iterator, vp_stage=vp_stage) - args = get_args() - num_samples = batch.pop('num_samples') - text_position_ids = batch.pop('text_position_ids', None) - if text_position_ids is None: - text_position_ids = batch.get('position_ids') - if args.padding_free and text_position_ids is not None: - batch['packed_seq_params'] = get_packed_seq_params(text_position_ids) - batch['packed_seq_params'].num_samples = num_samples - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - return batch - - -def get_kto_batch(data_iterator): - """Generate a kto batch.""" - args = get_args() - - data = next(data_iterator) - is_finished = data.pop('is_finished', False) - - batch = to_device(data, 'cuda', non_blocking=True) - - kto_tensor_keys = [ - 'completion_input_ids', 'completion_labels', 'completion_attention_mask', 'completion_position_ids', - 'KL_completion_input_ids', 'KL_completion_labels', 'KL_completion_attention_mask', 'KL_completion_position_ids' - ] - - # pp - if args.pipeline_model_parallel_size == 1: - pass - elif mpu.is_pipeline_first_stage(): - for key in kto_tensor_keys: - if 'labels' in key: - batch[key] = None - elif mpu.is_pipeline_last_stage(): - for key in kto_tensor_keys: - if 'input_ids' in key: - batch[key] = None - else: - for key in kto_tensor_keys: - batch[key] = None - - # Padding-Free - num_samples = batch.get('num_samples') - if args.padding_free: - if 'completion_position_ids' in batch and batch['completion_position_ids'] is not None: - batch['completion_packed_seq_params'] = get_packed_seq_params(batch['completion_position_ids']) - if num_samples is not None: - batch['completion_packed_seq_params'].num_samples = num_samples - - if 'KL_completion_position_ids' in batch and batch['KL_completion_position_ids'] is not None: - batch['KL_completion_packed_seq_params'] = get_packed_seq_params(batch['KL_completion_position_ids']) - if num_samples is not None: - batch['KL_completion_packed_seq_params'].num_samples = num_samples - - # cp - cp_size = mpu.get_context_parallel_world_size() - if cp_size > 1: - completion_psp = batch.get('completion_packed_seq_params') - kl_psp = batch.get('KL_completion_packed_seq_params') - - if completion_psp is None and kl_psp is None: - batch = mcore_get_batch_on_this_cp_rank(batch) - else: - for key, val in batch.items(): - if key in kto_tensor_keys and val is not None: - if key.startswith('KL_completion_') and kl_psp is not None: - batch[key] = split_cp_inputs(val, kl_psp.cu_seqlens_q, -1) - elif key.startswith('completion_') and completion_psp is not None: - batch[key] = split_cp_inputs(val, completion_psp.cu_seqlens_q, -1) - - if is_finished: - args.train_iters = args.curr_iteration + 1 - return batch diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainer/kto_trainer.py index 9d93abe5bf..f24529aa68 100644 --- a/swift/trainers/rlhf_trainer/kto_trainer.py +++ b/swift/trainers/rlhf_trainer/kto_trainer.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch import torch.nn as nn