From ce344eded52bb9e9432db606791ac1471f3e4489 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 10:21:38 +0800 Subject: [PATCH 1/6] fix non-padding_free qwen3_vl --- swift/megatron/model/mm_gpt_model.py | 2 +- swift/megatron/trainers/utils.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 3cceddb1e4..391904a502 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -67,7 +67,7 @@ def forward(_self, input_): kwargs.update(res) res = inputs_embeds if args.context_parallel_size > 1: - res = split_cp_inputs(res, packed_seq_params.cu_seqlens_q, 1) + res = split_cp_inputs(res, None if packed_seq_params is None else packed_seq_params.cu_seqlens_q, 1) if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() group_kwargs = {'group': _self.tp_group} if mcore_013 else {} diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 3c552578fb..697d554995 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -86,7 +86,7 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: qkv_format='thd') -def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int): +def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): # TODO: compat bshd if dim < 0: dim = (dim + inputs.ndim) % inputs.ndim @@ -127,15 +127,14 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): keys.append('input_ids') packed_seq_params = batch.get('packed_seq_params') - if packed_seq_params is None: - return mcore_get_batch_on_this_cp_rank(batch) for key, val in batch.items(): if key not in keys: continue if args.task_type == 'seq_cls' and key == 'labels': continue if val is not None: - batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) + batch[key] = split_cp_inputs(val, None if packed_seq_params is None else packed_seq_params.cu_seqlens_q, + -1) return batch From 7c0cf2a25e307dfe37e48269002a685635f3da5d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 10:30:02 +0800 Subject: [PATCH 2/6] update --- swift/megatron/trainers/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 697d554995..f29738ecd0 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -13,7 +13,6 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.optimizer import ChainedOptimizer from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank from megatron.training import get_args, get_wandb_writer from packaging import version @@ -87,16 +86,18 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): - # TODO: compat bshd if dim < 0: dim = (dim + inputs.ndim) % inputs.ndim new_inputs = [] cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() - for i in range(cu_seqlens.shape[0] - 1): - slices = [slice(None)] * inputs.ndim - slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) - val = inputs[tuple(slices)] + for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)): + if cu_seqlens is None: + val = inputs + else: + slices = [slice(None)] * inputs.ndim + slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) + val = inputs[tuple(slices)] view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', From da6f8acb453321c383fc47fda653254dfb48bf76 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 10:49:15 +0800 Subject: [PATCH 3/6] update --- swift/megatron/init.py | 4 ---- swift/megatron/model/mm_gpt/qwen3_vl.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index f38a20fcb2..b1c5ce9843 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -729,10 +729,6 @@ def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = Fal # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group()) return emb MultimodalRotaryEmbedding.forward = forward diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index 52e8e8d7f9..89c825e5a7 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -122,12 +122,12 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config): # compat cp args = get_args() if args.context_parallel_size > 1: - assert packed_seq_params is not None device = visual_pos_masks.device cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device) cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device) - cp_mask = split_cp_inputs(cp_mask, packed_seq_params.cu_seqlens_q, 0) - visual_pos_masks = split_cp_inputs(visual_pos_masks, packed_seq_params.cu_seqlens_q, 0) + cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q + cp_mask = split_cp_inputs(cp_mask, cu_seqlens, 0) + visual_pos_masks = split_cp_inputs(visual_pos_masks, cu_seqlens, 0) deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]] # compat sp tp_world_size = parallel_state.get_tensor_model_parallel_world_size() From 72cf9a2055bec60a2ecbfe6bd407a4cb3eb28a90 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 10:51:14 +0800 Subject: [PATCH 4/6] update --- swift/megatron/init.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index b1c5ce9843..5857f77c6e 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -670,10 +670,8 @@ def _write_item(self, *args, **kwargs): def _patch_mrope(): from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding - from megatron.core import parallel_state import megatron.core - from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank, - _apply_rotary_pos_emb_bshd) + from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd from megatron.core.models.common.embeddings import rope_utils from megatron.training import get_args From 27218dad39d54078b28765247296ace22164e7f5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 11:03:10 +0800 Subject: [PATCH 5/6] update --- swift/megatron/model/mm_gpt/qwen3_vl.py | 2 +- swift/megatron/model/mm_gpt_model.py | 2 +- swift/megatron/trainers/utils.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/swift/megatron/model/mm_gpt/qwen3_vl.py b/swift/megatron/model/mm_gpt/qwen3_vl.py index 89c825e5a7..99145fead4 100644 --- a/swift/megatron/model/mm_gpt/qwen3_vl.py +++ b/swift/megatron/model/mm_gpt/qwen3_vl.py @@ -125,7 +125,7 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config): device = visual_pos_masks.device cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device) cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device) - cu_seqlens = None if packed_seq_params is None else packed_seq_params.cu_seqlens_q + cu_seqlens = getattr(packed_seq_params, 'cu_seqlens_q', None) cp_mask = split_cp_inputs(cp_mask, cu_seqlens, 0) visual_pos_masks = split_cp_inputs(visual_pos_masks, cu_seqlens, 0) deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]] diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 391904a502..83c8dd3ade 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -67,7 +67,7 @@ def forward(_self, input_): kwargs.update(res) res = inputs_embeds if args.context_parallel_size > 1: - res = split_cp_inputs(res, None if packed_seq_params is None else packed_seq_params.cu_seqlens_q, 1) + res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() group_kwargs = {'group': _self.tp_group} if mcore_013 else {} diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index f29738ecd0..0a05f64194 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -134,8 +134,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): if args.task_type == 'seq_cls' and key == 'labels': continue if val is not None: - batch[key] = split_cp_inputs(val, None if packed_seq_params is None else packed_seq_params.cu_seqlens_q, - -1) + batch[key] = split_cp_inputs(val, getattr(packed_seq_params, 'cu_seqlens_q', None), -1) return batch From 50a532a927abf8a37a097ba2aa18b9b51826a76f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 29 Dec 2025 11:07:11 +0800 Subject: [PATCH 6/6] update --- docs/source/Instruction/Command-line-parameters.md | 2 +- docs/source/Megatron-SWIFT/Command-line-parameters.md | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index a45fefcc91..30c3bfc6ee 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -466,7 +466,7 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数, - add_version: 在output_dir上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。 - check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。**如果是断网环境,请设置为False**。 - 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。 -- 🔥packing: 将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attn_impl flash_attn` 时,可确保packed样本内的不同序列之间相互独立,互不可见。该参数默认为`False`,目前支持 CPT/SFT/DPO/KTO/GKD。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 +- 🔥packing: 使用`padding_free`的方式将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attn_impl flash_attn` 时,可确保packed样本内的不同序列之间相互独立,互不可见。该参数默认为`False`,目前支持 CPT/SFT/DPO/KTO/GKD。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 - "ms-swift>=3.12"新支持了embedding/reranker/seq_cls任务的packing。 - packing_length: packing的长度。默认为None,设置为max_length。 - packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效)。通常不需要修改该值,packing速度远快于tokenize速度。 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 64b3be63de..dd35923dff 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -300,7 +300,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - 提示:在日志中打印的"learning rate"为llm的学习率。 - aligner_lr: 当训练多模态大模型时,该参数指定aligner的学习率,默认为None,等于learning_rate。 - gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。 -- 🔥packing: 将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attention_backend flash` 时,可确保packed样本内的不同序列之间相互独立,互不可见(除Qwen3-Next,因为含有linear-attention)。该参数默认为`False`。Megatron-SWIFT的所有训练任务都支持该参数。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 +- 🔥packing: 使用`padding_free`的方式将不同长度的数据样本打包成**近似**统一长度的样本(packing能保证不对完整的序列进行切分),实现训练时各节点与进程的负载均衡(避免长文本拖慢短文本的训练速度),从而提高GPU利用率,保持显存占用稳定。当使用 `--attention_backend flash` 时,可确保packed样本内的不同序列之间相互独立,互不可见(除Qwen3-Next,因为含有linear-attention)。该参数默认为`False`。Megatron-SWIFT的所有训练任务都支持该参数。注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 - packing_length: packing的长度。默认为None,设置为max_length。 - packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效)。通常不需要修改该值,packing速度远快于tokenize速度。 - streaming: 流式读取并处理数据集,默认False。(流式数据集的随机并不彻底,可能导致loss波动剧烈。) diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index c251bbfb92..7afebd4a2c 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -476,7 +476,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - add_version: Add directory to output_dir with `'-'` to prevent weight overwrite, default is True. - check_model: Check local model files for corruption or modification and give a prompt, default is True. **If in an offline environment, please set to False.** - 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively. -- 🔥packing: Packs data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attn_impl flash_attn`, it ensures that different sequences within packed samples remain independent and invisible to each other. This parameter defaults to `False` and currently supports CPT/SFT/DPO/KTO/GKD. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**. +- 🔥packing: Use the `padding_free` method to pack data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attn_impl flash_attn`, it ensures that different sequences within packed samples remain independent and invisible to each other. This parameter defaults to `False` and currently supports CPT/SFT/DPO/KTO/GKD. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**. - "ms-swift>=3.12" has newly added support for packing in embedding/reranker/seq_cls tasks. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing). Usually there is no need to modify this value, as packing speed is much faster than tokenization speed. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index ec8a06672f..6180520402 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -319,7 +319,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - Note: The "learning rate" printed in the logs is the learning rate of the LLM. - aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`. - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled. -- 🔥packing: Packs data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attention_backend flash`, it ensures that different sequences within packed samples remain independent and invisible to each other (except for Qwen3-Next, which contains linear-attention). This parameter defaults to `False`. All training tasks in Megatron-SWIFT support this parameter. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**. +- 🔥packing: Use the `padding_free` method to pack data samples of different lengths into samples of **approximately** uniform length (packing ensures that complete sequences are not split), achieving load balancing across nodes and processes during training (preventing long texts from slowing down short text training), thereby improving GPU utilization and maintaining stable memory usage. When using `--attention_backend flash`, it ensures that different sequences within packed samples remain independent and invisible to each other (except for Qwen3-Next, which contains linear-attention). This parameter defaults to `False`. All training tasks in Megatron-SWIFT support this parameter. Note: **packing will reduce the number of dataset samples, please adjust gradient accumulation steps and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing). Usually there is no need to modify this value, as packing speed is much faster than tokenization speed. - streaming: Stream data loading and processing, default is False. (The shuffling of streaming datasets is not thorough, which may lead to severe loss fluctuations.)