diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ada07de804..914b5defbb 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -574,6 +574,7 @@ soft overlong 奖励参数 - write_batch_size: 结果写入`result_path`的batch_size。默认为1000。若设置为-1,则不受限制。 - metric: 对推理的结果进行评估,目前支持'acc'和'rouge'。默认为None,即不进行评估。 - val_dataset_sample: 推理数据集采样数,默认为None。 +- reranker_use_activation: 是否在score之后使用sigmoid,默认为True。 ### 部署参数 diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index e16e54d522..d4f880342a 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -242,6 +242,9 @@ |[Qwen/Qwen3-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-Reranker-8B)|qwen3_reranker|qwen3_reranker|-|✘|-|[Qwen/Qwen3-Reranker-8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B)| |[iic/gte_Qwen2-1.5B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-1.5B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)| |[iic/gte_Qwen2-7B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-7B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-7B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)| +|[BAAI/bge-reranker-base](https://modelscope.cn/models/BAAI/bge-reranker-base)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)| +|[BAAI/bge-reranker-v2-m3](https://modelscope.cn/models/BAAI/bge-reranker-v2-m3)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)| +|[BAAI/bge-reranker-large](https://modelscope.cn/models/BAAI/bge-reranker-large)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)| |[codefuse-ai/CodeFuse-QWen-14B](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B)|codefuse_qwen|codefuse|-|✘|coding|[codefuse-ai/CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B)| |[iic/ModelScope-Agent-7B](https://modelscope.cn/models/iic/ModelScope-Agent-7B)|modelscope_agent|modelscope_agent|-|✘|-|-| |[iic/ModelScope-Agent-14B](https://modelscope.cn/models/iic/ModelScope-Agent-14B)|modelscope_agent|modelscope_agent|-|✘|-|-| diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index aec9572478..6de9b94f5a 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -594,6 +594,7 @@ Inference arguments include the [base arguments](#base-arguments), [merge argume - write_batch_size: The batch size for writing results to result_path. Defaults to 1000. If set to -1, there is no restriction. - metric: Evaluate the results of the inference, currently supporting 'acc' and 'rouge'. The default is None, meaning no evaluation is performed. - val_dataset_sample: Number of samples from the inference dataset, default is None. +- reranker_use_activation: Use sigmoid after reranker score, default True. ### Deployment Arguments diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 22e07e8693..273a1d9a45 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -242,6 +242,9 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen3-Reranker-8B](https://modelscope.cn/models/Qwen/Qwen3-Reranker-8B)|qwen3_reranker|qwen3_reranker|-|✘|-|[Qwen/Qwen3-Reranker-8B](https://huggingface.co/Qwen/Qwen3-Reranker-8B)| |[iic/gte_Qwen2-1.5B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-1.5B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)| |[iic/gte_Qwen2-7B-instruct](https://modelscope.cn/models/iic/gte_Qwen2-7B-instruct)|qwen2_gte|dummy|-|✘|-|[Alibaba-NLP/gte-Qwen2-7B-instruct](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)| +|[BAAI/bge-reranker-base](https://modelscope.cn/models/BAAI/bge-reranker-base)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)| +|[BAAI/bge-reranker-v2-m3](https://modelscope.cn/models/BAAI/bge-reranker-v2-m3)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)| +|[BAAI/bge-reranker-large](https://modelscope.cn/models/BAAI/bge-reranker-large)|bge_reranker|bge_reranker|-|✘|-|[BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large)| |[codefuse-ai/CodeFuse-QWen-14B](https://modelscope.cn/models/codefuse-ai/CodeFuse-QWen-14B)|codefuse_qwen|codefuse|-|✘|coding|[codefuse-ai/CodeFuse-QWen-14B](https://huggingface.co/codefuse-ai/CodeFuse-QWen-14B)| |[iic/ModelScope-Agent-7B](https://modelscope.cn/models/iic/ModelScope-Agent-7B)|modelscope_agent|modelscope_agent|-|✘|-|-| |[iic/ModelScope-Agent-14B](https://modelscope.cn/models/iic/ModelScope-Agent-14B)|modelscope_agent|modelscope_agent|-|✘|-|-| diff --git a/examples/deploy/reranker/client.py b/examples/deploy/reranker/client.py new file mode 100644 index 0000000000..a9c0fd655f --- /dev/null +++ b/examples/deploy/reranker/client.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from openai import OpenAI + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def infer(client, model: str, messages): + resp = client.chat.completions.create(model=model, messages=messages) + scores = resp.choices[0].message.content + print(f'messages: {messages}') + print(f'scores: {scores}') + return scores + + +def run_client(host: str = '127.0.0.1', port: int = 8000): + client = OpenAI( + api_key='EMPTY', + base_url=f'http://{host}:{port}/v1', + ) + model = client.models.list().data[0].id + print(f'model: {model}') + + messages = [{ + 'role': 'user', + 'content': 'what is the capital of China?', + }, { + 'role': 'assistant', + 'content': 'Beijing', + }] + infer(client, model, messages) + + +if __name__ == '__main__': + from swift.llm import run_deploy, DeployArguments + with run_deploy( + DeployArguments( + model='BAAI/bge-reranker-v2-m3', + task_type='reranker', + infer_backend='vllm', + gpu_memory_utilization=0.7, + vllm_enforce_eager=True, + reranker_use_activation=False, + verbose=False, + log_interval=-1)) as port: + run_client(port=port) diff --git a/examples/deploy/reranker/client_generative.py b/examples/deploy/reranker/client_generative.py new file mode 100644 index 0000000000..6c4c79e5e6 --- /dev/null +++ b/examples/deploy/reranker/client_generative.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from openai import OpenAI + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def infer(client, model: str, messages): + resp = client.chat.completions.create(model=model, messages=messages) + scores = resp.choices[0].message.content + print(f'messages: {messages}') + print(f'scores: {scores}') + return scores + + +def run_client(host: str = '127.0.0.1', port: int = 8000): + client = OpenAI( + api_key='EMPTY', + base_url=f'http://{host}:{port}/v1', + ) + model = client.models.list().data[0].id + print(f'model: {model}') + + messages = [{ + 'role': 'user', + 'content': 'what is the capital of China?', + }, { + 'role': 'assistant', + 'content': 'Beijing.', + }] + infer(client, model, messages) + + +if __name__ == '__main__': + from swift.llm import run_deploy, DeployArguments + with run_deploy( + DeployArguments( + model='Qwen/Qwen3-Reranker-0.6B', + task_type='generative_reranker', + infer_backend='vllm', + gpu_memory_utilization=0.7, + verbose=False, + log_interval=-1)) as port: + run_client(port=port) diff --git a/examples/deploy/reranker/server.sh b/examples/deploy/reranker/server.sh new file mode 100644 index 0000000000..a61b6dd9b4 --- /dev/null +++ b/examples/deploy/reranker/server.sh @@ -0,0 +1,9 @@ +# GME/GTE models or your checkpoints are also supported +# pt/vllm/sglang supported +CUDA_VISIBLE_DEVICES=0 swift deploy \ + --host 0.0.0.0 \ + --port 8000 \ + --model BAAI/bge-reranker-v2-m3 \ + --infer_backend vllm \ + --task_type reranker \ + --vllm_enforce_eager true \ diff --git a/examples/deploy/seq_cls/client.py b/examples/deploy/seq_cls/client.py new file mode 100644 index 0000000000..e26f2d4671 --- /dev/null +++ b/examples/deploy/seq_cls/client.py @@ -0,0 +1,45 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from openai import OpenAI + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def infer(client, model: str, messages): + resp = client.chat.completions.create(model=model, messages=messages) + classify = resp.choices[0].message.content + print(f'messages: {messages}') + print(f'classify: {classify}') + return classify + + +def run_client(host: str = '127.0.0.1', port: int = 8000): + client = OpenAI( + api_key='EMPTY', + base_url=f'http://{host}:{port}/v1', + ) + model = client.models.list().data[0].id + print(f'model: {model}') + + messages = [{ + 'role': 'user', + 'content': 'What is the capital of China?', + }, { + 'role': 'assistant', + 'content': 'Beijing', + }] + infer(client, model, messages) + + +if __name__ == '__main__': + from swift.llm import run_deploy, DeployArguments + with run_deploy( + DeployArguments( + model='/your/seq_cls/checkpoint-xxx', + task_type='seq_cls', + infer_backend='vllm', + num_labels=2, + verbose=False, + log_interval=-1)) as port: + run_client(port=port) diff --git a/examples/deploy/seq_cls/server.sh b/examples/deploy/seq_cls/server.sh new file mode 100644 index 0000000000..31dcfb68a4 --- /dev/null +++ b/examples/deploy/seq_cls/server.sh @@ -0,0 +1,9 @@ +# GME/GTE models or your checkpoints are also supported +# pt/vllm/sglang supported +CUDA_VISIBLE_DEVICES=0 swift deploy \ + --host 0.0.0.0 \ + --port 8000 \ + --model /your/seq_cls/checkpoint-xxx \ + --infer_backend vllm \ + --task_type seq_cls \ + --num_labels 2 \ diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 1976e4713c..a7c00174f1 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -95,6 +95,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg result_path (Optional[str]): Directory to store inference results. Default is None. max_batch_size (int): Maximum batch size for the pt engine. Default is 1. val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None. + reranker_use_activation (bool): reranker use activation after calculating. Default is True. """ infer_backend: Literal['vllm', 'pt', 'sglang', 'lmdeploy'] = 'pt' @@ -107,6 +108,9 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg # only for inference val_dataset_sample: Optional[int] = None + # for reranker + reranker_use_activation: bool = True + def _get_result_path(self, folder_name: str) -> str: result_dir = self.ckpt_dir or f'result/{self.model_suffix}' os.makedirs(result_dir, exist_ok=True) diff --git a/swift/llm/infer/deploy.py b/swift/llm/infer/deploy.py index 9bf1a25f00..bce6fad469 100644 --- a/swift/llm/infer/deploy.py +++ b/swift/llm/infer/deploy.py @@ -116,7 +116,7 @@ def _post_process(self, request_info, response, return_cmpl_response: bool = Fal (tuple, list)): continue for j, content in enumerate(response.choices[i].message.content): - if content['type'] == 'image': + if isinstance(content, dict) and content['type'] == 'image': b64_image = MultiModalRequestMixin.to_base64(content['image']) response.choices[i].message.content[j]['image'] = f'data:image/jpg;base64,{b64_image}' diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 906bd53635..b393ae05d6 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -32,6 +32,7 @@ def __init__(self, args: Optional[Union[List[str], InferArguments]] = None) -> N if args.infer_backend == 'pt': model, self.template = prepare_model_template(args) self.infer_engine = PtEngine.from_model_template(model, self.template, max_batch_size=args.max_batch_size) + self.infer_engine.reranker_use_activation = args.reranker_use_activation logger.info(f'model: {self.infer_engine.model}') else: self.template = args.get_template(None) @@ -54,6 +55,7 @@ def get_infer_engine(args: InferArguments, template=None, **kwargs): 'revision': args.model_revision, 'torch_dtype': args.torch_dtype, 'template': template, + 'reranker_use_activation': args.reranker_use_activation, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend if infer_backend == 'pt': diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 75630d388f..f173c4dbfc 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -2,6 +2,7 @@ import asyncio import hashlib import inspect +import os import pickle import time from copy import deepcopy @@ -11,6 +12,7 @@ import json import torch +import torch.nn.functional as F from PIL import Image from tqdm import tqdm from transformers import GenerationConfig, LogitsProcessorList @@ -76,6 +78,7 @@ def __init__( task_type=task_type, model_kwargs=model_kwargs, **kwargs) + self.reranker_use_activation = kwargs.pop('reranker_use_activation', True) self.max_batch_size = max_batch_size if isinstance(adapters, str): adapters = [adapters] @@ -327,6 +330,9 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req elif 'last_hidden_state' in output: # embeddings logits = output['last_hidden_state'] + else: + raise NotImplementedError('Only support `logits` or `hidden_state` in output.') + if template.task_type == 'seq_cls': preds, logprobs = template.decode_seq_cls(logits, top_logprobs) elif template.task_type == 'prm': @@ -335,6 +341,27 @@ def _infer_forward(self, template: Template, inputs: Dict[str, Any], adapter_req elif template.task_type == 'embedding': preds = logits logprobs = [None] * len(preds) + elif template.task_type in ('reranker', 'generative_reranker'): + if template.task_type == 'generative_reranker': + # Qwen3-reranker like + positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') + negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') + token_false_id = template.tokenizer.convert_tokens_to_ids(negative_token) + token_true_id = template.tokenizer.convert_tokens_to_ids(positive_token) + batch_scores = logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + preds = batch_scores[:, 1].exp() + else: + preds = logits + if self.reranker_use_activation: + preds = F.sigmoid(preds) + preds = preds.tolist() + if not isinstance(preds[0], list): + preds = [preds] + logprobs = [None] * len(preds) else: raise ValueError(f'Unsupported task_type: {template.task_type}') @@ -521,8 +548,9 @@ def _gen_wrapper(): return _gen_wrapper() else: if len(kwargs) > 0: - infer_func = self._infer_forward if template.task_type in {'seq_cls', 'prm', 'embedding' - } else self._infer_full + infer_func = self._infer_forward if template.task_type in { + 'seq_cls', 'prm', 'embedding', 'reranker', 'generative_reranker' + } else self._infer_full res = infer_func(**kwargs) else: res = [] diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index b2bd8d8b0a..33bca0f44d 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -30,6 +30,7 @@ os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400' import vllm from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLMEngine + from vllm.pooling_params import PoolingParams except Exception: raise @@ -80,6 +81,8 @@ def __init__( reasoning_parser: Optional[str] = None, engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, + num_labels: Optional[int] = None, + reranker_use_activation: bool = True, ) -> None: if engine_kwargs is None: engine_kwargs = {} @@ -92,6 +95,7 @@ def __init__( self.default_adapter_request = AdapterRequest('default', adapters[0]) patch_vllm_memory_leak() self.use_async_engine = use_async_engine + self.reranker_use_activation = reranker_use_activation self.processor = get_model_tokenizer( model_id_or_path, torch_dtype, @@ -101,6 +105,7 @@ def __init__( use_hf=use_hf, hub_token=hub_token, revision=revision, + num_labels=num_labels, task_type=task_type)[1] self._post_init(template) @@ -168,6 +173,10 @@ def _prepare_engine_kwargs( ) -> None: if task == 'embedding': task = 'embed' + elif task == 'seq_cls': + task = 'classify' + elif task in ('reranker', 'generative_reranker'): + task = 'score' disable_log_stats = engine_kwargs.pop('disable_log_stats', True) if self.use_async_engine: engine_cls = AsyncEngineArgs @@ -203,6 +212,8 @@ def _prepare_engine_kwargs( if self.model_meta.model_type in arch_mapping: architectures = arch_mapping[self.model_meta.model_type] engine_kwargs['hf_overrides'] = {'architectures': architectures} + self.default_template.set_mode('vllm') + engine_kwargs.update(self.default_template.prepare_engine_kwargs()) engine_args = engine_cls( model=self.model_dir, dtype=dtype_mapping[model_info.torch_dtype], @@ -333,12 +344,23 @@ def _add_request(self, mm_processor_kwargs = inputs.get('mm_processor_kwargs') if mm_processor_kwargs: llm_inputs['mm_processor_kwargs'] = mm_processor_kwargs - if self.task_type == 'embedding': - from vllm.pooling_params import PoolingParams - if 'task' in inspect.signature(PoolingParams).parameters: - pooling_params = PoolingParams(task='embed') - else: - pooling_params = PoolingParams() + + has_task_arg = 'task' in inspect.signature(PoolingParams).parameters + has_activation_arg = 'activation' in inspect.signature(PoolingParams).parameters + task_mapping = { + 'embedding': 'embed', + 'seq_cls': 'classify', + 'reranker': 'score', + 'generative_reranker': 'score', + } + if self.task_type in task_mapping: + pooling_kwargs = {} + if has_task_arg: + pooling_kwargs['task'] = task_mapping[self.task_type] + if self.task_type in ('reranker', 'generative_reranker') and \ + has_activation_arg and self.reranker_use_activation: + pooling_kwargs['activation'] = True + pooling_params = PoolingParams(**pooling_kwargs) return self.engine.encode(llm_inputs, pooling_params, request_id) elif self.use_async_engine: return self.engine.generate(llm_inputs, generation_config, request_id, **kwargs) @@ -549,6 +571,44 @@ def _create_chat_completion_response( prompt_token_ids=prompt_token_ids, images_size=images_size) + def _create_seq_cls_response( + self, + result, + template, + request_config, + request_id, + ) -> ChatCompletionResponse: + assert result is not None + choices = [] + preds = result.outputs.data + if preds.dim() == 1: + preds = preds.unsqueeze(0) + if self.task_type == 'seq_cls': + top_logprobs = request_config.top_logprobs or 20 + preds, logprobs = template.decode_seq_cls(preds, top_logprobs) + else: + logprobs = [None] * len(preds) + num_prompt_token_ids = 0 + num_generated_tokens = 0 + for i, pred in enumerate(preds): + num_prompt_token_ids += len(result.prompt_token_ids) + num_generated_tokens += 1 + if isinstance(pred, torch.Tensor): + pred = pred.tolist() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=pred, tool_calls=None), + finish_reason='stop', + logprobs=logprobs[i])) + usage_info = self._get_usage_info(num_prompt_token_ids, num_generated_tokens) + return ChatCompletionResponse( + model=self.model_name, + choices=choices, + usage=usage_info, + id=request_id, + prompt_token_ids=result.prompt_token_ids) + async def _infer_full_async( self, template: Template, @@ -566,6 +626,8 @@ async def _infer_full_async( pass if self.task_type == 'embedding': return self._create_embedding_response(result, template, generation_config, request_id) + elif self.task_type in ('seq_cls', 'reranker', 'generative_reranker'): + return self._create_seq_cls_response(result, template, request_config, request_id) else: return self._create_chat_completion_response(result, inputs, template, request_config, request_id) diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index 5df0984e10..e14e9127f9 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -25,6 +25,8 @@ class LLMModelType: qwen2_gte = 'qwen2_gte' + bge_reranker = 'bge_reranker' + codefuse_qwen = 'codefuse_qwen' modelscope_agent = 'modelscope_agent' marco_o1 = 'marco_o1' diff --git a/swift/llm/model/model/baai.py b/swift/llm/model/model/baai.py index fdc7a0fe29..3edfe6afb6 100644 --- a/swift/llm/model/model/baai.py +++ b/swift/llm/model/model/baai.py @@ -3,11 +3,11 @@ import sys from typing import Any, Dict -from transformers import AutoModel +from transformers import AutoModel, AutoModelForSequenceClassification from swift.llm import TemplateType from swift.utils import get_device -from ..constant import MLLMModelType +from ..constant import LLMModelType, MLLMModelType from ..model_arch import ModelArch from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model from ..utils import ModelInfo, git_clone_github, safe_snapshot_download @@ -94,3 +94,18 @@ def get_model_tokenizer_emu3_chat(model_dir: str, tags=['vision'], requires=['transformers>=4.44.0'], )) + +register_model( + ModelMeta( + LLMModelType.bge_reranker, + [ + ModelGroup([ + Model('BAAI/bge-reranker-base', 'BAAI/bge-reranker-base'), + Model('BAAI/bge-reranker-v2-m3', 'BAAI/bge-reranker-v2-m3'), + Model('BAAI/bge-reranker-large', 'BAAI/bge-reranker-large'), + ]), + ], + TemplateType.bge_reranker, + get_model_tokenizer_with_flash_attn, + architectures=['XLMRobertaForSequenceClassification'], + )) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 6313215305..43aefde09e 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -252,6 +252,9 @@ def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: else: i += 1 + def prepare_engine_kwargs(self) -> Dict[str, Any]: + return {} + def _preprocess_inputs( self, inputs: StdTemplateInputs, @@ -413,31 +416,36 @@ def _embedding_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: return _encoded def _reranker_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - chosen = inputs.chosen - instruction = chosen.system + if self.is_training: + chosen = inputs.chosen + instruction = chosen.system - _encoded = defaultdict(list) - labels = [] + _encoded = defaultdict(list) + labels = [] - for positive in inputs.positive: - if instruction is not None and positive.system is None: - positive.system = instruction - positive.messages = chosen.messages + positive.messages - positive_encoded = self._encode_truncated(positive) - labels.append(1) - for key in positive_encoded: - _encoded[key].append(positive_encoded[key]) - - for negative in inputs.negative: - if instruction is not None and negative.system is None: - negative.system = instruction - negative.messages = chosen.messages + negative.messages - negative_encoded = self._encode_truncated(negative) - labels.append(0) - for key in negative_encoded: - _encoded[key].append(negative_encoded[key]) - - _encoded['labels'] = labels + for positive in inputs.positive: + if instruction is not None and positive.system is None: + positive.system = instruction + positive.messages = chosen.messages + positive.messages + positive_encoded = self._encode_truncated(positive) + labels.append(1) + for key in positive_encoded: + _encoded[key].append(positive_encoded[key]) + + for negative in inputs.negative: + if instruction is not None and negative.system is None: + negative.system = instruction + negative.messages = chosen.messages + negative.messages + negative_encoded = self._encode_truncated(negative) + labels.append(0) + for key in negative_encoded: + _encoded[key].append(negative_encoded[key]) + + _encoded['labels'] = labels + else: + anchor = inputs.chosen + _encoded = self._encode_truncated(anchor) + _encoded.pop('labels', None) return _encoded def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: @@ -1190,6 +1198,8 @@ def _get_length(input_ids, labels): def _encode_truncated(self, inputs: StdTemplateInputs): self._preprocess_inputs(inputs) if self.mode in {'vllm', 'lmdeploy', 'sglang'}: + # For multi-modal models, images do not need to be pre processed here + # vllm/lmdeploy/sglang will handle the logic encoded = Template._encode(self, inputs) keys = ['images', 'audios', 'videos'] if self.mode == 'vllm': @@ -1534,26 +1544,32 @@ def _reranker_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: - max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) - max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) - labels_list = [] - new_batch = [] - for b in batch: - labels = b.pop('labels') - positive_num = sum(labels) - negative_num = len(labels) - positive_num - max_positive = min(positive_num, max_positive_samples) - max_negative = min(negative_num, max_negative_samples) - for i in random.sample(range(positive_num), max_positive): - new_batch.append({'input_ids': b['input_ids'][i]}) - labels_list.append(1) - for j in random.sample(range(negative_num), max_negative): - new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) - labels_list.append(0) - - res = self._data_collator(new_batch, padding_to=padding_to) - if labels_list: - res['labels'] = torch.tensor(labels_list, dtype=torch.long) + if self.is_training: + max_positive_samples = int(os.environ.get('MAX_POSITIVE_SAMPLES', 1)) + max_negative_samples = int(os.environ.get('MAX_NEGATIVE_SAMPLES', 7)) + labels_list = [] + new_batch = [] + for b in batch: + labels = b.pop('labels', None) + positive_num = sum(labels) + negative_num = len(labels) - positive_num + max_positive = min(positive_num, max_positive_samples) + max_negative = min(negative_num, max_negative_samples) + for i in random.sample(range(positive_num), max_positive): + new_batch.append({'input_ids': b['input_ids'][i]}) + labels_list.append(1) + for j in random.sample(range(negative_num), max_negative): + new_batch.append({'input_ids': b['input_ids'][j + positive_num]}) + labels_list.append(0) + + res = self._data_collator(new_batch, padding_to=padding_to) + if labels_list: + res['labels'] = torch.tensor(labels_list, dtype=torch.long) + else: + new_batch = [] + for b in batch: + new_batch.append({'input_ids': b['input_ids']}) + res = self._data_collator(new_batch, padding_to=padding_to) return res def _seq_cls_data_collator(self, diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index e5d15d2e73..4aed17b3b7 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -41,6 +41,7 @@ class LLMTemplateType: ziya = 'ziya' atom = 'atom' mengzi = 'mengzi' + bge_reranker = 'bge_reranker' chatglm2 = 'chatglm2' glm4 = 'glm4' diff --git a/swift/llm/template/template/__init__.py b/swift/llm/template/template/__init__.py index 21856663c1..3c875c517e 100644 --- a/swift/llm/template/template/__init__.py +++ b/swift/llm/template/template/__init__.py @@ -1,3 +1,3 @@ -from . import (baidu, bert, deepseek, dots, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm, +from . import (baai, baidu, bert, deepseek, dots, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm, megrez, microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral, qwen, seed, stepfun, valley, yi) diff --git a/swift/llm/template/template/emu3.py b/swift/llm/template/template/baai.py similarity index 97% rename from swift/llm/template/template/emu3.py rename to swift/llm/template/template/baai.py index 872c9cfa3e..2db1fe241e 100644 --- a/swift/llm/template/template/emu3.py +++ b/swift/llm/template/template/baai.py @@ -8,7 +8,7 @@ from swift.utils import get_device from ..base import Template -from ..constant import MLLMTemplateType +from ..constant import LLMTemplateType, MLLMTemplateType from ..register import register_template from ..template_inputs import StdTemplateInputs from ..template_meta import TemplateMeta @@ -193,3 +193,12 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: suffix=[['eos_token_id']], default_system=DEFAULT_SYSTEM, template_cls=Emu3ChatTemplate)) + +register_template( + TemplateMeta( + LLMTemplateType.bge_reranker, + prefix=[' '], + chat_sep=[], + prompt=['{{QUERY}} '], + suffix=[''], + )) diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index 6d490691c0..81bd55550f 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -86,6 +86,18 @@ def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None: inputs.messages = [{'role': 'user', 'content': user_message}] return inputs + def prepare_engine_kwargs(self) -> Dict[str, Any]: + if self.mode == 'vllm': + return { + 'hf_overrides': { + 'architectures': ['Qwen3ForSequenceClassification'], + 'classifier_from_token': ['no', 'yes'], + 'is_original_qwen3_reranker': True, + } + } + else: + return super().prepare_engine_kwargs() + qwen3_reranker_system = ( 'Judge whether the Document meets the requirements based on the Query and the Instruct provided. ' diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 68a23b6cdd..547fe81431 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -247,10 +247,12 @@ def get_vllm_engine_kwargs(self): 'use_async_engine': self.vllm_use_async_engine, 'quantization': self.vllm_quantization, 'reasoning_parser': self.vllm_reasoning_parser, - 'disable_cascade_attn': self.vllm_disable_cascade_attn + 'disable_cascade_attn': self.vllm_disable_cascade_attn, + 'num_labels': self.num_labels, } - if self.task_type == 'embedding': - kwargs['task_type'] = 'embedding' + if self.task_type in ('embedding', 'seq_cls') or 'reranker' in self.task_type: + kwargs['task_type'] = self.task_type + return kwargs