## Inference
We have trained a well-trained checkpoint through the `self-cognition-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it.

In [8]:
# import some libraries
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from swift.llm import InferEngine, InferRequest, PtEngine, RequestConfig, get_template
from swift.tuners import Swift

In [9]:
# Hyperparameters for inference
last_model_checkpoint = 'output/checkpoint-xxx'

# model
model_id_or_path = 'Qwen/Qwen2.5-3B-Instruct'  # model_id or model_path
system = 'You are a helpful assistant.'
infer_backend = 'pt'

# generation_config
max_new_tokens = 512
temperature = 0
stream = True

In [None]:
# Get model and template, and load LoRA weights.
engine = PtEngine(model_id_or_path)
engine.model = Swift.from_pretrained(engine.model, last_model_checkpoint)
template = get_template(engine.model.model_meta.template, engine.tokenizer, default_system=system)
# The default mode of the template is 'pt', so there is no need to make any changes.
# template.set_mode('pt')

In [None]:
query_list = [
    'who are you?',
    "What should I do if I can't sleep at night?",
    '你是谁训练的？',
]

def infer_stream(engine: InferEngine, infer_request: InferRequest):
    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)
    gen = engine.infer([infer_request], request_config)
    query = infer_request.messages[0]['content']
    print(f'query: {query}\nresponse: ', end='')
    for resp_list in gen:
        print(resp_list[0].choices[0].delta.content, end='', flush=True)
    print()

def infer(engine: InferEngine, infer_request: InferRequest):
    request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
    resp_list = engine.infer([infer_request], request_config)
    query = infer_request.messages[0]['content']
    response = resp_list[0].choices[0].message.content
    print(f'query: {query}')
    print(f'response: {response}')

infer_func = infer_stream if stream else infer
for query in query_list:
    infer_func(engine, InferRequest(messages=[{'role': 'user', 'content': query}]))
    print('-' * 50)