Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions swift/llm/infer/infer_engine/infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageI
total_tokens=num_prompt_tokens + num_generated_tokens,
)

@staticmethod
def _update_usage_info(origin_use_info: UsageInfo, num_generated_tokens: int) -> UsageInfo:
return UsageInfo(
prompt_tokens=origin_use_info.prompt_tokens,
completion_tokens=origin_use_info.completion_tokens + num_generated_tokens,
total_tokens=origin_use_info.total_tokens + num_generated_tokens,
)

@staticmethod
def _update_metrics(result, metrics: Optional[List[Metric]] = None):
if metrics is None:
Expand Down
50 changes: 29 additions & 21 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> _Generati
if request_config.logprobs:
generation_config.output_logits = True
generation_config.top_logprobs = request_config.top_logprobs
generation_config.num_return_sequences = request_config.n
return _GenerationConfig(**generation_config.to_dict())

def _add_stop_words(self, generation_config: _GenerationConfig, request_config: RequestConfig,
Expand Down Expand Up @@ -322,28 +323,35 @@ def _infer_full(self,
output.get('logits'), batched_generate_ids, generation_config.top_logprobs)

res = []
for i in range(batched_generate_ids.shape[0]):
generate_ids = batched_generate_ids[i]
num_return_sequences = generation_config.num_return_sequences
for i in range(inputs['attention_mask'].shape[0]):
choices = []
usage_info = self._get_usage_info(num_prompt_tokens, 0)
for j in range(num_return_sequences):
batched_index = i * num_return_sequences + j
generate_ids = batched_generate_ids[batched_index]

# ignore pad_token
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs is not None:
logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()]

logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs)
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
response = template.decode(generate_ids, template_inputs=template_inputs[i])
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
toolcall = self._get_toolcall(response, template.tools_prompt)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
# ignore pad_token
masks = generate_ids != self.tokenizer.pad_token_id
generate_ids = generate_ids[masks].tolist()
logprobs_list = None
if batched_logprobs is not None:
logprobs_list = [
logprobs for m, logprobs in zip(masks, batched_logprobs[batched_index]) if m.item()
]

logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids,
generation_config.top_logprobs)
usage_info = self._update_usage_info(usage_info, len(generate_ids))
response = template.decode(generate_ids, template_inputs=template_inputs[i])
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True)
toolcall = self._get_toolcall(response, template.tools_prompt)
choices.append(
ChatCompletionResponseChoice(
index=j,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs))
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
return res

Expand Down
Loading