Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def get_response(
| {"base_url": str(self.base_url or ""), "model_impl": "litellm"},
disabled=tracing.is_disabled(),
) as span_generation:
response = await self._fetch_response(
response, stream = await self._fetch_response(
system_instructions,
input,
model_settings,
Expand All @@ -87,26 +87,30 @@ async def get_response(
handoffs,
span_generation,
tracing,
stream=False,
stream=True,
)

assert isinstance(response.choices[0], litellm.types.utils.Choices)
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
if chunk.type == "response.completed":
response = chunk.response

message = Converter.output_items_to_message(response.output)

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Received model response")
else:
logger.debug(
f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
f"LLM resp:\n{json.dumps(message.model_dump(), indent=2)}\n"
)

if hasattr(response, "usage"):
response_usage = response.usage
usage = (
Usage(
requests=1,
input_tokens=response_usage.prompt_tokens,
output_tokens=response_usage.completion_tokens,
total_tokens=response_usage.total_tokens,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
)
if response.usage
else Usage()
Expand All @@ -116,18 +120,14 @@ async def get_response(
logger.warning("No usage information returned from Litellm")

if tracing.include_data():
span_generation.span_data.output = [response.choices[0].message.model_dump()]
span_generation.span_data.output = [message.model_dump()]
span_generation.span_data.usage = {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
}

items = Converter.message_to_output_items(
LitellmConverter.convert_message_to_openai(response.choices[0].message)
)

return ModelResponse(
output=items,
output=response.output,
usage=usage,
response_id=None,
)
Expand Down
29 changes: 29 additions & 0 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.chat.completion_create_params import ResponseFormat
from openai.types.responses import (
EasyInputMessageParam,
Expand Down Expand Up @@ -81,6 +82,34 @@ def convert_response_format(
},
}

@classmethod
def output_items_to_message(cls, items: list[TResponseOutputItem] ) -> ChatCompletionMessage:
tool_calls: list[ChatCompletionMessageToolCall] | None = None
message = ChatCompletionMessage(role="assistant")

for item in items:
if isinstance(item, ResponseOutputMessage):
if isinstance(item.content, ResponseOutputText):
message.content = item.content.text
elif isinstance(item.content, ResponseOutputRefusal):
message.refusal = item.content.refusal
elif isinstance(item, ResponseFunctionToolCall):
if tool_calls is None:
tool_calls = []
tool_calls.append(
ChatCompletionMessageToolCall(
id=item.call_id,
type="function",
function=Function(
name=item.name,
arguments=item.arguments,
),
)
)

message.tool_calls = tool_calls
return message

@classmethod
def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TResponseOutputItem]:
items: list[TResponseOutputItem] = []
Expand Down
24 changes: 14 additions & 10 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
from openai.types import ChatModel
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.responses import Response

from .. import _debug
Expand Down Expand Up @@ -58,7 +58,7 @@ async def get_response(
model_config=model_settings.to_json_dict() | {"base_url": str(self._client.base_url)},
disabled=tracing.is_disabled(),
) as span_generation:
response = await self._fetch_response(
response, stream = await self._fetch_response(
system_instructions,
input,
model_settings,
Expand All @@ -67,37 +67,41 @@ async def get_response(
handoffs,
span_generation,
tracing,
stream=False,
stream=True,
)

async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
if chunk.type == "response.completed":
response = chunk.response

message = Converter.output_items_to_message(response.output)

if _debug.DONT_LOG_MODEL_DATA:
logger.debug("Received model response")
else:
logger.debug(
f"LLM resp:\n{json.dumps(response.choices[0].message.model_dump(), indent=2)}\n"
f"LLM resp:\n{json.dumps(message.model_dump(), indent=2)}\n"
)

usage = (
Usage(
requests=1,
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
total_tokens=response.usage.total_tokens,
)
if response.usage
else Usage()
)
if tracing.include_data():
span_generation.span_data.output = [response.choices[0].message.model_dump()]
span_generation.span_data.output = [message.model_dump()]
span_generation.span_data.usage = {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
}

items = Converter.message_to_output_items(response.choices[0].message)

return ModelResponse(
output=items,
output=response.output,
usage=usage,
response_id=None,
)
Expand Down