Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions examples/infer/demo_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def infer(engine: 'InferEngine', infer_request: 'InferRequest'):
request_config = RequestConfig(max_tokens=512, temperature=0, stop=['Observation:'])
resp_list = engine.infer([infer_request], request_config)
query = infer_request.messages[0]['content']
response = resp_list[0].choices[0].message.content
tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}\n'
print(f'query: {query}')
print(f'response: {response}{tool}', end='')

infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
resp_list = engine.infer([infer_request], request_config)
response2 = resp_list[0].choices[0].message.content
print(response2)


def infer_stream(engine: 'InferEngine', infer_request: 'InferRequest'):
request_config = RequestConfig(max_tokens=512, temperature=0, stop=['Observation:'], stream=True)
gen = engine.infer([infer_request], request_config)
query = infer_request.messages[0]['content']
response = ''
tool = '{"temperature": 72, "condition": "Sunny", "humidity": 50}\n'
print(f'query: {query}')
for resp_list in gen:
delta = resp_list[0].choices[0].delta.content
response += delta
print(delta, end='', flush=True)
print(tool, end='')

infer_request.messages += [{'role': 'assistant', 'content': response}, {'role': 'tool', 'content': tool}]
gen = engine.infer([infer_request], request_config)
for resp_list in gen:
print(resp_list[0].choices[0].delta.content, end='', flush=True)
print()


def get_infer_request():
return InferRequest(
messages=[{
'role': 'user',
'content': "How's the weather today?"
}],
tools=[{
'name': 'get_current_weather',
'description': 'Get the current weather in a given location',
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The city and state, e.g. San Francisco, CA'
},
'unit': {
'type': 'string',
'enum': ['celsius', 'fahrenheit']
}
},
'required': ['location']
}
}])


if __name__ == '__main__':
from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig
model = 'Qwen/Qwen2.5-1.5B-Instruct'
infer_backend = 'pt'

if infer_backend == 'pt':
engine = PtEngine(model, max_batch_size=64)
elif infer_backend == 'vllm':
from swift.llm import VllmEngine
engine = VllmEngine(model, max_model_len=32768)
elif infer_backend == 'lmdeploy':
from swift.llm import LmdeployEngine
engine = LmdeployEngine(model)

infer(engine, get_infer_request())
infer_stream(engine, get_infer_request())
31 changes: 20 additions & 11 deletions examples/infer/demo_multilora.py → examples/infer/demo_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,36 @@ def infer_multilora(infer_request: 'InferRequest', infer_backend: Literal['vllm'
print(f'lora2-response: {response}')


def infer_pt(infer_request: 'InferRequest'):
def infer_lora(infer_request: 'InferRequest'):
request_config = RequestConfig(max_tokens=512, temperature=0)
adapter_path = safe_snapshot_download('swift/test_lora')
args = BaseArguments.from_pretrained(adapter_path)
engine = PtEngine(args.model, adapters=[adapter_path])
template = get_template(args.template, engine.tokenizer, args.system)
request_config = RequestConfig(max_tokens=512, temperature=0)
# method1
# engine = PtEngine(args.model, adapters=[adapter_path])
# template = get_template(args.template, engine.tokenizer, args.system)
# engine.default_template = template

# use lora
resp_list = engine.infer([infer_request], request_config, template=template)
response = resp_list[0].choices[0].message.content
print(f'lora-response: {response}')
# method2
engine.default_template = template
# model, processor = args.get_model_processor()
# model = Swift.from_pretrained(model, adapter_path)
# template = args.get_template(processor)
# engine = PtEngine.from_model_template(model, template)

# method3
model, tokenizer = get_model_tokenizer(args.model)
model = Swift.from_pretrained(model, adapter_path)
template = get_template(args.template, tokenizer, args.system)
engine = PtEngine.from_model_template(model, template)

resp_list = engine.infer([infer_request], request_config)
response = resp_list[0].choices[0].message.content
print(f'lora-response: {response}')


if __name__ == '__main__':
from swift.llm import (PtEngine, RequestConfig, AdapterRequest, get_template, BaseArguments, InferRequest,
safe_snapshot_download)
safe_snapshot_download, get_model_tokenizer)
from swift.tuners import Swift
infer_request = InferRequest(messages=[{'role': 'user', 'content': '你是谁'}])
# infer_lora(infer_request)
infer_multilora(infer_request, 'pt')
# infer_pt(infer_request)
10 changes: 0 additions & 10 deletions swift/llm/argument/base_args/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,9 @@ def _init_model_info(self) -> torch.dtype:
self._init_rope_scaling()
return self.model_info.torch_dtype

def _init_task_type(self):
if self.task_type is None:
if self.num_labels is None:
self.task_type = 'causal_lm'
else:
self.task_type = 'seq_cls'
if self.task_type == 'seq_cls':
assert self.num_labels is not None, 'Please set --num_labels <num_labels>.'

def __post_init__(self):
if self.model is None:
raise ValueError(f'Please set --model <model_id_or_path>`, model: {self.model}')
self._init_task_type()
self.model_suffix = get_model_name(self.model)
self._init_device_map()
self._init_torch_dtype()
Expand Down
14 changes: 10 additions & 4 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,15 @@ def get_model_info_meta(
torch_dtype = model_meta.torch_dtype or get_default_torch_dtype(model_info.torch_dtype)
logger.info(f'Setting torch_dtype: {torch_dtype}')
model_info.torch_dtype = torch_dtype
if model_meta.is_reward:
task_type = 'seq_cls'
num_labels = 1
if task_type is None:
if model_meta.is_reward:
num_labels = 1
if num_labels is None:
task_type = 'causal_lm'
else:
task_type = 'seq_cls'
if task_type == 'seq_cls':
assert num_labels is not None, 'Please pass the parameter `num_labels`.'
model_info.task_type = task_type
model_info.num_labels = num_labels

Expand All @@ -409,7 +415,7 @@ def get_model_tokenizer(
attn_impl: Literal['flash_attn', 'sdpa', 'eager', None] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
automodel_class=None,
task_type: Literal['causal_lm', 'seq_cls'] = 'causal_lm',
task_type: Literal['causal_lm', 'seq_cls'] = None,
num_labels: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
**kwargs) -> Tuple[Optional[PreTrainedModel], PreTrainedTokenizerBase]:
Expand Down
22 changes: 19 additions & 3 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,26 @@ def _preprocess_inputs(
'The template does not support multi-round chat. Only use the last round of the conversation.')
inputs.messages = inputs.messages[-2:]

if self.model_meta.is_multimodal:
self._replace_image_tags(inputs)
if inputs.is_multimodal:
self._add_default_tags(inputs)

@staticmethod
def _replace_image_tags(inputs: StdTemplateInputs):
# compat
images = []
pattern = r'<img>(.+?)</img>'
for message in inputs.messages:
content = message['content']
if not isinstance(content, str):
continue
images += re.findall(pattern, content)
message['content'] = re.sub(pattern, '<image>', content)
if images:
assert not inputs.images, f'images: {images}, inputs.images: {inputs.images}'
inputs.images = images

def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
Expand Down Expand Up @@ -206,13 +223,12 @@ def encode(self,
inputs = asdict(inputs)

if isinstance(inputs, dict):
if not self.is_training:
InferRequest.remove_response(inputs['messages'])
inputs = StdTemplateInputs.from_dict(inputs, tools_prompt=self.tools_prompt)
elif isinstance(inputs, StdTemplateInputs):
inputs = deepcopy(inputs)

if not self.is_training:
InferRequest.remove_response(inputs.messages)

assert isinstance(inputs, StdTemplateInputs)
self._preprocess_inputs(inputs)
if self.mode in {'vllm', 'lmdeploy'}:
Expand Down
Loading