From 2f0215b0341b4dbe7c96f5c8581cc18ccbd583ed Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Wed, 22 May 2024 05:45:04 +0300 Subject: [PATCH] LLM Tools support for OpenAI integration (#117645) * initial commit * Add tests * Move tests to the correct file * Fix exception type * Undo change to default prompt * Add intent dependency * Move format_tool out of the class * Fix tests * coverage * Adjust to new API * Update strings * Update tests * Remove unrelated change * Test referencing non-existing API * Add test to verify no exception on tool conversion for Assist tools * Bump voluptuous-openapi==0.0.4 * Add device_id to tool input * Fix tests --------- Co-authored-by: Paulus Schoutsen --- .../manifest.json | 2 +- .../openai_conversation/config_flow.py | 73 ++-- .../components/openai_conversation/const.py | 4 - .../openai_conversation/conversation.py | 146 ++++++-- .../openai_conversation/manifest.json | 4 +- .../openai_conversation/strings.json | 5 +- requirements_all.txt | 3 +- requirements_test_all.txt | 3 +- .../openai_conversation/conftest.py | 11 + .../snapshots/test_conversation.ambr | 153 +++++++- .../openai_conversation/test_conversation.py | 336 +++++++++++++++++- 11 files changed, 663 insertions(+), 77 deletions(-) diff --git a/homeassistant/components/google_generative_ai_conversation/manifest.json b/homeassistant/components/google_generative_ai_conversation/manifest.json index 00ba74f16b2c21..ee9d78d6c2e49c 100644 --- a/homeassistant/components/google_generative_ai_conversation/manifest.json +++ b/homeassistant/components/google_generative_ai_conversation/manifest.json @@ -8,5 +8,5 @@ "documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"] + "requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.4"] } diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 2fde6f37690230..c9f6e2660558de 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging -import types from types import MappingProxyType from typing import Any @@ -16,11 +15,15 @@ ConfigFlowResult, OptionsFlow, ) -from homeassistant.const import CONF_API_KEY +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.helpers.selector import ( NumberSelector, NumberSelectorConfig, + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, TemplateSelector, ) @@ -46,16 +49,6 @@ } ) -DEFAULT_OPTIONS = types.MappingProxyType( - { - CONF_PROMPT: DEFAULT_PROMPT, - CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL, - CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS, - CONF_TOP_P: DEFAULT_TOP_P, - CONF_TEMPERATURE: DEFAULT_TEMPERATURE, - } -) - async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -92,7 +85,11 @@ async def async_step_user( _LOGGER.exception("Unexpected exception") errors["base"] = "unknown" else: - return self.async_create_entry(title="OpenAI Conversation", data=user_input) + return self.async_create_entry( + title="OpenAI Conversation", + data=user_input, + options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, + ) return self.async_show_form( step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors @@ -118,45 +115,67 @@ async def async_step_init( ) -> ConfigFlowResult: """Manage the options.""" if user_input is not None: - return self.async_create_entry(title="OpenAI Conversation", data=user_input) - schema = openai_config_option_schema(self.config_entry.options) + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + return self.async_create_entry(title="", data=user_input) + schema = openai_config_option_schema(self.hass, self.config_entry.options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), ) -def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict: +def openai_config_option_schema( + hass: HomeAssistant, + options: MappingProxyType[str, Any], +) -> dict: """Return a schema for OpenAI completion options.""" - if not options: - options = DEFAULT_OPTIONS + apis: list[SelectOptionDict] = [ + SelectOptionDict( + label="No control", + value="none", + ) + ] + apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + return { - vol.Optional( - CONF_PROMPT, - description={"suggested_value": options[CONF_PROMPT]}, - default=DEFAULT_PROMPT, - ): TemplateSelector(), vol.Optional( CONF_CHAT_MODEL, description={ # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) + "suggested_value": options.get(CONF_CHAT_MODEL) }, default=DEFAULT_CHAT_MODEL, ): str, + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=apis)), + vol.Optional( + CONF_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT)}, + default=DEFAULT_PROMPT, + ): TemplateSelector(), vol.Optional( CONF_MAX_TOKENS, - description={"suggested_value": options[CONF_MAX_TOKENS]}, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, default=DEFAULT_MAX_TOKENS, ): int, vol.Optional( CONF_TOP_P, - description={"suggested_value": options[CONF_TOP_P]}, + description={"suggested_value": options.get(CONF_TOP_P)}, default=DEFAULT_TOP_P, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), vol.Optional( CONF_TEMPERATURE, - description={"suggested_value": options[CONF_TEMPERATURE]}, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, default=DEFAULT_TEMPERATURE, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), } diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index f992849f9b11dc..1e1fe27f547ab7 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -21,10 +21,6 @@ {%- endif %} {%- endfor %} {%- endfor %} - -Answer the user's questions about the world truthfully. - -If the user wants to control a device, reject the request and suggest using the Home Assistant app. """ CONF_CHAT_MODEL = "chat_model" DEFAULT_CHAT_MODEL = "gpt-3.5-turbo" diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 39549af3b883dc..b7219aad608d14 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -1,15 +1,18 @@ """Conversation support for OpenAI.""" -from typing import Literal +import json +from typing import Any, Literal import openai +import voluptuous as vol +from voluptuous_openapi import convert from homeassistant.components import assist_pipeline, conversation from homeassistant.config_entries import ConfigEntry -from homeassistant.const import MATCH_ALL +from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL from homeassistant.core import HomeAssistant -from homeassistant.exceptions import TemplateError -from homeassistant.helpers import intent, template +from homeassistant.exceptions import HomeAssistantError, TemplateError +from homeassistant.helpers import intent, llm, template from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util import ulid @@ -28,6 +31,9 @@ LOGGER, ) +# Max number of back and forth with the LLM to generate a response +MAX_TOOL_ITERATIONS = 10 + async def async_setup_entry( hass: HomeAssistant, @@ -39,6 +45,15 @@ async def async_setup_entry( async_add_entities([agent]) +def _format_tool(tool: llm.Tool) -> dict[str, Any]: + """Format tool specification.""" + tool_spec = {"name": tool.name} + if tool.description: + tool_spec["description"] = tool.description + tool_spec["parameters"] = convert(tool.parameters) + return {"type": "function", "function": tool_spec} + + class OpenAIConversationEntity( conversation.ConversationEntity, conversation.AbstractConversationAgent ): @@ -75,6 +90,26 @@ async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + intent_response = intent.IntentResponse(language=user_input.language) + llm_api: llm.API | None = None + tools: list[dict[str, Any]] | None = None + + if self.entry.options.get(CONF_LLM_HASS_API): + try: + llm_api = llm.async_get_api( + self.hass, self.entry.options[CONF_LLM_HASS_API] + ) + except HomeAssistantError as err: + LOGGER.error("Error getting LLM API: %s", err) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Error preparing LLM API: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=user_input.conversation_id + ) + tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) @@ -87,7 +122,10 @@ async def async_process( else: conversation_id = ulid.ulid_now() try: - prompt = self._async_generate_prompt(raw_prompt) + prompt = self._async_generate_prompt( + raw_prompt, + llm_api, + ) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) intent_response = intent.IntentResponse(language=user_input.language) @@ -106,38 +144,88 @@ async def async_process( client = self.hass.data[DOMAIN][self.entry.entry_id] - try: - result = await client.chat.completions.create( - model=model, - messages=messages, - max_tokens=max_tokens, - top_p=top_p, - temperature=temperature, - user=conversation_id, - ) - except openai.OpenAIError as err: - intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_error( - intent.IntentResponseErrorCode.UNKNOWN, - f"Sorry, I had a problem talking to OpenAI: {err}", - ) - return conversation.ConversationResult( - response=intent_response, conversation_id=conversation_id - ) - - LOGGER.debug("Response %s", result) - response = result.choices[0].message.model_dump(include={"role", "content"}) - messages.append(response) + # To prevent infinite loops, we limit the number of iterations + for _iteration in range(MAX_TOOL_ITERATIONS): + try: + result = await client.chat.completions.create( + model=model, + messages=messages, + tools=tools, + max_tokens=max_tokens, + top_p=top_p, + temperature=temperature, + user=conversation_id, + ) + except openai.OpenAIError as err: + intent_response = intent.IntentResponse(language=user_input.language) + intent_response.async_set_error( + intent.IntentResponseErrorCode.UNKNOWN, + f"Sorry, I had a problem talking to OpenAI: {err}", + ) + return conversation.ConversationResult( + response=intent_response, conversation_id=conversation_id + ) + + LOGGER.debug("Response %s", result) + response = result.choices[0].message + messages.append(response) + tool_calls = response.tool_calls + + if not tool_calls or not llm_api: + break + + for tool_call in tool_calls: + tool_input = llm.ToolInput( + tool_name=tool_call.function.name, + tool_args=json.loads(tool_call.function.arguments), + platform=DOMAIN, + context=user_input.context, + user_prompt=user_input.text, + language=user_input.language, + assistant=conversation.DOMAIN, + device_id=user_input.device_id, + ) + LOGGER.debug( + "Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args + ) + + try: + tool_response = await llm_api.async_call_tool(tool_input) + except (HomeAssistantError, vol.Invalid) as e: + tool_response = {"error": type(e).__name__} + if str(e): + tool_response["error_text"] = str(e) + + LOGGER.debug("Tool response: %s", tool_response) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": json.dumps(tool_response), + } + ) + self.history[conversation_id] = messages intent_response = intent.IntentResponse(language=user_input.language) - intent_response.async_set_speech(response["content"]) + intent_response.async_set_speech(response.content) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id ) - def _async_generate_prompt(self, raw_prompt: str) -> str: + def _async_generate_prompt( + self, + raw_prompt: str, + llm_api: llm.API | None, + ) -> str: """Generate a prompt for the user.""" + raw_prompt += "\n" + if llm_api: + raw_prompt += llm_api.prompt_template + else: + raw_prompt += llm.PROMPT_NO_API_CONFIGURED + return template.Template(raw_prompt, self.hass).async_render( { "ha_name": self.hass.config.location_name, diff --git a/homeassistant/components/openai_conversation/manifest.json b/homeassistant/components/openai_conversation/manifest.json index b71c84e2081895..480712574c4e91 100644 --- a/homeassistant/components/openai_conversation/manifest.json +++ b/homeassistant/components/openai_conversation/manifest.json @@ -1,12 +1,12 @@ { "domain": "openai_conversation", "name": "OpenAI Conversation", - "after_dependencies": ["assist_pipeline"], + "after_dependencies": ["assist_pipeline", "intent"], "codeowners": ["@balloob"], "config_flow": true, "dependencies": ["conversation"], "documentation": "https://www.home-assistant.io/integrations/openai_conversation", "integration_type": "service", "iot_class": "cloud_polling", - "requirements": ["openai==1.3.8"] + "requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"] } diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 1a7d5a03c6532d..6ab2ffb2855305 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -18,10 +18,11 @@ "init": { "data": { "prompt": "Prompt Template", - "model": "Completion Model", + "chat_model": "[%key:common::generic::model%]", "max_tokens": "Maximum tokens to return in response", "temperature": "Temperature", - "top_p": "Top P" + "top_p": "Top P", + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" } } } diff --git a/requirements_all.txt b/requirements_all.txt index 396b1c7875c22e..8074401a9552e5 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2826,7 +2826,8 @@ voip-utils==0.1.0 volkszaehler==0.4.0 # homeassistant.components.google_generative_ai_conversation -voluptuous-openapi==0.0.3 +# homeassistant.components.openai_conversation +voluptuous-openapi==0.0.4 # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 5431041bc0116f..24892d2093d9eb 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2191,7 +2191,8 @@ vilfo-api-client==0.5.0 voip-utils==0.1.0 # homeassistant.components.google_generative_ai_conversation -voluptuous-openapi==0.0.3 +# homeassistant.components.openai_conversation +voluptuous-openapi==0.0.4 # homeassistant.components.volvooncall volvooncall==0.10.3 diff --git a/tests/components/openai_conversation/conftest.py b/tests/components/openai_conversation/conftest.py index 272c23a951004f..6d770b51ce9be4 100644 --- a/tests/components/openai_conversation/conftest.py +++ b/tests/components/openai_conversation/conftest.py @@ -4,7 +4,9 @@ import pytest +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant +from homeassistant.helpers import llm from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -24,6 +26,15 @@ def mock_config_entry(hass): return entry +@pytest.fixture +def mock_config_entry_with_assist(hass, mock_config_entry): + """Mock a config entry with assist.""" + hass.config_entries.async_update_entry( + mock_config_entry, options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST} + ) + return mock_config_entry + + @pytest.fixture async def mock_init_component(hass, mock_config_entry): """Initialize integration.""" diff --git a/tests/components/openai_conversation/snapshots/test_conversation.ambr b/tests/components/openai_conversation/snapshots/test_conversation.ambr index 1a488bb948c25a..3a89f943399eea 100644 --- a/tests/components/openai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/openai_conversation/snapshots/test_conversation.ambr @@ -16,9 +16,35 @@ - Test Device 4 - 1 (3) - Answer the user's questions about the world truthfully. + If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), + ]) +# --- +# name: test_default_prompt[config_entry_options0-None] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. - If the user wants to control a device, reject the request and suggest using the Home Assistant app. + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Call the intent tools to control the system. Just pass the name to the intent. ''', 'role': 'system', }), @@ -26,13 +52,38 @@ 'content': 'hello', 'role': 'user', }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), + ]) +# --- +# name: test_default_prompt[config_entry_options0-conversation.openai] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Call the intent tools to control the system. Just pass the name to the intent. + ''', + 'role': 'system', + }), dict({ - 'content': 'Hello, how can I help you?', - 'role': 'assistant', + 'content': 'hello', + 'role': 'user', }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), ]) # --- -# name: test_default_prompt[conversation.openai] +# name: test_default_prompt[config_entry_options1-None] list([ dict({ 'content': ''' @@ -49,9 +100,35 @@ - Test Device 4 - 1 (3) - Answer the user's questions about the world truthfully. + Call the intent tools to control the system. Just pass the name to the intent. + ''', + 'role': 'system', + }), + dict({ + 'content': 'hello', + 'role': 'user', + }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), + ]) +# --- +# name: test_default_prompt[config_entry_options1-conversation.openai] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. - If the user wants to control a device, reject the request and suggest using the Home Assistant app. + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + Call the intent tools to control the system. Just pass the name to the intent. ''', 'role': 'system', }), @@ -59,9 +136,67 @@ 'content': 'hello', 'role': 'user', }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), + ]) +# --- +# name: test_default_prompt[conversation.openai] + list([ + dict({ + 'content': ''' + This smart home is controlled by Home Assistant. + + An overview of the areas and the devices in this smart home: + + Test Area: + - Test Device (Test Model) + + Test Area 2: + - Test Device 2 + - Test Device 3 (Test Model 3A) + - Test Device 4 + - 1 (3) + + If the user wants to control a device, tell them to edit the AI configuration and allow access to Home Assistant. + ''', + 'role': 'system', + }), dict({ - 'content': 'Hello, how can I help you?', - 'role': 'assistant', + 'content': 'hello', + 'role': 'user', }), + ChatCompletionMessage(content='Hello, how can I help you?', role='assistant', function_call=None, tool_calls=None), ]) # --- +# name: test_unknown_hass_api + dict({ + 'conversation_id': None, + 'response': IntentResponse( + card=dict({ + }), + error_code=, + failed_results=list([ + ]), + intent=None, + intent_targets=list([ + ]), + language='en', + matched_states=list([ + ]), + reprompt=dict({ + }), + response_type=, + speech=dict({ + 'plain': dict({ + 'extra_data': None, + 'speech': 'Error preparing LLM API: API non-existing not found', + }), + }), + speech_slots=dict({ + }), + success_results=list([ + ]), + unmatched_states=list([ + ]), + ), + }) +# --- diff --git a/tests/components/openai_conversation/test_conversation.py b/tests/components/openai_conversation/test_conversation.py index 9e50204cddec6c..431feb9d48247c 100644 --- a/tests/components/openai_conversation/test_conversation.py +++ b/tests/components/openai_conversation/test_conversation.py @@ -6,18 +6,34 @@ from openai import RateLimitError from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) from openai.types.completion_usage import CompletionUsage import pytest from syrupy.assertion import SnapshotAssertion +import voluptuous as vol from homeassistant.components import conversation +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import Context, HomeAssistant -from homeassistant.helpers import area_registry as ar, device_registry as dr, intent +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import ( + area_registry as ar, + device_registry as dr, + intent, + llm, +) +from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @pytest.mark.parametrize("agent_id", [None, "conversation.openai"]) +@pytest.mark.parametrize( + "config_entry_options", [{}, {CONF_LLM_HASS_API: llm.LLM_API_ASSIST}] +) async def test_default_prompt( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -26,6 +42,7 @@ async def test_default_prompt( device_registry: dr.DeviceRegistry, snapshot: SnapshotAssertion, agent_id: str, + config_entry_options: dict, ) -> None: """Test that the default prompt works.""" entry = MockConfigEntry(title=None) @@ -36,6 +53,14 @@ async def test_default_prompt( if agent_id is None: agent_id = mock_config_entry.entry_id + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + **mock_config_entry.options, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + }, + ) + device_registry.async_get_or_create( config_entry_id=entry.entry_id, connections={("test", "1234")}, @@ -194,3 +219,312 @@ async def test_conversation_agent( mock_config_entry.entry_id ) assert agent.supported_languages == "*" + + +@patch( + "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" +) +async def test_function_call( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call from the assistant.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.return_value = "Test response" + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + role = message["role"] if isinstance(message, dict) else message.role + if role == "tool": + return ChatCompletion( + id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="I have successfully called the function", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + + return ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + role="assistant", + function_call=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_AbCdEfGhIjKlMnOpQrStUvWx", + function=Function( + arguments='{"param1":"test_value"}', + name="test_tool", + ), + type="function", + ) + ], + ), + ) + ], + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + side_effect=completion_result, + ) as mock_create: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert mock_create.mock_calls[1][2]["messages"][3] == { + "role": "tool", + "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", + "name": "test_tool", + "content": '"Test response"', + } + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + platform="openai_conversation", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + +@patch( + "homeassistant.components.openai_conversation.conversation.llm.AssistAPI.async_get_tools" +) +async def test_function_exception( + mock_get_tools, + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test function call with exception.""" + agent_id = mock_config_entry_with_assist.entry_id + context = Context() + + mock_tool = AsyncMock() + mock_tool.name = "test_tool" + mock_tool.description = "Test function" + mock_tool.parameters = vol.Schema( + {vol.Optional("param1", description="Test parameters"): str} + ) + mock_tool.async_call.side_effect = HomeAssistantError("Test tool exception") + + mock_get_tools.return_value = [mock_tool] + + def completion_result(*args, messages, **kwargs): + for message in messages: + role = message["role"] if isinstance(message, dict) else message.role + if role == "tool": + return ChatCompletion( + id="chatcmpl-1234567890ZYXWVUTSRQPONMLKJIH", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="There was an error calling the function", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + + return ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + role="assistant", + function_call=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_AbCdEfGhIjKlMnOpQrStUvWx", + function=Function( + arguments='{"param1":"test_value"}', + name="test_tool", + ), + type="function", + ) + ], + ), + ) + ], + created=1700000000, + model="gpt-4-1106-preview", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ) + + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + side_effect=completion_result, + ) as mock_create: + result = await conversation.async_converse( + hass, + "Please call the test function", + None, + context, + agent_id=agent_id, + ) + + assert result.response.response_type == intent.IntentResponseType.ACTION_DONE + assert mock_create.mock_calls[1][2]["messages"][3] == { + "role": "tool", + "tool_call_id": "call_AbCdEfGhIjKlMnOpQrStUvWx", + "name": "test_tool", + "content": '{"error": "HomeAssistantError", "error_text": "Test tool exception"}', + } + mock_tool.async_call.assert_awaited_once_with( + hass, + llm.ToolInput( + tool_name="test_tool", + tool_args={"param1": "test_value"}, + platform="openai_conversation", + context=context, + user_prompt="Please call the test function", + language="en", + assistant="conversation", + device_id=None, + ), + ) + + +async def test_assist_api_tools_conversion( + hass: HomeAssistant, + mock_config_entry_with_assist: MockConfigEntry, + mock_init_component, +) -> None: + """Test that we are able to convert actual tools from Assist API.""" + for component in [ + "intent", + "todo", + "light", + "shopping_list", + "humidifier", + "climate", + "media_player", + "vacuum", + "cover", + "weather", + ]: + assert await async_setup_component(hass, component, {}) + + agent_id = mock_config_entry_with_assist.entry_id + with patch( + "openai.resources.chat.completions.AsyncCompletions.create", + new_callable=AsyncMock, + return_value=ChatCompletion( + id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content="Hello, how can I help you?", + role="assistant", + function_call=None, + tool_calls=None, + ), + ) + ], + created=1700000000, + model="gpt-3.5-turbo-0613", + object="chat.completion", + system_fingerprint=None, + usage=CompletionUsage( + completion_tokens=9, prompt_tokens=8, total_tokens=17 + ), + ), + ) as mock_create: + await conversation.async_converse(hass, "hello", None, None, agent_id=agent_id) + + tools = mock_create.mock_calls[0][2]["tools"] + assert tools + + +async def test_unknown_hass_api( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + snapshot: SnapshotAssertion, + mock_init_component, +) -> None: + """Test when we reference an API that no longer exists.""" + hass.config_entries.async_update_entry( + mock_config_entry, + options={ + **mock_config_entry.options, + CONF_LLM_HASS_API: "non-existing", + }, + ) + + result = await conversation.async_converse( + hass, "hello", None, Context(), agent_id=mock_config_entry.entry_id + ) + + assert result == snapshot