Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions src/agents/models/openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,30 +290,39 @@ async def _fetch_response(

stream_param: Literal[True] | Omit = True if stream else omit

ret = await self._get_client().chat.completions.create(
model=self.model,
messages=converted_messages,
tools=tools_param,
temperature=self._non_null_or_omit(model_settings.temperature),
top_p=self._non_null_or_omit(model_settings.top_p),
frequency_penalty=self._non_null_or_omit(model_settings.frequency_penalty),
presence_penalty=self._non_null_or_omit(model_settings.presence_penalty),
max_tokens=self._non_null_or_omit(model_settings.max_tokens),
tool_choice=tool_choice,
response_format=response_format,
parallel_tool_calls=parallel_tool_calls,
stream=cast(Any, stream_param),
stream_options=self._non_null_or_omit(stream_options),
store=self._non_null_or_omit(store),
reasoning_effort=self._non_null_or_omit(reasoning_effort),
verbosity=self._non_null_or_omit(model_settings.verbosity),
top_logprobs=self._non_null_or_omit(model_settings.top_logprobs),
extra_headers=self._merge_headers(model_settings),
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
metadata=self._non_null_or_omit(model_settings.metadata),
**(model_settings.extra_args or {}),
)
request_kwargs: dict[str, Any] = {
"model": self.model,
"messages": converted_messages,
"tools": tools_param,
"temperature": self._non_null_or_omit(model_settings.temperature),
"top_p": self._non_null_or_omit(model_settings.top_p),
"frequency_penalty": self._non_null_or_omit(model_settings.frequency_penalty),
"presence_penalty": self._non_null_or_omit(model_settings.presence_penalty),
"max_tokens": self._non_null_or_omit(model_settings.max_tokens),
"tool_choice": tool_choice,
"response_format": response_format,
"parallel_tool_calls": parallel_tool_calls,
"stream": cast(Any, stream_param),
"stream_options": self._non_null_or_omit(stream_options),
"store": self._non_null_or_omit(store),
"reasoning_effort": self._non_null_or_omit(reasoning_effort),
"verbosity": self._non_null_or_omit(model_settings.verbosity),
"top_logprobs": self._non_null_or_omit(model_settings.top_logprobs),
"extra_headers": self._merge_headers(model_settings),
"extra_query": model_settings.extra_query,
"extra_body": model_settings.extra_body,
"metadata": self._non_null_or_omit(model_settings.metadata),
}

request_kwargs.update(model_settings.extra_args or {})

sanitized_kwargs = {
key: value
for key, value in request_kwargs.items()
if not isinstance(value, Omit) and value is not omit
}

ret = await self._get_client().chat.completions.create(**sanitized_kwargs)

if isinstance(ret, ChatCompletion):
return ret
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_kwargs_functionality.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import litellm
import pytest
from litellm.types.utils import Choices, Message, ModelResponse, Usage
from openai import Omit, omit
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
Expand Down Expand Up @@ -124,6 +125,9 @@ def __init__(self):
# Verify regular parameters are still passed
assert captured["temperature"] == 0.7

assert all(not isinstance(value, Omit) for value in captured.values())
assert omit not in captured.values()


@pytest.mark.allow_call_model_methods
@pytest.mark.asyncio
Expand Down
29 changes: 19 additions & 10 deletions tests/test_openai_chatcompletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import httpx
import pytest
from openai import AsyncOpenAI, omit
from openai import AsyncOpenAI, Omit, omit
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
Expand Down Expand Up @@ -283,19 +283,22 @@ def __init__(self, completions: DummyCompletions) -> None:
stream=False,
)
assert result is chat

# Ensure expected args were passed through to OpenAI client.
kwargs = completions.kwargs
assert kwargs["stream"] is omit
assert kwargs["store"] is omit
assert kwargs["model"] == "gpt-4"
assert kwargs["messages"][0]["role"] == "system"
assert kwargs["messages"][0]["content"] == "sys"
assert kwargs["messages"][1]["role"] == "user"
# Defaults for optional fields become the omit sentinel
assert kwargs["tools"] is omit
assert kwargs["tool_choice"] is omit
assert kwargs["response_format"] is omit
assert kwargs["stream_options"] is omit
assert kwargs["messages"][1]["content"] == "hi"
assert "stream" not in kwargs
assert "store" not in kwargs
assert "tools" not in kwargs
assert "tool_choice" not in kwargs
assert "response_format" not in kwargs
assert "stream_options" not in kwargs
assert "parallel_tool_calls" not in kwargs
assert all(not isinstance(value, Omit) and value is not omit for value in kwargs.values())


@pytest.mark.asyncio
Expand Down Expand Up @@ -340,8 +343,14 @@ def __init__(self, completions: DummyCompletions) -> None:
)
# Check OpenAI client was called for streaming
assert completions.kwargs["stream"] is True
assert completions.kwargs["store"] is omit
assert completions.kwargs["stream_options"] is omit
assert completions.kwargs["model"] == "gpt-4"
assert "store" not in completions.kwargs
assert "stream_options" not in completions.kwargs
assert "tools" not in completions.kwargs
assert "parallel_tool_calls" not in completions.kwargs
assert all(
not isinstance(value, Omit) and value is not omit for value in completions.kwargs.values()
)
# Response is a proper openai Response
assert isinstance(response, Response)
assert response.id == FAKE_RESPONSES_ID
Expand Down