diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 1a81cc968b..0b90de50ed 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -91,7 +91,8 @@ - 注意:数据集中的system**优先级**最高,然后是`--system`,最后是注册template时设置的`default_system`。 - 🔥max_length: 限制单数据集样本经过`tokenizer.encode`后的tokens最大长度,超过的数据样本会根据`truncation_strategy`参数进行处理(避免训练OOM)。默认为None,即设置为模型支持的tokens最大长度(max_model_len)。 - 当PPO、GRPO和推理情况下,`max_length`代表`max_prompt_length`。 -- truncation_strategy: 如果单样本的tokens超过`max_length`如何处理,支持`delete`、`left`和`right`,代表删除、左侧裁剪和右侧裁剪,默认为'delete'。 +- truncation_strategy: 如果单样本的tokens超过`max_length`如何处理,支持'delete'、'left'、'right'和'split',代表删除、左侧裁剪、右侧裁剪和切成多条数据样本,默认为'delete'。 + - 注意:`--truncation_strategy split`只支持预训练时使用,即`swift/megatron pt`场景下,需"ms-swift>=3.11",该策略会将超长字段切成多条数据样本,从而避免tokens浪费。(该特性不兼容cached_dataset) - 注意:若多模态模型的训练时将'truncation_strategy'设置为`left`或`right`,**ms-swift会保留所有的image_token等多模态tokens**,这可能会导致训练时OOM。 - 🔥max_pixels: 多模态模型输入图片的最大像素数(H\*W),将超过该限制的图像进行缩放(避免训练OOM)。默认为None,不限制最大像素数。 - 注意:该参数适用于所有的多模态模型。而Qwen2.5-VL特有的模型参数`MAX_PIXELS`(你可以在文档最下面找到)只针对Qwen2.5-VL模型。 @@ -698,7 +699,6 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - exist_ok: 如果output_dir存在,不抛出异常,进行覆盖。默认为False。 - 🔥quant_method: 可选为'gptq'、'awq'、'bnb'和'fp8',默认为None。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/quantize)。 - quant_n_samples: gptq/awq的校验集采样数,默认为256。 -- max_length: 校准集的max_length, 默认值2048。 - quant_batch_size: 量化batch_size,默认为1。 - group_size: 量化group大小,默认为128。 - to_cached_dataset: 提前对数据集进行tokenize并导出,默认为False。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。更多介绍请查看`cached_dataset`。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 7c977cb923..d575225d1f 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -82,6 +82,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 首先,我们需要将HF格式的权重转为Megatron格式: - 多卡权重转换:将`CUDA_VISIBLE_DEVICES=0`删除即可使用多卡权重转换。 - 转换精度测试:`--test_convert_precision true`将测试转换精度。在MoE大型模型的转换时,该参数所需时间较长,且需要更多的内存消耗,可酌情去除。 +- ms-swift支持了Mcore-Bridge来避免权重转换的额外耗时,请参考[Mcore-Bridge文档](./Mcore-Bridge.md)。 ```shell CUDA_VISIBLE_DEVICES=0 \ swift export \ diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 48d9d6e821..ccc65a2581 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -91,7 +91,8 @@ The command-line arguments will be introduced in four categories: basic argument - Note: In terms of priority, the `system` field from the dataset takes precedence, followed by `--system`, and finally the `default_system` set in the registered template. - 🔥max_length: Maximum token length after `tokenizer.encode` for a single data sample (to prevent OOM during training). Samples exceeding this limit are handled according to `truncation_strategy`. Default is `None`, meaning it's set to the model’s maximum supported sequence length (`max_model_len`). - In PPO, GRPO, and inference scenarios, `max_length` refers to `max_prompt_length`. -- truncation_strategy: How to handle samples exceeding `max_length`. Options: `'delete'`, `'left'`, `'right'`, representing deletion, left-truncation, and right-truncation respectively. Default is `'delete'`. +- truncation_strategy: How to handle samples whose tokens exceed `max_length`. Supports 'delete', 'left', 'right', and 'split', which represent deleting, left truncation, right truncation, and splitting into multiple data samples, respectively. The default is 'delete'. + - Note: `--truncation_strategy split` is only supported during pretraining, i.e., in `swift/megatron pt` scenarios, and requires "ms-swift>=3.11". This strategy will split oversized fields into multiple data samples to avoid token waste. (This feature is not compatible with cached_dataset) - Note: For multimodal models, if `truncation_strategy` is set to `'left'` or `'right'` during training, **ms-swift preserves all image tokens and other modality-specific tokens**, which may lead to OOM. - 🔥max_pixels: Maximum pixel count (H×W) for input images in multimodal models. Images exceeding this limit will be resized to avoid OOM during training. Default is `None` (no restriction). - Note: This parameter applies to all multimodal models. The Qwen2.5-VL specific parameter `MAX_PIXELS` (see bottom of doc) only affects Qwen2.5-VL. @@ -716,7 +717,6 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - exist_ok: If output_dir exists, do not raise an exception and overwrite the contents. The default value is False. - 🔥quant_method: Options are 'gptq', 'awq', 'bnb' or 'fp8', with the default being None. Examples can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/quantize). - quant_n_samples: The number of samples for the validation set used by gptq/awq, with a default of 256. -- max_length: Max length for the calibration set, default value is 2048. - quant_batch_size: Quantization batch size, default is 1. - group_size: Group size for quantization, default is 128. - to_cached_dataset: pre-tokenize the dataset and export it in advance, default is False. See the example [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). For more information, please refer to cached_dataset. diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 3e4f390680..1161d76449 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -82,6 +82,7 @@ This section introduces a quick start example for fine-tuning the self-awareness First, we need to convert the weights from HF (Hugging Face) format to Megatron format: - Multi-GPU weight conversion: Remove `CUDA_VISIBLE_DEVICES=0` to enable multi-GPU weight conversion. - Conversion precision test: `--test_convert_precision true` will test the conversion precision. For large MoE model conversions, this option takes longer and consumes more memory, so you may omit it as needed. +- ms-swift supports Mcore-Bridge to avoid the extra time cost of weight conversion. Please refer to the [Mcore-Bridge documentation](./Mcore-Bridge.md). ```shell CUDA_VISIBLE_DEVICES=0 \ swift export \ diff --git a/examples/export/quantize/reward_model/gptq.sh b/examples/export/quantize/reward_model/gptq.sh index 5bfa3023a8..41cac64e7a 100644 --- a/examples/export/quantize/reward_model/gptq.sh +++ b/examples/export/quantize/reward_model/gptq.sh @@ -3,6 +3,7 @@ CUDA_VISIBLE_DEVICES=0 swift export \ --model Shanghai_AI_Laboratory/internlm2-1_8b-reward \ --output_dir output/internlm2-1_8b-reward-gptq-int4 \ --quant_bits 4 \ + --max_length 2048 \ --quant_method gptq \ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#1000' 'AI-ModelScope/alpaca-gpt4-data-en#1000' diff --git a/swift/llm/argument/base_args/template_args.py b/swift/llm/argument/base_args/template_args.py index 5e1aeff330..2ff2e2ca6b 100644 --- a/swift/llm/argument/base_args/template_args.py +++ b/swift/llm/argument/base_args/template_args.py @@ -32,7 +32,7 @@ class TemplateArguments: system: Optional[str] = None # Override the default_system in the template. max_length: Optional[int] = None - truncation_strategy: Literal['delete', 'left', 'right', None] = None + truncation_strategy: Literal['delete', 'left', 'right', 'split', None] = None max_pixels: Optional[int] = None agent_template: Optional[str] = None norm_bbox: Literal['norm1000', 'none', None] = None diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index ed36395444..fd1daa801d 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -37,7 +37,6 @@ class ExportArguments(MergeArguments, BaseArguments): # awq/gptq quant_method: Literal['awq', 'gptq', 'bnb', 'fp8', 'gptq_v2'] = None quant_n_samples: int = 256 - max_length: int = 2048 quant_batch_size: int = 1 group_size: int = 128 diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 054fffa5d3..491ca39c71 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -65,7 +65,7 @@ class GRPOArguments(GRPOArgumentsMixin): # multi step num_iterations: int = 1 - truncation_strategy: Literal['delete', 'left', 'right', None] = None + truncation_strategy: Literal['delete', 'left', 'right', 'split', None] = None @dataclass diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index acc295849d..f70bd03517 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -63,7 +63,7 @@ def __init__( default_system: Optional[str] = None, max_length: Optional[int] = None, *, - truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', + truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', max_pixels: Optional[int] = None, agent_template: Optional[str] = None, norm_bbox: Literal['norm1000', 'none', None] = None, @@ -529,31 +529,36 @@ def encode(self, else: raise ValueError(f'task_type: {self.task_type} is not supported.') - if chosen.channel is not None: - encoded['channel'] = chosen.channel - - lengths = [0] if self.task_type not in {'reranker', 'generative_reranker'} else [] - for key in list(encoded.keys()): - if encoded[key] is None: - encoded.pop(key) - elif key.endswith('length'): - value = encoded[key] - if isinstance(value, int): - lengths.append(value) - elif isinstance(value, (tuple, list)): - lengths += value - if return_length: - if self.task_type in {'reranker', 'generative_reranker'}: - encoded['length'] = lengths + # compatible with `--truncation_strategy split` + batched = encoded + if not isinstance(batched, (list, tuple)): + batched = [batched] + for encoded in batched: + if chosen.channel is not None: + encoded['channel'] = chosen.channel + + lengths = [0] if self.task_type not in {'reranker', 'generative_reranker'} else [] + for key in list(encoded.keys()): + if encoded[key] is None: + encoded.pop(key) + elif key.endswith('length'): + value = encoded[key] + if isinstance(value, int): + lengths.append(value) + elif isinstance(value, (tuple, list)): + lengths += value + if return_length: + if self.task_type in {'reranker', 'generative_reranker'}: + encoded['length'] = lengths + else: + encoded['length'] = sum(lengths) else: - encoded['length'] = sum(lengths) - else: - encoded.pop('length', None) - if return_template_inputs: - encoded['template_inputs'] = chosen - if not self.remove_unused_columns: - encoded['_extra_kwargs'] = chosen.extra_kwargs - return encoded + encoded.pop('length', None) + if return_template_inputs: + encoded['template_inputs'] = chosen + if not self.remove_unused_columns: + encoded['_extra_kwargs'] = chosen.extra_kwargs + return batched[0] if len(batched) == 1 else batched def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]: packed = {} @@ -1229,6 +1234,26 @@ def _encode_truncated(self, inputs: StdTemplateInputs): elif self.truncation_strategy == 'raise': raise MaxLengthError(f'Current length of row({length}) is larger' f' than the max_length({self.max_length}).') + elif self.truncation_strategy == 'split': + i = 0 + batched = [] + while i < length: + splited = {} + for key in ['input_ids', 'labels', 'loss_scale']: + value = encoded.get(key) + if value is not None: + value = value[i:i + self.max_length] + if key == 'labels' and len(value) > 0: + value[0] = -100 + elif key == 'loss_scale' and len(value) > 0: + value[0] = 0 + splited[key] = value + splited['length'] = self._get_length(splited.get('input_ids'), splited.get('labels')) + batched.append(splited) + i += self.max_length + return batched + else: + raise ValueError(f'Invalid truncation_strategy: {self.truncation_strategy}') encoded['length'] = length encoded['input_ids'] = input_ids encoded['labels'] = labels diff --git a/swift/llm/template/register.py b/swift/llm/template/register.py index 713ec37133..91f9778169 100644 --- a/swift/llm/template/register.py +++ b/swift/llm/template/register.py @@ -22,7 +22,7 @@ def get_template( default_system: Optional[str] = None, max_length: Optional[int] = None, *, - truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', + truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise', max_pixels: Optional[int] = None, # h * w agent_template: Optional[str] = None, norm_bbox: Literal['norm1000', 'none', None] = None, diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 5c8e0c07fc..9f05659931 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -142,7 +142,7 @@ def _post_process_datasets(self, datasets: List) -> List: if i == 1 and predict_with_generate: # val_dataset continue - if not args.streaming: + if not args.streaming and args.truncation_strategy != 'split': dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) if args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset @@ -317,6 +317,13 @@ def _encode_dataset(self, train_dataset, val_dataset, pre_process=True): origin_template_model = template.model template.model = None # Avoid serializing the model. + if args.truncation_strategy == 'split': + if (args.task_type != 'causal_lm' or template.mode != 'train' or args.use_chat_template + or args.model_meta.is_multimodal): + raise ValueError( + '`--truncation_strategy split` is currently only supported for plain text model pretraining') + assert not args.lazy_tokenize, '`--truncation_strategy split` does not support lazy_tokenize' + for i, dataset in enumerate(datasets): if dataset is None: continue @@ -325,7 +332,8 @@ def _encode_dataset(self, train_dataset, val_dataset, pre_process=True): continue if not args.lazy_tokenize and not args.streaming: # Compatible with cached_dataset, only additionally write length here. - preprocessor = AddLengthPreprocessor(template=template) + preprocessor_cls = EncodePreprocessor if args.truncation_strategy == 'split' else AddLengthPreprocessor + preprocessor = preprocessor_cls(template=template) batch_size = 100 if args.model_meta.is_multimodal else 1000 dataset = preprocessor( dataset,