diff --git "a/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" "b/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" index d6cb841f07..76a574b957 100644 --- "a/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" +++ "b/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" @@ -4,7 +4,7 @@ SWIFT支持以命令行、Python代码和界面方式进行推理和部署: - 使用`engine.infer`或者`engine.infer_async`进行python的方式推理. 参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo.py). - 使用`swift infer`使用命令行的方式进行推理. 参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/infer/cli_demo.sh). - 使用`swift deploy`进行服务部署,并使用openai API或者`client.infer`的方式推理. 服务端参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/deploy/server), 客户端参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/deploy/client). -- 使用`swift app`部署模型进行界面推理, 可以查看[这里](../GetStarted/界面使用.md) +- 使用`swift app`部署模型进行界面推理, 可以查看[这里](../GetStarted/Web-UI.md) ## 命令行推理指令 diff --git a/docs/source_en/Instruction/Inference-and-deployment.md b/docs/source_en/Instruction/Inference-and-deployment.md index e77b7a970d..1229ba3590 100644 --- a/docs/source_en/Instruction/Inference-and-deployment.md +++ b/docs/source_en/Instruction/Inference-and-deployment.md @@ -4,7 +4,7 @@ SWIFT supports inference and deployment through command line, Python code, and i - Use `engine.infer` or `engine.infer_async` for Python-based inference. See [here](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo.py) for reference. - Use `swift infer` for command-line-based inference. See [here](https://github.com/modelscope/ms-swift/blob/main/examples/infer/cli_demo.sh) for reference. - Use `swift deploy` for service deployment and perform inference using the OpenAI API or `client.infer`. Refer to the server guidelines [here](https://github.com/modelscope/ms-swift/tree/main/examples/deploy/server) and the client guidelines [here](https://github.com/modelscope/ms-swift/tree/main/examples/deploy/client). -- Deploy the model with `swift app` for web-based inference. You can check [here](../GetStarted/Interface-usage.md) for details. +- Deploy the model with `swift app` for web-based inference. You can check [here](../GetStarted/Web-UI.md) for details. ## Command Line Inference diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 2851683db2..22340beef7 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -43,7 +43,6 @@ def _handle_ckpt_dir(self: 'BaseArguments'): return self.adapters.insert(0, self.ckpt_dir) else: - assert self.model is None, f'self.model: {self.model}' self.model = self.ckpt_dir self.ckpt_dir = None logger.warning('The `--ckpt_dir` parameter will be removed in `ms-swift>=3.2`. ' @@ -236,13 +235,14 @@ def _init_device(self): else: torch.cuda.set_device(self.local_rank) - def get_template(self, processor: 'Processor') -> 'Template': + def get_template(self, processor: 'Processor', template_type=None) -> 'Template': template_kwargs = self.get_template_kwargs() - template = get_template(self.template, processor, **template_kwargs) + template_type = template_type or self.template + template = get_template(template_type, processor, **template_kwargs) logger.info(f'default_system: {template.template_meta.default_system}') return template - def get_model_processor(self, *, model=None, model_type=None, model_revision=None, **kwargs): + def get_model_processor(self, *, model=None, model_type=None, model_revision=None, task_type=None, **kwargs): if self.tuner_backend == 'unsloth': return load_by_unsloth(self) kwargs.update(self.get_model_kwargs()) @@ -250,5 +250,6 @@ def get_model_processor(self, *, model=None, model_type=None, model_revision=Non kwargs['model_id_or_path'] = model or self.model kwargs['model_type'] = model_type or self.model_type kwargs['model_revision'] = model_revision or self.model_revision + kwargs['task_type'] = task_type or self.task_type return get_model_tokenizer(**kwargs) diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index 80feb707b8..fe6057b383 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -174,7 +174,7 @@ def _get_num_tokens(inputs: Dict[str, Any]) -> int: else: return input_ids.shape[-1] elif 'inputs_embeds' in inputs: # 2d or 3d - return inputs['inputs_embeds'].shape[-1] + return inputs['inputs_embeds'].shape[-2] raise ValueError(f'Unable to retrieve input_ids and inputs_embeds. inputs: {inputs}') def set_default_max_tokens(self, request_config: RequestConfig, inputs: Dict[str, Any]) -> None: diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 67f1b02f01..6510b989f3 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -118,7 +118,7 @@ def check_query(self, query: str) -> Optional[str]: return query -def _prepare_adapter(args, model): +def prepare_adapter(args, model, adapters=None): if args.tuner_backend == 'unsloth': if args.model_meta.is_multimodal: from unsloth import FastVisionModel as UnslothModel @@ -131,7 +131,8 @@ def _prepare_adapter(args, model): else: tuner = Swift # compat deploy - for adapter in args.adapters: + adapters = adapters or args.adapters + for adapter in adapters: model = tuner.from_pretrained(model, adapter) if args.train_type == 'bone': # Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0` @@ -141,6 +142,6 @@ def _prepare_adapter(args, model): def prepare_model_template(args, **kwargs): model, processor = args.get_model_processor(**kwargs) - model = _prepare_adapter(args, model) + model = prepare_adapter(args, model) template = args.get_template(processor) return model, template diff --git a/swift/llm/template/template_meta.py b/swift/llm/template/template_meta.py index 98520516c2..176870878d 100644 --- a/swift/llm/template/template_meta.py +++ b/swift/llm/template/template_meta.py @@ -128,6 +128,18 @@ def init(self, tokenizer: PreTrainedTokenizerBase) -> None: if tokenizer.eos_token not in self.stop_words: self.stop_words.append(tokenizer.eos_token) + self.stop_token_id = tokenizer.eos_token_id + if self.suffix: + suffix_tokens = self.suffix[-1] + if isinstance(suffix_tokens, str): + stop_token_id = tokenizer.convert_tokens_to_ids(suffix_tokens) + elif isinstance(suffix_tokens, list) and len(suffix_tokens) == 1: + stop_token_id = suffix_tokens[0] + else: + stop_token_id = None + if stop_token_id is not None: + self.stop_token_id = stop_token_id + def check_system(self, system: Optional[str]) -> None: if system is not None: assert self.support_system, ( diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 4584bdccae..c47212934d 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -105,7 +105,7 @@ def get_target_modules(args, model) -> Union[str, List[str]]: return target_modules -def get_modules_to_save(args, model): +def get_modules_to_save(args, model, task_type=None): modules_to_save = args.modules_to_save.copy() if 'all-embedding' in args.modules_to_save: modules_to_save.remove('all-embedding') @@ -113,6 +113,8 @@ def get_modules_to_save(args, model): if 'all-norm' in args.modules_to_save: modules_to_save.remove('all-norm') modules_to_save += find_norm(model) + if task_type and task_type.lower() == 'seq_cls': # reward_model + modules_to_save.append('v_head') return modules_to_save @@ -136,11 +138,12 @@ def get_vera_target_modules(model, config): return config -def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None): +def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None, task_type=None): from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, LLaMAProConfig, LongLoRAModelType, LoraConfig, LoRAConfig, ReftConfig, Swift, VeraConfig) + task_type = (task_type or args.task_type).upper() target_modules = get_target_modules(args, model) - modules_to_save = get_modules_to_save(args, model) + modules_to_save = get_modules_to_save(args, model, task_type) lora_kwargs = { 'r': args.lora_rank, 'target_modules': target_modules, @@ -153,7 +156,6 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset 'lorap_lr_ratio': args.lorap_lr_ratio, 'init_lora_weights': args.init_weights, } - task_type = args.task_type.upper() if args.train_type in ('lora', 'longlora'): if args.use_swift_lora: lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs) @@ -329,14 +331,7 @@ def torchacc_resume_from_checkpoint(args, model): class TunerMixin: @classmethod - def prepare_model( - cls, - args, - model, - *, - template=None, - train_dataset=None, - ): + def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None): if args.use_liger: # Apply liger apply_liger(args.model_type) @@ -361,7 +356,8 @@ def prepare_model( tuner: Tuner = extra_tuners[args.train_type] model = tuner.prepare_model(args, model) else: - model = prepare_adapter(args, model, template=template, train_dataset=train_dataset) + model = prepare_adapter( + args, model, template=template, train_dataset=train_dataset, task_type=task_type) # fix bug: Attempting to unscale FP16 gradients. # peft: https://github.com/huggingface/peft/issues/1249 for p in model.parameters(): diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index d0281bcdfc..5eb7a72dd3 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -44,21 +44,20 @@ class SwiftMixin: - def __init__( - self, - model: Union[PreTrainedModel, Module] = None, - args: TrainingArguments = None, - data_collator: Optional[DataCollator] = None, - train_dataset: Optional[HfDataset] = None, - eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None, - template: Optional[Template] = None, - model_init: Optional[Callable[[], PreTrainedModel]] = None, - compute_loss_func: Optional[Callable] = None, - compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], - torch.Tensor]] = None) -> None: + def __init__(self, + model: Union[PreTrainedModel, Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[HfDataset] = None, + eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None, + template: Optional[Template] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + **kwargs) -> None: if args.check_model and hasattr(model, 'model_dir'): check_local_model_is_latest( model.model_dir, user_agent={