Skip to content
70 changes: 67 additions & 3 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from openai.types.chat.chat_completion_message_function_tool_call import Function
from openai.types.responses import Response
from openai.types.responses.tool_choice_function import ToolChoiceFunction

from ... import _debug
from ...agent_output import AgentOutputSchemaBase
Expand Down Expand Up @@ -77,6 +78,68 @@ def __init__(
self.base_url = base_url
self.api_key = api_key

@staticmethod
def _convert_tool_choice_for_response(
tool_choice: Any,
) -> Literal["auto", "required", "none"] | ToolChoiceFunction:
"""
Convert various tool_choice formats to the format expected by Response.

Args:
tool_choice: Can be:
- omit/NotGiven: Defaults to "auto"
- Literal["auto", "required", "none"]: Used directly
- ToolChoiceFunction: Used directly
- dict (from ChatCompletions Converter): Converted to ToolChoiceFunction
Format: {"type": "function", "function": {"name": "tool_name"}}

Returns:
Literal["auto", "required", "none"] | ToolChoiceFunction

Examples:
>>> LitellmModel._convert_tool_choice_for_response(omit)
"auto"
>>> LitellmModel._convert_tool_choice_for_response("required")
"required"
>>> LitellmModel._convert_tool_choice_for_response(
... {"type": "function", "function": {"name": "my_tool"}}
... )
ToolChoiceFunction(type="function", name="my_tool")
"""
# Handle omit/NotGiven
if tool_choice is omit or isinstance(tool_choice, NotGiven):
return "auto"

# Already a ToolChoiceFunction, use directly
if isinstance(tool_choice, ToolChoiceFunction):
return tool_choice

# Convert from ChatCompletions format dict to ToolChoiceFunction
# ChatCompletions Converter returns: {"type": "function", "function": {"name": "..."}}
if isinstance(tool_choice, dict):
func_data = tool_choice.get("function")
if (
tool_choice.get("type") == "function"
and func_data is not None
and isinstance(func_data, dict)
):
tool_name = func_data.get("name")
if isinstance(tool_name, str) and tool_name: # Ensure non-empty string
return ToolChoiceFunction(type="function", name=tool_name)
else:
# Fallback to auto if name is missing or invalid
return "auto"
else:
# Fallback to auto if unexpected format
return "auto"

# Handle literal strings
if tool_choice in ("auto", "required", "none"):
return cast(Literal["auto", "required", "none"], tool_choice)

# Fallback to auto for any other case
return "auto"

async def get_response(
self,
system_instructions: str | None,
Expand Down Expand Up @@ -367,15 +430,16 @@ async def _fetch_response(
if isinstance(ret, litellm.types.utils.ModelResponse):
return ret

# Convert tool_choice to the correct type for Response
response_tool_choice = self._convert_tool_choice_for_response(tool_choice)

response = Response(
id=FAKE_RESPONSES_ID,
created_at=time.time(),
model=self.model,
object="response",
output=[],
tool_choice=cast(Literal["auto", "required", "none"], tool_choice)
if tool_choice is not omit
else "auto",
tool_choice=response_tool_choice,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this, and it’s still not fixed, response_tool_choice always ends up being "auto", even when I pass:ModelSettings(tool_choice="my_tool")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks your test, I will test it again later!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Fixed in commit fca3ed5 and verified with integration testing.

Root cause: The initial fix incorrectly assumed LiteLLM uses openai_responses.Converter (flat format), but it actually uses chatcmpl_converter.Converter which returns nested ChatCompletions format.

The fix: Now correctly handles the nested dict structure {"type": "function", "function": {"name": "my_tool"}} by accessing tool_choice.get("function").get("name") (lines 382-393).

Verification: Integration test confirms that when ModelSettings(tool_choice="my_specific_tool") is passed, litellm.acompletion receives the correct nested dict format, and Response.tool_choice is properly set to ToolChoiceFunction(name="my_specific_tool").

Test output:

litellm.acompletion called with tool_choice: {'type': 'function', 'function': {'name': 'my_specific_tool'}}

The fix is now working correctly!

top_p=model_settings.top_p,
temperature=model_settings.temperature,
tools=[],
Expand Down
156 changes: 156 additions & 0 deletions tests/extensions/test_litellm_tool_choice_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Unit tests for LitellmModel._convert_tool_choice_for_response

Tests the static method that converts various tool_choice formats
to the format expected by the Response type.

Related to Issue #1846: Support tool_choice with specific tool names in LiteLLM streaming
"""

from openai import NotGiven, omit
from openai.types.responses.tool_choice_function import ToolChoiceFunction

from agents.extensions.models.litellm_model import LitellmModel


class TestConvertToolChoiceForResponse:
"""Test the _convert_tool_choice_for_response static method."""

def test_convert_omit_returns_auto(self):
"""Test that omit is converted to 'auto'"""
result = LitellmModel._convert_tool_choice_for_response(omit)
assert result == "auto"

def test_convert_not_given_returns_auto(self):
"""Test that NotGiven is converted to 'auto'"""
result = LitellmModel._convert_tool_choice_for_response(NotGiven())
assert result == "auto"

def test_convert_literal_auto(self):
"""Test that literal 'auto' is preserved"""
result = LitellmModel._convert_tool_choice_for_response("auto")
assert result == "auto"

def test_convert_literal_required(self):
"""Test that literal 'required' is preserved"""
result = LitellmModel._convert_tool_choice_for_response("required")
assert result == "required"

def test_convert_literal_none(self):
"""Test that literal 'none' is preserved"""
result = LitellmModel._convert_tool_choice_for_response("none")
assert result == "none"

def test_convert_tool_choice_function_preserved(self):
"""Test that ToolChoiceFunction is preserved as-is"""
tool_choice = ToolChoiceFunction(type="function", name="my_tool")
result = LitellmModel._convert_tool_choice_for_response(tool_choice)
assert result == tool_choice
assert isinstance(result, ToolChoiceFunction)
assert result.name == "my_tool"

def test_convert_dict_from_chatcompletions_converter(self):
"""
Test conversion from ChatCompletions Converter dict format.
Format: {"type": "function", "function": {"name": "tool_name"}}
"""
tool_choice_dict = {
"type": "function",
"function": {"name": "my_custom_tool"},
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert isinstance(result, ToolChoiceFunction)
assert result.type == "function"
assert result.name == "my_custom_tool"

def test_convert_dict_missing_function_name_returns_auto(self):
"""Test that dict without function name falls back to 'auto'"""
tool_choice_dict = {
"type": "function",
"function": {}, # Missing 'name'
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert result == "auto"

def test_convert_dict_empty_function_name_returns_auto(self):
"""Test that dict with empty function name falls back to 'auto'"""
tool_choice_dict = {
"type": "function",
"function": {"name": ""}, # Empty name
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert result == "auto"

def test_convert_dict_missing_function_key_returns_auto(self):
"""Test that dict without 'function' key falls back to 'auto'"""
tool_choice_dict = {"type": "function"} # Missing 'function' key
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert result == "auto"

def test_convert_dict_wrong_type_returns_auto(self):
"""Test that dict with wrong type falls back to 'auto'"""
tool_choice_dict = {
"type": "wrong_type",
"function": {"name": "my_tool"},
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert result == "auto"

def test_convert_dict_function_not_dict_returns_auto(self):
"""Test that dict with non-dict function value falls back to 'auto'"""
tool_choice_dict = {
"type": "function",
"function": "not_a_dict",
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert result == "auto"

def test_convert_unexpected_type_returns_auto(self):
"""Test that unexpected types fall back to 'auto'"""
result = LitellmModel._convert_tool_choice_for_response(123)
assert result == "auto"

result = LitellmModel._convert_tool_choice_for_response([])
assert result == "auto"

result = LitellmModel._convert_tool_choice_for_response(None)
assert result == "auto"


class TestToolChoiceConversionEdgeCases:
"""Test edge cases and real-world scenarios."""

def test_real_world_scenario_chatcompletions_format(self):
"""
Test a real-world scenario from ChatCompletions Converter.
This is the actual format returned when tool_choice specifies a tool name.
"""
# This is what ChatCompletions Converter returns
tool_choice_from_converter = {
"type": "function",
"function": {"name": "get_weather"},
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_from_converter)
assert isinstance(result, ToolChoiceFunction)
assert result.name == "get_weather"
assert result.type == "function"

def test_none_string_vs_none_literal(self):
"""Test that string 'none' works but None (NoneType) defaults to auto"""
# String "none" should be preserved
result = LitellmModel._convert_tool_choice_for_response("none")
assert result == "none"

# NoneType should fallback to auto
result = LitellmModel._convert_tool_choice_for_response(None)
assert result == "auto"

def test_complex_tool_name(self):
"""Test that complex tool names are handled correctly"""
tool_choice_dict = {
"type": "function",
"function": {"name": "get_user_profile_with_special_chars_123"},
}
result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict)
assert isinstance(result, ToolChoiceFunction)
assert result.name == "get_user_profile_with_special_chars_123"