Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/pytorch/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,4 @@ The template initialization function retrieves the complete chat template based
- `--ignore_args_error`: Default value is `False`. For specific parameter details, please refer to the `sft.sh Command Line Arguments`.
- `--stream`: Whether to use streaming output. Default value is `True`.
- `--merge_lora_and_save`: Whether to merge the lora weights into the base model and save the complete weights. Default value is `False`. The weights will be saved in a directory named `checkpoint-xxx-merged` at the same level as `ckpt_dir`, e.g., `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`.
- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generate_config file saved during training will be overwritten.
- `--overwrite_generation_config`: Whether to save the generation_config used for evaluation as a `generation_config.json` file. Default value is `False`. The generation_config file saved during training will be overwritten.
2 changes: 1 addition & 1 deletion examples/pytorch/llm/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,4 @@ if __name__ == '__main__':
- `--ignore_args_error`: 默认值为`False`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
- `--stream`: 是否使用流式输出, 默认为`True`.
- `--merge_lora_and_save`: 是否将lora权重merge到基模型中, 并保存完整的权重, 默认为`False`. 权重会保存在`ckpt_dir`的同级目录中, e.g. `'/path/to/your/vx_xxx/checkpoint-xxx-merged'`目录下.
- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generate_config文件将被覆盖.
- `--overwrite_generation_config`: 是否将评估所使用的generation_config保存成`generation_config.json`文件, 默认为`False`. 训练时保存的generation_config文件将被覆盖.
1 change: 0 additions & 1 deletion swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def prepare_model_template(
args.system, args.max_length,
args.truncation_strategy)
generation_config = GenerationConfig(
max_length=None,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
Expand Down
1 change: 0 additions & 1 deletion swift/llm/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def rome_infer(args: RomeArguments) -> None:
args.system, args.max_length,
args.truncation_strategy)
generation_config = GenerationConfig(
max_length=None,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
Expand Down
1 change: 0 additions & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def llm_sft(args: SftArguments) -> str:

# Setting training_args
generation_config = GenerationConfig(
max_length=None,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,8 @@ def load_from_ckpt_dir(args: InferArguments) -> None:
with open(sft_args_path, 'r') as f:
sft_args = json.load(f)
imported_keys = [
'model_id_or_path', 'model_revision', 'model_cache_dir', 'sft_type',
'template_type', 'dtype', 'system', 'quantization_bit',
'model_type', 'model_id_or_path', 'model_revision', 'model_cache_dir',
'sft_type', 'template_type', 'dtype', 'system', 'quantization_bit',
'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type',
'bnb_4bit_use_double_quant'
]
Expand Down
5 changes: 4 additions & 1 deletion swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def inference_stream(
model.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(
**generation_config.to_dict(), do_stream=True)
stream_config.max_length = int(1e9) # fix max_length, max_new_tokens bug
if stream_config.max_new_tokens is not None:
stream_config.max_length = 20 # fix max_length, max_new_tokens bug
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
gen = model.generate_stream(
input_ids=input_ids,
Expand Down Expand Up @@ -395,6 +396,8 @@ def inference(model: PreTrainedModel,
streamer = None
if stream:
streamer = TextStreamer(tokenizer, skip_prompt=True)
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens bug
generate_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def prediction_step(
gen_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
# fix generate warning
if ('max_length' in gen_kwargs and 'max_new_tokens' in gen_kwargs
and gen_kwargs['max_length'] is None):
and gen_kwargs['max_new_tokens'] is not None):
gen_kwargs.pop('max_length')
gen_time = time.time()
generate_inputs = inputs.copy()
Expand Down