diff --git a/README.md b/README.md index 198e6ce001..1be393bcb7 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ RLHF: CUDA_VISIBLE_DEVICES=0 swift rlhf \ --rlhf_type dpo \ --model Qwen/Qwen2.5-7B-Instruct \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:en \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ --train_type lora \ --output_dir output \ ... diff --git a/README_CN.md b/README_CN.md index da8e5b83fb..bf4443bd88 100644 --- a/README_CN.md +++ b/README_CN.md @@ -246,7 +246,7 @@ RLHF: CUDA_VISIBLE_DEVICES=0 swift rlhf \ --rlhf_type dpo \ --model Qwen/Qwen2.5-7B-Instruct \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ --train_type lora \ --output_dir output \ ... diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 208c753a7e..f0920e9f5e 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -67,6 +67,7 @@ - num_beams: beam search的并行保留数量,默认为1 - 🔥stream: 流式输出,默认为`False` - stop_words: 额外的停止词,默认为`[]` +- logprobs: 是否输出logprobs,默认为False ### 量化参数 以下为拉起模型时量化的参数,具体含义可以查看[量化](https://huggingface.co/docs/transformers/main/en/main_classes/quantization)文档。这里不包含`swift export`中涉及的`gptq`、`awq`量化参数 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 6c1d0e6962..b56a0b1afc 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -68,6 +68,7 @@ Refer to the [generation_config](https://huggingface.co/docs/transformers/main_c - num_beams: Number of beams for beam search, default is 1. - 🔥stream: Stream output, default is `False`. - stop_words: Additional stop words, default is `[]`. +- logprobs: Whether to output logprobs, default is False. ### Quantization Arguments diff --git a/examples/train/multimodal/dpo.sh b/examples/train/multimodal/rlhf/dpo.sh similarity index 81% rename from examples/train/multimodal/dpo.sh rename to examples/train/multimodal/rlhf/dpo.sh index 1063661571..3dbb966b60 100644 --- a/examples/train/multimodal/dpo.sh +++ b/examples/train/multimodal/rlhf/dpo.sh @@ -1,15 +1,15 @@ -# 4*32GiB +# 4*50GiB # You can refer to `https://github.com/QwenLM/Qwen2-VL` for the meaning of the `MAX_PIXELS` parameter. -# --rlhf_type cpo/orpo/simpo/rm/kto are also supported -nproc_per_node=4 +# --rlhf_type cpo/orpo/simpo are also supported +nproc_per_node=2 -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=0,1 \ NPROC_PER_NODE=$nproc_per_node \ MAX_PIXELS=1003520 \ swift rlhf \ --rlhf_type dpo \ --model Qwen/Qwen2-VL-7B-Instruct \ - --dataset swift/RLAIF-V-Dataset \ + --dataset 'swift/RLAIF-V-Dataset#20000' \ --train_type lora \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ @@ -24,7 +24,7 @@ swift rlhf \ --eval_steps 100 \ --save_steps 100 \ --save_total_limit 5 \ - --deepspeed zero3 \ + --deepspeed zero2 \ --logging_steps 5 \ --max_length 2048 \ --output_dir output \ diff --git a/examples/train/multimodal/rlhf/kto.sh b/examples/train/multimodal/rlhf/kto.sh new file mode 100644 index 0000000000..fcf07c28de --- /dev/null +++ b/examples/train/multimodal/rlhf/kto.sh @@ -0,0 +1,31 @@ +# Due to the absence of a multi-modal open-source dataset for kto, +# we will use a pure text kto dataset as an example here. +nproc_per_node=2 + +CUDA_VISIBLE_DEVICES=0,1 \ +NPROC_PER_NODE=$nproc_per_node \ +MAX_PIXELS=1003520 \ +swift rlhf \ + --rlhf_type kto \ + --model Qwen/Qwen2-VL-7B-Instruct \ + --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#10000' \ + --train_type lora \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --freeze_vit true \ + --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 5 \ + --deepspeed zero2 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 diff --git a/examples/train/rlhf/cpo.sh b/examples/train/rlhf/cpo.sh index acf6e7202b..9473e8d02e 100644 --- a/examples/train/rlhf/cpo.sh +++ b/examples/train/rlhf/cpo.sh @@ -6,15 +6,22 @@ swift rlhf \ --rlhf_type cpo \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ - --weight_decay 0.1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/examples/train/rlhf/dpo.sh b/examples/train/rlhf/dpo.sh index 7d16026cb9..cd52333668 100644 --- a/examples/train/rlhf/dpo.sh +++ b/examples/train/rlhf/dpo.sh @@ -6,14 +6,22 @@ swift rlhf \ --rlhf_type dpo \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/examples/train/rlhf/kto.sh b/examples/train/rlhf/kto.sh index 5d031bce43..d7a5731b81 100644 --- a/examples/train/rlhf/kto.sh +++ b/examples/train/rlhf/kto.sh @@ -1,19 +1,26 @@ -nproc_per_node=4 +nproc_per_node=2 -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=0,1 \ NPROC_PER_NODE=$nproc_per_node \ swift rlhf \ --rlhf_type kto \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ --dataset 'AI-ModelScope/ultrafeedback-binarized-preferences-cleaned-kto#10000' \ - --num_train_epochs 2 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/examples/train/rlhf/orpo.sh b/examples/train/rlhf/orpo.sh index 97c8cff4f5..c11ec4860b 100644 --- a/examples/train/rlhf/orpo.sh +++ b/examples/train/rlhf/orpo.sh @@ -6,14 +6,22 @@ swift rlhf \ --rlhf_type orpo \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/examples/train/rlhf/rm.sh b/examples/train/rlhf/rm.sh index 77bf50a134..98b05f1423 100644 --- a/examples/train/rlhf/rm.sh +++ b/examples/train/rlhf/rm.sh @@ -6,14 +6,22 @@ swift rlhf \ --rlhf_type rm \ --model Qwen/Qwen2.5-7B-Instruct \ --train_type lora \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ + --target_modules all-linear \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --gradient_checkpointing_kwargs '{"use_reentrant": false}' \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/examples/train/rlhf/simpo.sh b/examples/train/rlhf/simpo.sh index 804fe13a2c..2c3d066ebf 100644 --- a/examples/train/rlhf/simpo.sh +++ b/examples/train/rlhf/simpo.sh @@ -1,18 +1,25 @@ +# 2*50GB nproc_per_node=2 CUDA_VISIBLE_DEVICES=0,1 \ NPROC_PER_NODE=$nproc_per_node \ swift rlhf \ --rlhf_type simpo \ - --model Qwen/Qwen2.5-7B-Instruct \ + --model Qwen/Qwen2.5-3B-Instruct \ --train_type full \ - --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh \ + --dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \ + --torch_dtype bfloat16 \ --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --learning_rate 1e-5 \ --gradient_accumulation_steps $(expr 16 / $nproc_per_node) \ - --warmup_ratio 0.03 \ --eval_steps 100 \ --save_steps 100 \ - --save_total_limit 2 \ - --deepspeed zero3 \ - --logging_steps 5 + --save_total_limit 5 \ + --logging_steps 5 \ + --max_length 2048 \ + --output_dir output \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --deepspeed zero2 diff --git a/swift/hub/hub.py b/swift/hub/hub.py index 0b2297d68c..6170128f28 100644 --- a/swift/hub/hub.py +++ b/swift/hub/hub.py @@ -291,7 +291,7 @@ def load_dataset(cls, cls.try_login(token) if revision is None or revision == 'main': revision = 'master' - # noinspection PyTypeChecker + return MsDataset.load( dataset_id, subset_name=subset_name, diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index e4baa50e56..a867c2bdea 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Literal, Optional +import torch + from swift.utils import get_logger from .base_args import BaseArguments, to_abspath from .merge_args import MergeArguments @@ -51,14 +53,6 @@ class ExportArguments(MergeArguments, BaseArguments): # compat to_peft_format: bool = False - def _init_quant(self): - - if self.quant_bits: - if self.quant_method is None: - raise ValueError('Please specify the quantization method using `--quant_method awq/gptq`.') - if len(self.dataset) == 0 and self.quant_method in {'gptq', 'awq'}: - raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.') - def _init_output_dir(self): suffix = None if self.output_dir is None: @@ -68,7 +62,7 @@ def _init_output_dir(self): suffix = 'peft' elif self.merge_lora: suffix = 'merged' - elif self.quant_bits: + elif self.quant_method: suffix = f'{self.quant_method}-int{self.quant_bits}' elif self.to_ollama: suffix = 'ollama' @@ -82,13 +76,14 @@ def _init_output_dir(self): assert not os.path.exists(self.output_dir), f'args.output_dir: {self.output_dir} already exists.' def __post_init__(self): + if self.quant_bits and self.quant_method is None: + raise ValueError('Please specify the quantization method using `--quant_method awq/gptq/bnb`.') + if self.quant_method and self.quant_bits is None: + raise ValueError('Please specify `--quant_bits`.') + if self.quant_method in {'gptq', 'awq'} and self.torch_dtype is None: + self.torch_dtype = torch.float16 + BaseArguments.__post_init__(self) self._init_output_dir() - if self.quant_bits: - self._init_quant() - - def _init_torch_dtype(self) -> None: - if self.quant_bits and self.torch_dtype is None: - self.torch_dtype = 'float16' - logger.info(f'Setting args.torch_dtype: {self.torch_dtype}') - super()._init_torch_dtype() + if self.quant_method in {'gptq', 'awq'} and len(self.dataset) == 0: + raise ValueError(f'self.dataset: {self.dataset}, Please input the quant dataset.') diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 68dbd9f7d6..89d167ee3c 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -44,6 +44,7 @@ class RLHFArguments(TrainArguments): undesirable_weight: float = 1.0 def __post_init__(self): + self._init_rm() self._init_simpo() self._set_default() super().__post_init__() @@ -65,6 +66,11 @@ def _init_simpo(self): if self.beta is None: self.beta = 2. + def _init_rm(self): + if self.rlhf_type == 'rm': + self.task_type = 'seq_cls' + self.num_labels = 1 + def _set_default(self): if self.beta is None: self.beta = 0.1 diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 6b035ca7d9..de1a350c04 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -537,20 +537,10 @@ def repair_conversations(s: Union[str, Any]) -> Any: register_dataset( DatasetMeta( ms_dataset_id='hjh0119/shareAI-Llama3-DPO-zh-en-emoji', - subsets=[ - SubsetDataset( - 'zh', - preprocess_func=ResponsePreprocessor(columns_mapping={ - 'answer_zh': 'response', - 'answer_en': 'rejected_response' - })), - SubsetDataset( - 'en', - preprocess_func=ResponsePreprocessor(columns_mapping={ - 'answer_en': 'response', - 'answer_zh': 'rejected_response' - })) - ], + preprocess_func=ResponsePreprocessor(columns_mapping={ + 'answer_zh': 'response', + 'answer_en': 'rejected_response' + }), tags=['rlhf', 'dpo'])) register_dataset( diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 51c20f36c3..1adb901dee 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -21,7 +21,7 @@ def run(self): args.adapters[0] = swift_to_peft_format(args.adapters[0], args.output_dir) elif args.merge_lora: merge_lora(args) - elif args.quant_method is not None: + elif args.quant_method: quantize_model(args) elif args.to_ollama: export_to_ollama(args) diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 21cf3ee255..80feb707b8 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -115,7 +115,6 @@ def _update_metrics(result, metrics: Optional[List[Metric]] = None): metric.update(response) return result_origin - @torch.inference_mode() def infer(self, infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None, diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 57e70dc083..a665d4ea46 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -249,7 +249,6 @@ async def _infer_full_async(self, template: Template, inputs: Dict[str, Any], ] return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info) - @torch.inference_mode() async def infer_async(self, infer_request: InferRequest, request_config: Optional[RequestConfig] = None, @@ -266,7 +265,8 @@ async def infer_async(self, request_config.seed = get_seed() loop = asyncio.get_running_loop() - inputs = await loop.run_in_executor(None, template.encode, infer_request) + with torch.inference_mode(): + inputs = await loop.run_in_executor(None, template.encode, infer_request) images = inputs.pop('images', None) if images: inputs['images'] = await self.engine.vl_encoder.async_infer(images) @@ -283,7 +283,6 @@ async def infer_async(self, else: return await self._infer_full_async(**kwargs) - @torch.inference_mode() def infer( self, infer_requests: List[InferRequest], diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index b52bbbda2b..45022dcb8a 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -347,7 +347,6 @@ def _infer_full(self, res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) return res - @torch.inference_mode() async def infer_async( self, infer_request: InferRequest, @@ -376,6 +375,7 @@ async def _gen_wrapper(): else: return res_or_gen[0] + @torch.inference_mode() def _infer( self, infer_requests: List[InferRequest], @@ -436,7 +436,6 @@ def _gen_wrapper(): infer_func = self._infer_seq_cls if template.mode == 'seq_cls' else self._infer_full return self._update_metrics(infer_func(**kwargs), metrics) - @torch.inference_mode() def infer( self, infer_requests: List[InferRequest], diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index ff182e017b..78310eed53 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -330,7 +330,6 @@ async def _infer_full_async(self, choices.append(choice) return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info, id=request_id) - @torch.inference_mode() def infer( self, infer_requests: List[InferRequest], @@ -349,7 +348,6 @@ def infer( use_tqdm=use_tqdm, adapter_request=adapter_request) - @torch.inference_mode() async def infer_async( self, infer_request: InferRequest, @@ -365,7 +363,8 @@ async def infer_async( template.set_mode('vllm') loop = asyncio.get_running_loop() - inputs = await loop.run_in_executor(None, template.encode, infer_request) + with torch.inference_mode(): + inputs = await loop.run_in_executor(None, template.encode, infer_request) self.set_default_max_tokens(request_config, inputs) generation_config = self._prepare_generation_config(request_config) self._add_stop_words(generation_config, request_config, template.template_meta) diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 903d31c8a0..5d1271c8e5 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -25,10 +25,7 @@ def patch_output_clone(module: torch.nn.Module): """Clone the output, to avoid the inplace problem""" def _clone_hook(module, input, output): - if module.training: - return output.requires_grad_(True).clone() - else: - return output.detach() + return output.requires_grad_(True).clone() module.register_forward_hook(_clone_hook) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 43857a23a8..276420ea4b 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -110,7 +110,7 @@ def __init__( self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer 'train', 'rlhf', 'kto', # train 'seq_cls'] = 'pt' - if self.model_info.task_type != 'causal': + if self.model_info.task_type != 'causal_lm': self.mode = self.model_info.task_type self._handles = [] self._deepspeed_initialize = None @@ -184,7 +184,7 @@ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: label, inputs.label = inputs.label, None encoded = self._rlhf_encode(inputs) - encoded['label'] = label + encoded['label'] = bool(label) return encoded def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @@ -721,10 +721,7 @@ def is_training(self): return self.mode not in {'vllm', 'lmdeploy', 'pt'} def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None: - if self.model_info.task_type == 'causal_lm': - self.mode = mode - else: - swift.warning(f'task_type: `{self.model_info.task_type}` does not support modifying template.mode.') + self.mode = mode def register_post_encode_hook(self, models: List[nn.Module]) -> None: """This function is important for multi-modal training, as it registers the post_encode method diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 906a3e1166..2e5bff8910 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -32,18 +32,6 @@ def _prepare_template(self) -> None: # Avoid padding labels during the model's forward pass in multimodal models. self.template.loss_scale = 'last_round' - @classmethod - def prepare_model(cls, args, model, *_args, **kwargs): - model = super().prepare_model(args, model, *_args, **kwargs) - if args.rlhf_type == 'rm': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - model.lm_head = None # avoid error - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - patch_getattr(AutoModelForCausalLMWithValueHead, 'pretrained_model') - return model - def _get_dataset(self): args = self.args train_dataset, val_dataset = super()._get_dataset() diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 407c8b511b..4226df8bca 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -110,7 +110,7 @@ def run(self): args = self.args train_dataset, val_dataset = self._get_dataset() - if args.task_type == 'seq_cls' and isinstance(train_dataset, HfDataset): + if args.task_type == 'seq_cls' and isinstance(train_dataset, HfDataset) and 'label' in train_dataset.features: min_num_labels = int(max(train_dataset['label']) + 1) assert args.num_labels >= min_num_labels, ( f'args.num_labels: {args.num_labels}, min_num_labels: {min_num_labels}') diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainer/reward_trainer.py index 317c487e28..49fe5d683b 100644 --- a/swift/trainers/rlhf_trainer/reward_trainer.py +++ b/swift/trainers/rlhf_trainer/reward_trainer.py @@ -26,10 +26,7 @@ def compute_loss(self, inputs.pop('labels', None) # not use attention_mask = inputs['attention_mask'] batch_size = attention_mask.shape[0] // 2 - values = model(**inputs)[2] - - sequence_lengths = (torch.eq(attention_mask, 0).int().argmax(-1) - 1) % attention_mask.shape[1] - rewards = values.gather(dim=-1, index=sequence_lengths[:, None]) + rewards = model(**inputs).logits rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0) if 'margin' in inputs: loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs['margin']).mean() diff --git a/tests/train/test_rlhf.py b/tests/train/test_rlhf.py index 421bb39c15..7c20a70ae6 100644 --- a/tests/train/test_rlhf.py +++ b/tests/train/test_rlhf.py @@ -16,7 +16,7 @@ def test_llm(): RLHFArguments( rlhf_type='dpo', model='Qwen/Qwen2-7B-Instruct', - dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji:zh#100'], + dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#100'], **kwargs)) last_model_checkpoint = result['last_model_checkpoint'] infer_main(InferArguments(adapters=last_model_checkpoint, load_data_args=True, merge_lora=True))