diff --git a/README.md b/README.md index 81ddc47fb9..8b38e55180 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 ## 🎉 News -- 2023.1.4: Support for **VLLM deployment**, compatible with the OpenAI API style. For more details, please refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md#部署) +- 2023.1.4: Support for **VLLM deployment**, compatible with the **OpenAI API** style. For more details, please refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md#部署) - 2023.1.4: Update [Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md) to facilitate viewing the training speed and GPU memory required for different models. - 🔥 2023.12.29: Support web-ui for training and inference, use `swift web-ui` after the installation of ms-swift. - 🔥 2023.12.29: Support DPO RLHF(Reinforcement Learning from Human Feedback) and two datasets: AI-ModelScope/stack-exchange-paired and AI-ModelScope/hh-rlhf for this task. Use [this script](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh) to start training! @@ -113,7 +113,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 - Quickly perform **inference** on LLM and build a **Web-UI**, see the [LLM Inference Documentation](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM推理文档.md). - Rapidly **fine-tune** and perform inference on LLM, and build a Web-UI. See the [LLM Fine-tuning Documentation](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md) and [WEB-UI Documentation](https://github.com/modelscope/swift/blob/main/docs/source/GetStarted/%E7%95%8C%E9%9D%A2%E8%AE%AD%E7%BB%83%E6%8E%A8%E7%90%86.md). - **DPO training** supported, start by using [this script](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh). -- Utilize VLLM for **inference acceleration** and **deployment(openai API)**. Please refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md) for more information. +- Utilize VLLM for **inference acceleration** and **deployment(OpenAI API)**. Please refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md) for more information. - View the models and datasets supported by Swift. You can check [supported models and datasets](https://github.com/modelscope/swift/blob/main/docs/source/LLM/支持的模型和数据集.md). - Expand and customize models, datasets, and dialogue templates in Swift, see [Customization and Expansion](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自定义与拓展.md). - Check command-line parameters for fine-tuning and inference, see [Command-Line parameters](https://github.com/modelscope/swift/blob/main/docs/source/LLM/命令行参数.md). diff --git a/README_CN.md b/README_CN.md index 8359d30eaf..4927328963 100644 --- a/README_CN.md +++ b/README_CN.md @@ -60,7 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。 ## 🎉 新闻 -- 2023.1.4: 支持**VLLM部署**, 兼容openai API样式, 具体可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md#部署). +- 2023.1.4: 支持**VLLM部署**, 兼容**OpenAI API**样式, 具体可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md#部署). - 2023.1.4: 更新[Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md), 方便查看不同模型训练的速度和所需显存. - 🔥 2023.12.29: 支持web-ui进行sft训练和推理,安装ms-swift后使用`swift web-ui`开启 - 🔥 2023.12.29: 支持 DPO RLHF(Reinforcement Learning from Human Feedback) 和两个用于此任务的数据集: AI-ModelScope/stack-exchange-paired 以及 AI-ModelScope/hh-rlhf. 使用[这个脚本](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh)开启训练! @@ -111,7 +111,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 - 快速对LLM进行**推理**, 搭建**Web-UI**, 可以查看[LLM推理文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM推理文档.md). - 快速对LLM进行**微调**, 推理并搭建Web-UI. 可以查看[LLM微调文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md) 和 [WEB-UI文档](https://github.com/modelscope/swift/blob/main/docs/source/GetStarted/%E7%95%8C%E9%9D%A2%E8%AE%AD%E7%BB%83%E6%8E%A8%E7%90%86.md). - 支持**DPO训练**, 使用[这个脚本](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh)开启训练 -- 使用VLLM进行**推理加速**和**部署(openai API)**. 可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md). +- 使用VLLM进行**推理加速**和**部署(OpenAI API)**. 可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md). - 查看swift支持的模型和数据集. 可以查看[支持的模型和数据集](https://github.com/modelscope/swift/blob/main/docs/source/LLM/支持的模型和数据集.md). - 对swift中的模型, 数据集, 对话模板进行**拓展**, 可以查看[自定义与拓展](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自定义与拓展.md). - 查询微调和推理的命令行参数, 可以查看[命令行参数](https://github.com/modelscope/swift/blob/main/docs/source/LLM/命令行参数.md). diff --git "a/docs/source/LLM/VLLM\346\216\250\347\220\206\345\212\240\351\200\237\344\270\216\351\203\250\347\275\262.md" "b/docs/source/LLM/VLLM\346\216\250\347\220\206\345\212\240\351\200\237\344\270\216\351\203\250\347\275\262.md" index 4ade27a857..17bf5f72d7 100644 --- "a/docs/source/LLM/VLLM\346\216\250\347\220\206\345\212\240\351\200\237\344\270\216\351\203\250\347\275\262.md" +++ "b/docs/source/LLM/VLLM\346\216\250\347\220\206\345\212\240\351\200\237\344\270\216\351\203\250\347\275\262.md" @@ -19,6 +19,7 @@ pip install -e .[llm] # vllm与cuda版本有对应关系,请按照`https://docs.vllm.ai/en/latest/getting_started/installation.html`选择版本 pip install vllm -U +pip install openai -U # 环境对齐 (如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试) pip install -r requirements/framework.txt -U @@ -239,14 +240,50 @@ swift使用VLLM作为推理后端, 并兼容openai的API样式. 客户端的openai的API参数可以参考: https://platform.openai.com/docs/api-reference/introduction. ### 原始模型 -**qwen-7b-chat** +#### qwen-7b-chat -服务端: +**服务端:** ```bash CUDA_VISIBLE_DEVICES=0 swift deploy --model_type qwen-7b-chat ``` -客户端: +**客户端:** + +使用swift: +```python +from swift.llm import get_model_list_client, XRequest, inference_client + +model_list = get_model_list_client() +model_type = model_list.data[0].id +print(f'model_type: {model_type}') + +query = '浙江的省会在哪里?' +request_kwargs = XRequest(model=model_type, seed=42) +resp = inference_client(query, request_kwargs=request_kwargs) +response = resp.choices[0].message.content +print(f'query: {query}') +print(f'response: {response}') + +history = [(query, response)] +query = '这有什么好吃的?' +request_kwargs = XRequest(model=model_type, stream=True, seed=42) +stream_resp = inference_client(query, history, request_kwargs=request_kwargs) +print(f'query: {query}') +print('response: ', end='') +for chunk in stream_resp: + print(chunk.choices[0].delta.content, end='', flush=True) +print() + +"""Out[0] +model_type: qwen-7b-chat +query: 浙江的省会在哪里? +response: 浙江省的省会是杭州市。 +query: 这有什么好吃的? +response: 杭州有许多美食,例如西湖醋鱼、东坡肉、龙井虾仁、叫化童子鸡等。此外,杭州还有许多特色小吃,如西湖藕粉、杭州小笼包、杭州油条等。 +""" +``` + +使用openai: ```python from openai import OpenAI client = OpenAI( @@ -263,7 +300,8 @@ messages = [{ }] resp = client.chat.completions.create( model=model_type, - messages=messages) + messages=messages, + seed=42) response = resp.choices[0].message.content print(f'query: {query}') print(f'response: {response}') @@ -272,14 +310,15 @@ print(f'response: {response}') messages.append({'role': 'assistant', 'content': response}) query = '这有什么好吃的?' messages.append({'role': 'user', 'content': query}) -stream = client.chat.completions.create( +stream_resp = client.chat.completions.create( model=model_type, messages=messages, - stream=True) + stream=True, + seed=42) print(f'query: {query}') print('response: ', end='') -for chunk in stream: +for chunk in stream_resp: print(chunk.choices[0].delta.content, end='', flush=True) print() @@ -288,19 +327,67 @@ model_type: qwen-7b-chat query: 浙江的省会在哪里? response: 浙江省的省会是杭州市。 query: 这有什么好吃的? -response: -浙江省是一个美食天堂,有着丰富多样的美食,如新鲜海鲜、麻糍、竹筒饭、西湖醋鱼、小吃等。至于具体哪个更好吃,可能还要看您个人的口味。 +response: 杭州有许多美食,例如西湖醋鱼、东坡肉、龙井虾仁、叫化童子鸡等。此外,杭州还有许多特色小吃,如西湖藕粉、杭州小笼包、杭州油条等。 """ ``` -**qwen-7b** +#### qwen-7b -服务端: +**服务端:** ```bash CUDA_VISIBLE_DEVICES=0 swift deploy --model_type qwen-7b ``` -客户端: +**客户端:** + +使用swift: +```python +from swift.llm import get_model_list_client, XRequest, inference_client + +model_list = get_model_list_client() +model_type = model_list.data[0].id +print(f'model_type: {model_type}') + +query = '浙江 -> 杭州\n安徽 -> 合肥\n四川 ->' +request_kwargs = XRequest(model=model_type, max_tokens=32, temperature=0.1, seed=42) +resp = inference_client(query, request_kwargs=request_kwargs) +response = resp.choices[0].text +print(f'query: {query}') +print(f'response: {response}') + +request_kwargs.stream = True +stream_resp = inference_client(query, request_kwargs=request_kwargs) +print(f'query: {query}') +print('response: ', end='') +for chunk in stream_resp: + print(chunk.choices[0].text, end='', flush=True) +print() + +"""Out[0] +model_type: qwen-7b +query: 浙江 -> 杭州 +安徽 -> 合肥 +四川 -> +response: 成都 +广东 -> 广州 +江苏 -> 南京 +浙江 -> 杭州 +安徽 -> 合肥 +四川 -> 成都 + +query: 浙江 -> 杭州 +安徽 -> 合肥 +四川 -> +response: 成都 +广东 -> 广州 +江苏 -> 南京 +浙江 -> 杭州 +安徽 -> 合肥 +四川 -> 成都 +""" +``` + +使用openai: ```python from openai import OpenAI client = OpenAI( @@ -311,7 +398,7 @@ model_type = client.models.list().data[0].id print(f'model_type: {model_type}') query = '浙江 -> 杭州\n安徽 -> 合肥\n四川 ->' -kwargs = {'model': model_type, 'prompt': query, 'seed': 42, 'temperature': 0., 'max_tokens': 32} +kwargs = {'model': model_type, 'prompt': query, 'seed': 42, 'temperature': 0.1, 'max_tokens': 32} resp = client.completions.create(**kwargs) response = resp.choices[0].text @@ -319,12 +406,11 @@ print(f'query: {query}') print(f'response: {response}') # 流式 -query = '浙江 -> 杭州\n安徽 -> 合肥\n四川 ->' -stream = client.completions.create(stream=True, **kwargs) +stream_resp = client.completions.create(stream=True, **kwargs) response = resp.choices[0].text print(f'query: {query}') print('response: ', end='') -for chunk in stream: +for chunk in stream_resp: print(chunk.choices[0].text, end='', flush=True) print() @@ -360,4 +446,4 @@ swift merge-lora --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx' CUDA_VISIBLE_DEVICES=0 swift deploy --ckpt_dir 'xxx/vx_xxx/checkpoint-xxx-merged' ``` -客户端代码示例同原始模型. +客户端示例代码同原始模型. diff --git "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index f9ac794d35..6eca8fc968 100644 --- "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -9,7 +9,7 @@ - `--model_id_or_path`: 表示模型在ModelScope Hub中的`model_id`, 不区分大小写, 默认为`None`. 如果`--model_id_or_path`未被注册, 则会抛出异常. 你可以使用`model_type`的方式指定模型类型, 也可以通过`model_id_or_path`的方式指定模型类型. - `--model_revision`: 表示模型在ModelScope Hub中对应`model_id`的版本号, 默认为`None`. `model_revision`指定为`None`, 则使用注册在`MODEL_MAPPING`中的revision. 否则强制使用命令行传入的`model_revision`. - `--model_cache_dir`: 默认为`None`. 如果模型在本地已经有缓存, 且缓存路径并非ModelScope默认cache路径, 可以通过指定该参数从cache_dir中导入model和tokenizer. -- `--sft_type`: 表示微调的方式, 默认是`'lora'`. 你可以选择的值包括: 'lora', 'full'. 如果你要使用qlora, 你需设置`--sft_type lora --quantization_bit 4`. +- `--sft_type`: 表示微调的方式, 默认是`'lora'`. 你可以选择的值包括: 'lora', 'full', 'longlora', 'qalora'. 如果你要使用qlora, 你需设置`--sft_type lora --quantization_bit 4`. - `--freeze_parameters`: 当sft_type指定为'full'时, 将模型最底部的参数进行freeze. 指定范围为0. ~ 1., 默认为`0.`. 该参数提供了lora与全参数微调的折中方案. - `--tuner_backend`: 表示lora, qlora的后端支持, 默认是`'swift'`. 你可以选择的值包括: 'swift', 'peft'. - `--template_type`: 表示使用的对话模板的类型, 默认是`'AUTO'`, 即根据`model_type`查找`MODEL_MAPPING`中的`template`. 可以选择的`template_type`可以查看`TEMPLATE_MAPPING.keys()`. diff --git a/requirements/framework.txt b/requirements/framework.txt index 499594b0ee..ca0c0fd909 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,4 +1,5 @@ accelerate +dacite datasets jieba matplotlib diff --git a/requirements/llm.txt b/requirements/llm.txt index c5b0c9fc3a..0aab5201c6 100644 --- a/requirements/llm.txt +++ b/requirements/llm.txt @@ -1,6 +1,8 @@ charset_normalizer cpm_kernels +fastapi gradio>=3.40.0 sentencepiece tiktoken transformers_stream_generator +uvicorn diff --git a/swift/llm/deploy.py b/swift/llm/deploy.py index 8f603e3fcb..313a4e374a 100644 --- a/swift/llm/deploy.py +++ b/swift/llm/deploy.py @@ -217,7 +217,7 @@ async def _generate_stream(): usage=usage_info, id=request_id, created=created_time) - yield f'data:{json.dumps(asdict(response))}\n\n' + yield f'data:{json.dumps(asdict(response), ensure_ascii=False)}\n\n' yield 'data:[DONE]\n\n' if request.stream: diff --git a/swift/llm/tuner.py b/swift/llm/tuner.py index 7faf7eb3e8..4a377ccabd 100644 --- a/swift/llm/tuner.py +++ b/swift/llm/tuner.py @@ -5,14 +5,14 @@ from swift.tuners import (LongLoRAConfig, LongLoRAModelType, LoraConfig, LoRAConfig, NEFTuneConfig, Swift) from swift.utils import freeze_model_parameters, get_logger -from .utils import SftArguments, find_all_linear_for_lora +from .utils import SftArguments, find_all_linear_for_lora, is_lora logger = get_logger() def prepare_model(model, args: SftArguments): # Preparing LoRA - if args.sft_type in ('lora', 'qalora', 'longlora'): + if is_lora(args.sft_type): if args.resume_from_checkpoint is None: if 'ALL' in args.lora_target_modules: assert len(args.lora_target_modules) == 1 @@ -20,13 +20,13 @@ def prepare_model(model, args: SftArguments): model, args.quantization_bit, args.model_type) logger.info( f'Setting lora_target_modules: {args.lora_target_modules}') + lora_kwargs = { + 'r': args.lora_rank, + 'target_modules': args.lora_target_modules, + 'lora_alpha': args.lora_alpha, + 'lora_dropout': args.lora_dropout_p + } if args.sft_type == 'lora': - lora_kwargs = { - 'r': args.lora_rank, - 'target_modules': args.lora_target_modules, - 'lora_alpha': args.lora_alpha, - 'lora_dropout': args.lora_dropout_p - } if args.tuner_backend == 'swift': lora_config = LoRAConfig( lora_dtype=args.lora_dtype, **lora_kwargs) @@ -36,35 +36,26 @@ def prepare_model(model, args: SftArguments): model = Swift.prepare_model(model, lora_config) logger.info(f'lora_config: {lora_config}') elif args.sft_type == 'longlora': - assert args.tuner_backend != 'peft', ( - 'peft does not support longlora. You need to set `--tuner_backend swift`.' - ) + assert args.tuner_backend == 'swift' assert LongLoRAModelType.LLAMA in args.model_type longlora_config = LongLoRAConfig( - r=args.lora_rank, - target_modules=args.lora_target_modules, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout_p, lora_dtype=args.lora_dtype, model_type=LongLoRAModelType.LLAMA, - use_flash_attn=args.use_flash_attn) + use_flash_attn=args.use_flash_attn, + **lora_kwargs) model = Swift.prepare_model(model, longlora_config) logger.info(f'longlora_config: {longlora_config}') elif args.sft_type == 'qalora': + assert args.tuner_backend == 'swift' assert getattr( model, 'quantization_method', None) == 'gptq', 'qalora must be used with auto_gptq' - lora_kwargs = {} - lora_config = LoRAConfig( - r=args.lora_rank, - target_modules=args.lora_target_modules, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout_p, + qalora_config = LoRAConfig( lora_dtype=args.lora_dtype, use_qa_lora=True, **lora_kwargs) - model = Swift.prepare_model(model, lora_config) - logger.info(f'lora_config: {lora_config}') + model = Swift.prepare_model(model, qalora_config) + logger.info(f'qalora_config: {qalora_config}') else: model = Swift.from_pretrained( model, args.resume_from_checkpoint, is_trainable=True) diff --git a/swift/llm/utils/__init__.py b/swift/llm/utils/__init__.py index 9dabbdfd16..273b49f58a 100644 --- a/swift/llm/utils/__init__.py +++ b/swift/llm/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .argument import (DeployArguments, DPOArguments, InferArguments, - RomeArguments, SftArguments) + RomeArguments, SftArguments, is_lora) +from .client_utils import get_model_list_client, inference_client from .dataset import (DATASET_MAPPING, DatasetName, GetDatasetFunction, HfDataset, add_self_cognition_dataset, get_dataset, get_dataset_from_repo, load_dataset_from_local, @@ -23,9 +24,10 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, Model, - ModelList, UsageInfo, random_uuid) + ModelList, UsageInfo, XRequest, random_uuid) from .template import (DEFAULT_SYSTEM, TEMPLATE_MAPPING, History, Prompt, - Template, TemplateType, get_template, register_template) + StopWords, Template, TemplateType, get_template, + register_template) from .utils import (LazyLLMDataset, LLMDataset, data_collate_fn, dataset_map, download_dataset, find_all_linear_for_lora, get_time_info, history_to_messages, inference, inference_stream, diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index ecfe94c0fe..bf0047c04a 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -23,6 +23,10 @@ logger = get_logger() +def is_lora(sft_type: str) -> bool: + return sft_type in {'lora', 'longlora', 'qalora'} + + @dataclass class SftArguments: # You can specify the model by either using the model_type or model_id_or_path. @@ -33,7 +37,7 @@ class SftArguments: model_revision: Optional[str] = None model_cache_dir: Optional[str] = None - sft_type: Literal['lora', 'longlora', 'qalora', 'full'] = 'lora' + sft_type: Literal['lora', 'full', 'longlora', 'qalora'] = 'lora' freeze_parameters: float = 0. # 0 ~ 1 tuner_backend: Literal['swift', 'peft'] = 'swift' template_type: str = field( @@ -79,7 +83,7 @@ class SftArguments: bnb_4bit_comp_dtype: Literal['fp16', 'bf16', 'fp32', 'AUTO'] = 'AUTO' bnb_4bit_quant_type: Literal['fp4', 'nf4'] = 'nf4' bnb_4bit_use_double_quant: bool = True - + # lora lora_target_modules: List[str] = field(default_factory=lambda: ['DEFAULT']) lora_rank: int = 8 lora_alpha: int = 32 @@ -200,7 +204,7 @@ def __post_init__(self) -> None: self.output_dir = add_version_to_work_dir(self.output_dir) logger.info(f'output_dir: {self.output_dir}') - if self.sft_type in ('lora', 'longlora', 'qalora'): + if is_lora(self.sft_type): assert self.freeze_parameters == 0., ( 'lora does not support `freeze_parameters`, please set `--sft_type full`' ) diff --git a/swift/llm/utils/client_utils.py b/swift/llm/utils/client_utils.py new file mode 100644 index 0000000000..320a277eee --- /dev/null +++ b/swift/llm/utils/client_utils.py @@ -0,0 +1,88 @@ +from typing import Iterator, Optional, Union + +import json +import requests +from dacite import from_dict +from requests.exceptions import HTTPError + +from .model import get_default_template_type +from .protocol import (ChatCompletionResponse, ChatCompletionStreamResponse, + CompletionResponse, CompletionStreamResponse, ModelList, + XRequest) +from .template import History +from .utils import history_to_messages + + +def get_model_list_client(host: str = '127.0.0.1', + port: str = '8000') -> ModelList: + url = f'http://{host}:{port}/v1/models' + resp_obj = requests.get(url).json() + return from_dict(ModelList, resp_obj) + + +def _parse_stream_data(data: bytes) -> Optional[str]: + data = data.decode(encoding='utf-8') + data = data.strip() + if len(data) == 0: + return + assert data.startswith('data:') + return data[5:].strip() + + +def inference_client( + query: str, + history: Optional[History] = None, + system: Optional[str] = None, + *, + request_kwargs: Optional[XRequest], + host: str = '127.0.0.1', + port: str = '8000', + is_chat_request: Optional[bool] = None, +) -> Union[ChatCompletionResponse, CompletionResponse, + Iterator[ChatCompletionStreamResponse], + Iterator[CompletionStreamResponse]]: + if is_chat_request is None: + template_type = get_default_template_type(request_kwargs.model) + is_chat_request = 'generation' not in template_type + data = { + k: v + for k, v in request_kwargs.__dict__.items() if not k.startswith('__') + } + if is_chat_request: + data['messages'] = history_to_messages(history, query, system) + url = f'http://{host}:{port}/v1/chat/completions' + else: + assert system is None and history is None, ( + 'The chat template for text generation does not support system and history.' + ) + data['prompt'] = query + url = f'http://{host}:{port}/v1/completions' + if request_kwargs.stream: + if is_chat_request: + ret_cls = ChatCompletionStreamResponse + else: + ret_cls = CompletionStreamResponse + resp = requests.post(url, json=data, stream=True) + + def _gen_stream() -> Union[Iterator[ChatCompletionStreamResponse], + Iterator[CompletionStreamResponse]]: + for data in resp.iter_lines(): + data = _parse_stream_data(data) + if data == '[DONE]': + break + if data is not None: + resp_obj = json.loads(data) + if resp_obj['object'] == 'error': + raise HTTPError(resp_obj['message']) + yield from_dict(ret_cls, resp_obj) + + return _gen_stream() + else: + resp_obj = requests.post(url, json=data).json() + if is_chat_request: + ret_cls = ChatCompletionResponse + else: + ret_cls = CompletionResponse + if resp_obj['object'] == 'error': + raise HTTPError(resp_obj['message']) + return from_dict(ret_cls, resp_obj) diff --git a/swift/llm/utils/protocol.py b/swift/llm/utils/protocol.py index 22d3f940a9..fb93599ee8 100644 --- a/swift/llm/utils/protocol.py +++ b/swift/llm/utils/protocol.py @@ -12,6 +12,9 @@ def random_uuid() -> str: @dataclass class Model: id: str # model_type + object: str = 'model' + created: int = field(default_factory=lambda: int(time.time())) + owned_by: str = 'swift' @dataclass @@ -127,7 +130,7 @@ class DeltaMessage: class ChatCompletionResponseStreamChoice: index: int delta: DeltaMessage - finish_reason: Literal['stop', 'length'] + finish_reason: Literal['stop', 'length', None] @dataclass @@ -144,7 +147,7 @@ class ChatCompletionStreamResponse: class CompletionResponseStreamChoice: index: int text: str - finish_reason: Literal['stop', 'length'] + finish_reason: Literal['stop', 'length', None] @dataclass diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 872cc3a37a..4845cf0fe4 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -649,13 +649,16 @@ def compute_token_length(history_length: int) -> int: Messages = List[Dict[str, str]] -def history_to_messages(history: History, +def history_to_messages(history: Optional[History], query: Optional[str] = None, system: Optional[str] = None) -> Messages: + if history is None: + history = [] messages = [] if system is not None: messages.append({'role': 'system', 'content': system}) for h in history: + assert isinstance(h, (list, tuple)) messages.append({'role': 'user', 'content': h[0]}) messages.append({'role': 'assistant', 'content': h[1]}) if query is not None: