Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ soft overlong 奖励参数
- write_batch_size: 结果写入`result_path`的batch_size。默认为1000。若设置为-1,则不受限制。
- metric: 对推理的结果进行评估,目前支持'acc'和'rouge'。默认为None,即不进行评估。
- val_dataset_sample: 推理数据集采样数,默认为None。
- reranker_use_activation: 是否在score之后使用sigmoid,默认为True。


### 部署参数
Expand Down
3 changes: 3 additions & 0 deletions docs/source/Instruction/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@
|[Qwen/Qwen3-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-Reranker-8B)|qwen3_reranker|qwen3_reranker|-|✘|-|[Qwen/Qwen3-Reranker-8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B)|
|[iic/gte_Qwen2-1.5B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-1.5B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)|
|[iic/gte_Qwen2-7B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-7B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-7B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)|
|[BAAI/bge-reranker-base](https://modelscope.cn/models/BAAI/bge-reranker-base)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)|
|[BAAI/bge-reranker-v2-m3](https://modelscope.cn/models/BAAI/bge-reranker-v2-m3)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)|
|[BAAI/bge-reranker-large](https://modelscope.cn/models/BAAI/bge-reranker-large)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)|
|[codefuse-ai/CodeFuse-QWen-14B](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B)|codefuse_qwen|codefuse|-|✘|coding|[codefuse-ai/CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B)|
|[iic/ModelScope-Agent-7B](https://modelscope.cn/models/iic/ModelScope-Agent-7B)|modelscope_agent|modelscope_agent|-|✘|-|-|
|[iic/ModelScope-Agent-14B](https://modelscope.cn/models/iic/ModelScope-Agent-14B)|modelscope_agent|modelscope_agent|-|✘|-|-|
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ Inference arguments include the [base arguments](#base-arguments), [merge argume
- write_batch_size: The batch size for writing results to result_path. Defaults to 1000. If set to -1, there is no restriction.
- metric: Evaluate the results of the inference, currently supporting 'acc' and 'rouge'. The default is None, meaning no evaluation is performed.
- val_dataset_sample: Number of samples from the inference dataset, default is None.
- reranker_use_activation: Use sigmoid after reranker score, default True.

### Deployment Arguments

Expand Down
3 changes: 3 additions & 0 deletions docs/source_en/Instruction/Supported-models-and-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ The table below introduces the models integrated with ms-swift:
|[Qwen/Qwen3-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-Reranker-8B)|qwen3_reranker|qwen3_reranker|-|✘|-|[Qwen/Qwen3-Reranker-8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B)|
|[iic/gte_Qwen2-1.5B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-1.5B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)|
|[iic/gte_Qwen2-7B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-7B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-7B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)|
|[BAAI/bge-reranker-base](https://modelscope.cn/models/BAAI/bge-reranker-base)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)|
|[BAAI/bge-reranker-v2-m3](https://modelscope.cn/models/BAAI/bge-reranker-v2-m3)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)|
|[BAAI/bge-reranker-large](https://modelscope.cn/models/BAAI/bge-reranker-large)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)|
|[codefuse-ai/CodeFuse-QWen-14B](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B)|codefuse_qwen|codefuse|-|✘|coding|[codefuse-ai/CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B)|
|[iic/ModelScope-Agent-7B](https://modelscope.cn/models/iic/ModelScope-Agent-7B)|modelscope_agent|modelscope_agent|-|✘|-|-|
|[iic/ModelScope-Agent-14B](https://modelscope.cn/models/iic/ModelScope-Agent-14B)|modelscope_agent|modelscope_agent|-|✘|-|-|
Expand Down
47 changes: 47 additions & 0 deletions examples/deploy/reranker/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from openai import OpenAI

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def infer(client, model: str, messages):
resp = client.chat.completions.create(model=model, messages=messages)
scores = resp.choices[0].message.content
print(f'messages: {messages}')
print(f'scores: {scores}')
return scores


def run_client(host: str = '127.0.0.1', port: int = 8000):
client = OpenAI(
api_key='EMPTY',
base_url=f'http://{host}:{port}/v1',
)
model = client.models.list().data[0].id
print(f'model: {model}')

messages = [{
'role': 'user',
'content': 'what is the capital of China?',
}, {
'role': 'assistant',
'content': 'Beijing',
}]
infer(client, model, messages)


if __name__ == '__main__':
from swift.llm import run_deploy, DeployArguments
with run_deploy(
DeployArguments(
model='BAAI/bge-reranker-v2-m3',
task_type='reranker',
infer_backend='vllm',
gpu_memory_utilization=0.7,
vllm_enforce_eager=True,
reranker_use_activation=False,
verbose=False,
log_interval=-1)) as port:
run_client(port=port)
45 changes: 45 additions & 0 deletions examples/deploy/reranker/client_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from openai import OpenAI

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def infer(client, model: str, messages):
resp = client.chat.completions.create(model=model, messages=messages)
scores = resp.choices[0].message.content
print(f'messages: {messages}')
print(f'scores: {scores}')
return scores


def run_client(host: str = '127.0.0.1', port: int = 8000):
client = OpenAI(
api_key='EMPTY',
base_url=f'http://{host}:{port}/v1',
)
model = client.models.list().data[0].id
print(f'model: {model}')

messages = [{
'role': 'user',
'content': 'what is the capital of China?',
}, {
'role': 'assistant',
'content': 'Beijing.',
}]
infer(client, model, messages)


if __name__ == '__main__':
from swift.llm import run_deploy, DeployArguments
with run_deploy(
DeployArguments(
model='Qwen/Qwen3-Reranker-0.6B',
task_type='generative_reranker',
infer_backend='vllm',
gpu_memory_utilization=0.7,
verbose=False,
log_interval=-1)) as port:
run_client(port=port)
9 changes: 9 additions & 0 deletions examples/deploy/reranker/server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# GME/GTE models or your checkpoints are also supported
# pt/vllm/sglang supported
CUDA_VISIBLE_DEVICES=0 swift deploy \
--host 0.0.0.0 \
--port 8000 \
--model BAAI/bge-reranker-v2-m3 \
--infer_backend vllm \
--task_type reranker \
--vllm_enforce_eager true \
45 changes: 45 additions & 0 deletions examples/deploy/seq_cls/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from openai import OpenAI

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def infer(client, model: str, messages):
resp = client.chat.completions.create(model=model, messages=messages)
classify = resp.choices[0].message.content
print(f'messages: {messages}')
print(f'classify: {classify}')
return classify


def run_client(host: str = '127.0.0.1', port: int = 8000):
client = OpenAI(
api_key='EMPTY',
base_url=f'http://{host}:{port}/v1',
)
model = client.models.list().data[0].id
print(f'model: {model}')

messages = [{
'role': 'user',
'content': 'What is the capital of China?',
}, {
'role': 'assistant',
'content': 'Beijing',
}]
infer(client, model, messages)


if __name__ == '__main__':
from swift.llm import run_deploy, DeployArguments
with run_deploy(
DeployArguments(
model='/your/seq_cls/checkpoint-xxx',
task_type='seq_cls',
infer_backend='vllm',
num_labels=2,
verbose=False,
log_interval=-1)) as port:
run_client(port=port)
9 changes: 9 additions & 0 deletions examples/deploy/seq_cls/server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# GME/GTE models or your checkpoints are also supported
# pt/vllm/sglang supported
CUDA_VISIBLE_DEVICES=0 swift deploy \
--host 0.0.0.0 \
--port 8000 \
--model /your/seq_cls/checkpoint-xxx \
--infer_backend vllm \
--task_type seq_cls \
--num_labels 2 \
4 changes: 4 additions & 0 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg
result_path (Optional[str]): Directory to store inference results. Default is None.
max_batch_size (int): Maximum batch size for the pt engine. Default is 1.
val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None.
reranker_use_activation (bool): reranker use activation after calculating. Default is True.
"""
infer_backend: Literal['vllm', 'pt', 'sglang', 'lmdeploy'] = 'pt'

Expand All @@ -107,6 +108,9 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg
# only for inference
val_dataset_sample: Optional[int] = None

# for reranker
reranker_use_activation: bool = True

def _get_result_path(self, folder_name: str) -> str:
result_dir = self.ckpt_dir or f'result/{self.model_suffix}'
os.makedirs(result_dir, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/infer/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _post_process(self, request_info, response, return_cmpl_response: bool = Fal
(tuple, list)):
continue
for j, content in enumerate(response.choices[i].message.content):
if content['type'] == 'image':
if isinstance(content, dict) and content['type'] == 'image':
b64_image = MultiModalRequestMixin.to_base64(content['image'])
response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}'

Expand Down
2 changes: 2 additions & 0 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, args: Optional[Union[List[str], InferArguments]] = None) -> N
if args.infer_backend == 'pt':
model, self.template = prepare_model_template(args)
self.infer_engine = PtEngine.from_model_template(model, self.template, max_batch_size=args.max_batch_size)
self.infer_engine.reranker_use_activation = args.reranker_use_activation
logger.info(f'model: {self.infer_engine.model}')
else:
self.template = args.get_template(None)
Expand All @@ -54,6 +55,7 @@ def get_infer_engine(args: InferArguments, template=None, **kwargs):
'revision': args.model_revision,
'torch_dtype': args.torch_dtype,
'template': template,
'reranker_use_activation': args.reranker_use_activation,
})
infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend
if infer_backend == 'pt':
Expand Down
32 changes: 30 additions & 2 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import hashlib
import inspect
import os
import pickle
import time
from copy import deepcopy
Expand All @@ -11,6 +12,7 @@

import json
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from transformers import GenerationConfig, LogitsProcessorList
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(
task_type=task_type,
model_kwargs=model_kwargs,
**kwargs)
self.reranker_use_activation = kwargs.pop('reranker_use_activation', True)
self.max_batch_size = max_batch_size
if isinstance(adapters, str):
adapters = [adapters]
Expand Down Expand Up @@ -327,6 +330,9 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req
elif 'last_hidden_state' in output:
# embeddings
logits = output['last_hidden_state']
else:
raise NotImplementedError('Only support `logits` or `hidden_state` in output.')

if template.task_type == 'seq_cls':
preds, logprobs = template.decode_seq_cls(logits, top_logprobs)
elif template.task_type == 'prm':
Expand All @@ -335,6 +341,27 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req
elif template.task_type == 'embedding':
preds = logits
logprobs = [None] * len(preds)
elif template.task_type in ('reranker', 'generative_reranker'):
if template.task_type == 'generative_reranker':
# Qwen3-reranker like
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token)
token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token)
batch_scores = logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
preds = batch_scores[:, 1].exp()
else:
preds = logits
if self.reranker_use_activation:
preds = F.sigmoid(preds)
preds = preds.tolist()
if not isinstance(preds[0], list):
preds = [preds]
logprobs = [None] * len(preds)
else:
raise ValueError(f'Unsupported task_type: {template.task_type}')

Expand Down Expand Up @@ -521,8 +548,9 @@ def _gen_wrapper():
return _gen_wrapper()
else:
if len(kwargs) > 0:
infer_func = self._infer_forward if template.task_type in {'seq_cls', 'prm', 'embedding'
} else self._infer_full
infer_func = self._infer_forward if template.task_type in {
'seq_cls', 'prm', 'embedding', 'reranker', 'generative_reranker'
} else self._infer_full
res = infer_func(**kwargs)
else:
res = []
Expand Down
Loading