diff --git a/swift/llm/infer/infer_engine/infer_engine.py b/swift/llm/infer/infer_engine/infer_engine.py index fe6057b383..d1c50c0bb7 100644 --- a/swift/llm/infer/infer_engine/infer_engine.py +++ b/swift/llm/infer/infer_engine/infer_engine.py @@ -101,6 +101,14 @@ def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageI total_tokens=num_prompt_tokens + num_generated_tokens, ) + @staticmethod + def _update_usage_info(origin_use_info: UsageInfo, num_generated_tokens: int) -> UsageInfo: + return UsageInfo( + prompt_tokens=origin_use_info.prompt_tokens, + completion_tokens=origin_use_info.completion_tokens + num_generated_tokens, + total_tokens=origin_use_info.total_tokens + num_generated_tokens, + ) + @staticmethod def _update_metrics(result, metrics: Optional[List[Metric]] = None): if metrics is None: diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 45022dcb8a..7c65e204b0 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -104,6 +104,7 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> _Generati if request_config.logprobs: generation_config.output_logits = True generation_config.top_logprobs = request_config.top_logprobs + generation_config.num_return_sequences = request_config.n return _GenerationConfig(**generation_config.to_dict()) def _add_stop_words(self, generation_config: _GenerationConfig, request_config: RequestConfig, @@ -322,28 +323,35 @@ def _infer_full(self, output.get('logits'), batched_generate_ids, generation_config.top_logprobs) res = [] - for i in range(batched_generate_ids.shape[0]): - generate_ids = batched_generate_ids[i] + num_return_sequences = generation_config.num_return_sequences + for i in range(inputs['attention_mask'].shape[0]): + choices = [] + usage_info = self._get_usage_info(num_prompt_tokens, 0) + for j in range(num_return_sequences): + batched_index = i * num_return_sequences + j + generate_ids = batched_generate_ids[batched_index] - # ignore pad_token - masks = generate_ids != self.tokenizer.pad_token_id - generate_ids = generate_ids[masks].tolist() - logprobs_list = None - if batched_logprobs is not None: - logprobs_list = [logprobs for m, logprobs in zip(masks, batched_logprobs[i]) if m.item()] - - logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs) - usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids)) - response = template.decode(generate_ids, template_inputs=template_inputs[i]) - finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True) - toolcall = self._get_toolcall(response, template.tools_prompt) - choices = [ - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), - finish_reason=finish_reason, - logprobs=logprobs) - ] + # ignore pad_token + masks = generate_ids != self.tokenizer.pad_token_id + generate_ids = generate_ids[masks].tolist() + logprobs_list = None + if batched_logprobs is not None: + logprobs_list = [ + logprobs for m, logprobs in zip(masks, batched_logprobs[batched_index]) if m.item() + ] + + logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, + generation_config.top_logprobs) + usage_info = self._update_usage_info(usage_info, len(generate_ids)) + response = template.decode(generate_ids, template_inputs=template_inputs[i]) + finish_reason = self._get_finish_reason(generation_config.max_new_tokens, num_prompt_tokens, True) + toolcall = self._get_toolcall(response, template.tools_prompt) + choices.append( + ChatCompletionResponseChoice( + index=j, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=finish_reason, + logprobs=logprobs)) res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info)) return res