From ff8161876e83d81a6547fa6c59aced7c92437c53 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 19 Nov 2023 19:56:33 +0800 Subject: [PATCH 1/2] fix load_from_ckpt_dir bug --- swift/llm/utils/argument.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 70e4184057..6869807075 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -537,8 +537,8 @@ def load_from_ckpt_dir(args: InferArguments) -> None: with open(sft_args_path, 'r') as f: sft_args = json.load(f) imported_keys = [ - 'model_id_or_path', 'model_revision', 'model_cache_dir', 'sft_type', - 'template_type', 'dtype', 'system', 'quantization_bit', + 'model_type', 'model_id_or_path', 'model_revision', 'model_cache_dir', + 'sft_type', 'template_type', 'dtype', 'system', 'quantization_bit', 'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant' ] From 644e8d81d5588440f7deab3c6d4bfae5cf6cf6a6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 19 Nov 2023 20:59:26 +0800 Subject: [PATCH 2/2] remove max_length=None --- examples/pytorch/llm/README.md | 2 +- examples/pytorch/llm/README_CN.md | 2 +- swift/llm/infer.py | 1 - swift/llm/rome.py | 1 - swift/llm/sft.py | 1 - swift/llm/utils/utils.py | 5 ++++- swift/trainers/trainers.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/pytorch/llm/README.md b/examples/pytorch/llm/README.md index cb584544d2..9a80690733 100644 --- a/examples/pytorch/llm/README.md +++ b/examples/pytorch/llm/README.md @@ -614,4 +614,4 @@ The template initialization function retrieves the complete chat template based - `--ignore_args_error`: Default value is `False`. For specific parameter details, please refer to the `sft.sh Command Line Arguments`. - `--stream`: Whether to use streaming output. Default value is `True`. - `--merge_lora_and_save`: Whether to merge the lora weights into the base model and save the complete weights. Default value is `False`. The weights will be saved in a directory named `checkpoint-xxx-merged` at the same level as `ckpt_dir`, e.g., `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`. -- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generate_config file saved during training will be overwritten. +- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generation_config file saved during training will be overwritten. diff --git a/examples/pytorch/llm/README_CN.md b/examples/pytorch/llm/README_CN.md index 5bde99ef3d..80f5b80e07 100644 --- a/examples/pytorch/llm/README_CN.md +++ b/examples/pytorch/llm/README_CN.md @@ -617,4 +617,4 @@ if __name__ == '__main__': - `--ignore_args_error`: 默认值为`False`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. - `--stream`: 是否使用流式输出, 默认为`True`. - `--merge_lora_and_save`: 是否将lora权重merge到基模型中, 并保存完整的权重, 默认为`False`. 权重会保存在`ckpt_dir`的同级目录中, e.g. `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`目录下. -- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generate_config文件将被覆盖. +- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generation_config文件将被覆盖. diff --git a/swift/llm/infer.py b/swift/llm/infer.py index 407965600b..84c79531d0 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -110,7 +110,6 @@ def prepare_model_template( args.system, args.max_length, args.truncation_strategy) generation_config = GenerationConfig( - max_length=None, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, diff --git a/swift/llm/rome.py b/swift/llm/rome.py index 6fd4c524d3..37b986f2f6 100644 --- a/swift/llm/rome.py +++ b/swift/llm/rome.py @@ -59,7 +59,6 @@ def rome_infer(args: RomeArguments) -> None: args.system, args.max_length, args.truncation_strategy) generation_config = GenerationConfig( - max_length=None, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 75859d0157..85723e1ff4 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -164,7 +164,6 @@ def llm_sft(args: SftArguments) -> str: # Setting training_args generation_config = GenerationConfig( - max_length=None, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 850acee452..c8695dd00a 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -350,7 +350,8 @@ def inference_stream( model.__class__.sample_stream = NewGenerationMixin.sample_stream stream_config = StreamGenerationConfig( **generation_config.to_dict(), do_stream=True) - stream_config.max_length = int(1e9) # fix max_length, max_new_tokens bug + if stream_config.max_new_tokens is not None: + stream_config.max_length = 20 # fix max_length, max_new_tokens bug stream_config.do_sample = True # avoid is_greedy_gen_mode = True gen = model.generate_stream( input_ids=input_ids, @@ -395,6 +396,8 @@ def inference(model: PreTrainedModel, streamer = None if stream: streamer = TextStreamer(tokenizer, skip_prompt=True) + if generation_config.max_new_tokens is not None: + generation_config.max_length = 20 # fix max_length, max_new_tokens bug generate_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index bf9996920f..867c70f8a9 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -107,7 +107,7 @@ def prediction_step( gen_kwargs['eos_token_id'] = self.tokenizer.eos_token_id # fix generate warning if ('max_length' in gen_kwargs and 'max_new_tokens' in gen_kwargs - and gen_kwargs['max_length'] is None): + and gen_kwargs['max_new_tokens'] is not None): gen_kwargs.pop('max_length') gen_time = time.time() generate_inputs = inputs.copy()