diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 2fc10ae43..8d9d4270e 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -32,6 +32,7 @@ ) from openai.types.chat.chat_completion_message_function_tool_call import Function from openai.types.responses import Response +from openai.types.responses.tool_choice_function import ToolChoiceFunction from ... import _debug from ...agent_output import AgentOutputSchemaBase @@ -77,6 +78,68 @@ def __init__( self.base_url = base_url self.api_key = api_key + @staticmethod + def _convert_tool_choice_for_response( + tool_choice: Any, + ) -> Literal["auto", "required", "none"] | ToolChoiceFunction: + """ + Convert various tool_choice formats to the format expected by Response. + + Args: + tool_choice: Can be: + - omit/NotGiven: Defaults to "auto" + - Literal["auto", "required", "none"]: Used directly + - ToolChoiceFunction: Used directly + - dict (from ChatCompletions Converter): Converted to ToolChoiceFunction + Format: {"type": "function", "function": {"name": "tool_name"}} + + Returns: + Literal["auto", "required", "none"] | ToolChoiceFunction + + Examples: + >>> LitellmModel._convert_tool_choice_for_response(omit) + "auto" + >>> LitellmModel._convert_tool_choice_for_response("required") + "required" + >>> LitellmModel._convert_tool_choice_for_response( + ... {"type": "function", "function": {"name": "my_tool"}} + ... ) + ToolChoiceFunction(type="function", name="my_tool") + """ + # Handle omit/NotGiven + if tool_choice is omit or isinstance(tool_choice, NotGiven): + return "auto" + + # Already a ToolChoiceFunction, use directly + if isinstance(tool_choice, ToolChoiceFunction): + return tool_choice + + # Convert from ChatCompletions format dict to ToolChoiceFunction + # ChatCompletions Converter returns: {"type": "function", "function": {"name": "..."}} + if isinstance(tool_choice, dict): + func_data = tool_choice.get("function") + if ( + tool_choice.get("type") == "function" + and func_data is not None + and isinstance(func_data, dict) + ): + tool_name = func_data.get("name") + if isinstance(tool_name, str) and tool_name: # Ensure non-empty string + return ToolChoiceFunction(type="function", name=tool_name) + else: + # Fallback to auto if name is missing or invalid + return "auto" + else: + # Fallback to auto if unexpected format + return "auto" + + # Handle literal strings + if tool_choice in ("auto", "required", "none"): + return cast(Literal["auto", "required", "none"], tool_choice) + + # Fallback to auto for any other case + return "auto" + async def get_response( self, system_instructions: str | None, @@ -367,15 +430,16 @@ async def _fetch_response( if isinstance(ret, litellm.types.utils.ModelResponse): return ret + # Convert tool_choice to the correct type for Response + response_tool_choice = self._convert_tool_choice_for_response(tool_choice) + response = Response( id=FAKE_RESPONSES_ID, created_at=time.time(), model=self.model, object="response", output=[], - tool_choice=cast(Literal["auto", "required", "none"], tool_choice) - if tool_choice is not omit - else "auto", + tool_choice=response_tool_choice, top_p=model_settings.top_p, temperature=model_settings.temperature, tools=[], diff --git a/tests/extensions/test_litellm_tool_choice_conversion.py b/tests/extensions/test_litellm_tool_choice_conversion.py new file mode 100644 index 000000000..ede12d33a --- /dev/null +++ b/tests/extensions/test_litellm_tool_choice_conversion.py @@ -0,0 +1,156 @@ +""" +Unit tests for LitellmModel._convert_tool_choice_for_response + +Tests the static method that converts various tool_choice formats +to the format expected by the Response type. + +Related to Issue #1846: Support tool_choice with specific tool names in LiteLLM streaming +""" + +from openai import NotGiven, omit +from openai.types.responses.tool_choice_function import ToolChoiceFunction + +from agents.extensions.models.litellm_model import LitellmModel + + +class TestConvertToolChoiceForResponse: + """Test the _convert_tool_choice_for_response static method.""" + + def test_convert_omit_returns_auto(self): + """Test that omit is converted to 'auto'""" + result = LitellmModel._convert_tool_choice_for_response(omit) + assert result == "auto" + + def test_convert_not_given_returns_auto(self): + """Test that NotGiven is converted to 'auto'""" + result = LitellmModel._convert_tool_choice_for_response(NotGiven()) + assert result == "auto" + + def test_convert_literal_auto(self): + """Test that literal 'auto' is preserved""" + result = LitellmModel._convert_tool_choice_for_response("auto") + assert result == "auto" + + def test_convert_literal_required(self): + """Test that literal 'required' is preserved""" + result = LitellmModel._convert_tool_choice_for_response("required") + assert result == "required" + + def test_convert_literal_none(self): + """Test that literal 'none' is preserved""" + result = LitellmModel._convert_tool_choice_for_response("none") + assert result == "none" + + def test_convert_tool_choice_function_preserved(self): + """Test that ToolChoiceFunction is preserved as-is""" + tool_choice = ToolChoiceFunction(type="function", name="my_tool") + result = LitellmModel._convert_tool_choice_for_response(tool_choice) + assert result == tool_choice + assert isinstance(result, ToolChoiceFunction) + assert result.name == "my_tool" + + def test_convert_dict_from_chatcompletions_converter(self): + """ + Test conversion from ChatCompletions Converter dict format. + Format: {"type": "function", "function": {"name": "tool_name"}} + """ + tool_choice_dict = { + "type": "function", + "function": {"name": "my_custom_tool"}, + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert isinstance(result, ToolChoiceFunction) + assert result.type == "function" + assert result.name == "my_custom_tool" + + def test_convert_dict_missing_function_name_returns_auto(self): + """Test that dict without function name falls back to 'auto'""" + tool_choice_dict = { + "type": "function", + "function": {}, # Missing 'name' + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert result == "auto" + + def test_convert_dict_empty_function_name_returns_auto(self): + """Test that dict with empty function name falls back to 'auto'""" + tool_choice_dict = { + "type": "function", + "function": {"name": ""}, # Empty name + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert result == "auto" + + def test_convert_dict_missing_function_key_returns_auto(self): + """Test that dict without 'function' key falls back to 'auto'""" + tool_choice_dict = {"type": "function"} # Missing 'function' key + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert result == "auto" + + def test_convert_dict_wrong_type_returns_auto(self): + """Test that dict with wrong type falls back to 'auto'""" + tool_choice_dict = { + "type": "wrong_type", + "function": {"name": "my_tool"}, + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert result == "auto" + + def test_convert_dict_function_not_dict_returns_auto(self): + """Test that dict with non-dict function value falls back to 'auto'""" + tool_choice_dict = { + "type": "function", + "function": "not_a_dict", + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert result == "auto" + + def test_convert_unexpected_type_returns_auto(self): + """Test that unexpected types fall back to 'auto'""" + result = LitellmModel._convert_tool_choice_for_response(123) + assert result == "auto" + + result = LitellmModel._convert_tool_choice_for_response([]) + assert result == "auto" + + result = LitellmModel._convert_tool_choice_for_response(None) + assert result == "auto" + + +class TestToolChoiceConversionEdgeCases: + """Test edge cases and real-world scenarios.""" + + def test_real_world_scenario_chatcompletions_format(self): + """ + Test a real-world scenario from ChatCompletions Converter. + This is the actual format returned when tool_choice specifies a tool name. + """ + # This is what ChatCompletions Converter returns + tool_choice_from_converter = { + "type": "function", + "function": {"name": "get_weather"}, + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_from_converter) + assert isinstance(result, ToolChoiceFunction) + assert result.name == "get_weather" + assert result.type == "function" + + def test_none_string_vs_none_literal(self): + """Test that string 'none' works but None (NoneType) defaults to auto""" + # String "none" should be preserved + result = LitellmModel._convert_tool_choice_for_response("none") + assert result == "none" + + # NoneType should fallback to auto + result = LitellmModel._convert_tool_choice_for_response(None) + assert result == "auto" + + def test_complex_tool_name(self): + """Test that complex tool names are handled correctly""" + tool_choice_dict = { + "type": "function", + "function": {"name": "get_user_profile_with_special_chars_123"}, + } + result = LitellmModel._convert_tool_choice_for_response(tool_choice_dict) + assert isinstance(result, ToolChoiceFunction) + assert result.name == "get_user_profile_with_special_chars_123"