From 62fd856487f0acdce0e60b02e1082cdbe34fdefe Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 25 May 2024 03:27:21 +0000 Subject: [PATCH] Add recommended options to OpenAI --- .../openai_conversation/config_flow.py | 109 +++++++++++------- .../components/openai_conversation/const.py | 10 +- .../openai_conversation/conversation.py | 59 +++++----- .../openai_conversation/strings.json | 3 +- .../openai_conversation/test_config_flow.py | 87 +++++++++++++- 5 files changed, 191 insertions(+), 77 deletions(-) diff --git a/homeassistant/components/openai_conversation/config_flow.py b/homeassistant/components/openai_conversation/config_flow.py index 469d36e28d8666..553f6c1ffa26bf 100644 --- a/homeassistant/components/openai_conversation/config_flow.py +++ b/homeassistant/components/openai_conversation/config_flow.py @@ -31,14 +31,15 @@ CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, + CONF_RECOMMENDED, CONF_TEMPERATURE, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, ) _LOGGER = logging.getLogger(__name__) @@ -49,6 +50,12 @@ } ) +RECOMMENDED_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_PROMPT: DEFAULT_PROMPT, +} + async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -88,7 +95,7 @@ async def async_step_user( return self.async_create_entry( title="OpenAI Conversation", data=user_input, - options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, + options=RECOMMENDED_OPTIONS, ) return self.async_show_form( @@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow): def __init__(self, config_entry: ConfigEntry) -> None: """Initialize options flow.""" self.config_entry = config_entry + self.last_rendered_recommended = config_entry.options.get( + CONF_RECOMMENDED, False + ) async def async_step_init( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Manage the options.""" + options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options + if user_input is not None: - if user_input[CONF_LLM_HASS_API] == "none": - user_input.pop(CONF_LLM_HASS_API) - return self.async_create_entry(title="", data=user_input) - schema = openai_config_option_schema(self.hass, self.config_entry.options) + if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + return self.async_create_entry(title="", data=user_input) + + # Re-render the options again, now with the recommended options shown/hidden + self.last_rendered_recommended = user_input[CONF_RECOMMENDED] + + options = { + CONF_RECOMMENDED: user_input[CONF_RECOMMENDED], + CONF_PROMPT: user_input[CONF_PROMPT], + CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API], + } + + schema = openai_config_option_schema(self.hass, options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), @@ -127,16 +150,16 @@ async def async_step_init( def openai_config_option_schema( hass: HomeAssistant, - options: MappingProxyType[str, Any], + options: dict[str, Any] | MappingProxyType[str, Any], ) -> dict: """Return a schema for OpenAI completion options.""" - apis: list[SelectOptionDict] = [ + hass_apis: list[SelectOptionDict] = [ SelectOptionDict( label="No control", value="none", ) ] - apis.extend( + hass_apis.extend( SelectOptionDict( label=api.name, value=api.id, @@ -144,38 +167,46 @@ def openai_config_option_schema( for api in llm.async_get_apis(hass) ) - return { + schema = { vol.Optional( CONF_PROMPT, - description={"suggested_value": options.get(CONF_PROMPT)}, - default=DEFAULT_PROMPT, + description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)}, ): TemplateSelector(), vol.Optional( CONF_LLM_HASS_API, description={"suggested_value": options.get(CONF_LLM_HASS_API)}, default="none", - ): SelectSelector(SelectSelectorConfig(options=apis)), - vol.Optional( - CONF_CHAT_MODEL, - description={ - # New key in HA 2023.4 - "suggested_value": options.get(CONF_CHAT_MODEL) - }, - default=DEFAULT_CHAT_MODEL, - ): str, - vol.Optional( - CONF_MAX_TOKENS, - description={"suggested_value": options.get(CONF_MAX_TOKENS)}, - default=DEFAULT_MAX_TOKENS, - ): int, - vol.Optional( - CONF_TOP_P, - description={"suggested_value": options.get(CONF_TOP_P)}, - default=DEFAULT_TOP_P, - ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), - vol.Optional( - CONF_TEMPERATURE, - description={"suggested_value": options.get(CONF_TEMPERATURE)}, - default=DEFAULT_TEMPERATURE, - ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, } + + if options.get(CONF_RECOMMENDED): + return schema + + schema.update( + { + vol.Optional( + CONF_CHAT_MODEL, + description={"suggested_value": options.get(CONF_CHAT_MODEL)}, + default=RECOMMENDED_CHAT_MODEL, + ): str, + vol.Optional( + CONF_MAX_TOKENS, + description={"suggested_value": options.get(CONF_MAX_TOKENS)}, + default=RECOMMENDED_MAX_TOKENS, + ): int, + vol.Optional( + CONF_TOP_P, + description={"suggested_value": options.get(CONF_TOP_P)}, + default=RECOMMENDED_TOP_P, + ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), + vol.Optional( + CONF_TEMPERATURE, + description={"suggested_value": options.get(CONF_TEMPERATURE)}, + default=RECOMMENDED_TEMPERATURE, + ): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)), + } + ) + return schema diff --git a/homeassistant/components/openai_conversation/const.py b/homeassistant/components/openai_conversation/const.py index 27ef86bf918180..995d80e02f1603 100644 --- a/homeassistant/components/openai_conversation/const.py +++ b/homeassistant/components/openai_conversation/const.py @@ -4,13 +4,15 @@ DOMAIN = "openai_conversation" LOGGER = logging.getLogger(__package__) + +CONF_RECOMMENDED = "recommended" CONF_PROMPT = "prompt" DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point.""" CONF_CHAT_MODEL = "chat_model" -DEFAULT_CHAT_MODEL = "gpt-4o" +RECOMMENDED_CHAT_MODEL = "gpt-4o" CONF_MAX_TOKENS = "max_tokens" -DEFAULT_MAX_TOKENS = 150 +RECOMMENDED_MAX_TOKENS = 150 CONF_TOP_P = "top_p" -DEFAULT_TOP_P = 1.0 +RECOMMENDED_TOP_P = 1.0 CONF_TEMPERATURE = "temperature" -DEFAULT_TEMPERATURE = 1.0 +RECOMMENDED_TEMPERATURE = 1.0 diff --git a/homeassistant/components/openai_conversation/conversation.py b/homeassistant/components/openai_conversation/conversation.py index 7fe4ef6ac04a16..e5d88a9a05a67a 100644 --- a/homeassistant/components/openai_conversation/conversation.py +++ b/homeassistant/components/openai_conversation/conversation.py @@ -22,13 +22,13 @@ CONF_PROMPT, CONF_TEMPERATURE, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_P, DOMAIN, LOGGER, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_P, ) # Max number of back and forth with the LLM to generate a response @@ -90,15 +90,14 @@ async def async_process( self, user_input: conversation.ConversationInput ) -> conversation.ConversationResult: """Process a sentence.""" + options = self.entry.options intent_response = intent.IntentResponse(language=user_input.language) llm_api: llm.API | None = None tools: list[dict[str, Any]] | None = None - if self.entry.options.get(CONF_LLM_HASS_API): + if options.get(CONF_LLM_HASS_API): try: - llm_api = llm.async_get_api( - self.hass, self.entry.options[CONF_LLM_HASS_API] - ) + llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API]) except HomeAssistantError as err: LOGGER.error("Error getting LLM API: %s", err) intent_response.async_set_error( @@ -110,26 +109,12 @@ async def async_process( ) tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] - model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL) - max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS) - top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P) - temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE) - if user_input.conversation_id in self.history: conversation_id = user_input.conversation_id messages = self.history[conversation_id] else: conversation_id = ulid.ulid_now() try: - prompt = template.Template( - self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass - ).async_render( - { - "ha_name": self.hass.config.location_name, - }, - parse_result=False, - ) - if llm_api: empty_tool_input = llm.ToolInput( tool_name="", @@ -142,11 +127,23 @@ async def async_process( device_id=user_input.device_id, ) - prompt = ( - await llm_api.async_get_api_prompt(empty_tool_input) - + "\n" - + prompt + api_prompt = await llm_api.async_get_api_prompt(empty_tool_input) + + else: + api_prompt = llm.PROMPT_NO_API_CONFIGURED + + prompt = ( + api_prompt + + "\n" + + template.Template( + options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass + ).async_render( + { + "ha_name": self.hass.config.location_name, + }, + parse_result=False, ) + ) except TemplateError as err: LOGGER.error("Error rendering prompt: %s", err) @@ -163,7 +160,7 @@ async def async_process( messages.append({"role": "user", "content": user_input.text}) - LOGGER.debug("Prompt for %s: %s", model, messages) + LOGGER.debug("Prompt: %s", messages) client = self.hass.data[DOMAIN][self.entry.entry_id] @@ -171,12 +168,12 @@ async def async_process( for _iteration in range(MAX_TOOL_ITERATIONS): try: result = await client.chat.completions.create( - model=model, + model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), messages=messages, tools=tools, - max_tokens=max_tokens, - top_p=top_p, - temperature=temperature, + max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), + top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE), user=conversation_id, ) except openai.OpenAIError as err: diff --git a/homeassistant/components/openai_conversation/strings.json b/homeassistant/components/openai_conversation/strings.json index 01060afc7f189f..1e93c60b6a9094 100644 --- a/homeassistant/components/openai_conversation/strings.json +++ b/homeassistant/components/openai_conversation/strings.json @@ -22,7 +22,8 @@ "max_tokens": "Maximum tokens to return in response", "temperature": "Temperature", "top_p": "Top P", - "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" + "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]", + "recommended": "Recommended model settings" }, "data_description": { "prompt": "Instruct how the LLM should respond. This can be a template." diff --git a/tests/components/openai_conversation/test_config_flow.py b/tests/components/openai_conversation/test_config_flow.py index 57f03d0c0bf04c..234e518b3c5a47 100644 --- a/tests/components/openai_conversation/test_config_flow.py +++ b/tests/components/openai_conversation/test_config_flow.py @@ -9,9 +9,17 @@ from homeassistant import config_entries from homeassistant.components.openai_conversation.const import ( CONF_CHAT_MODEL, - DEFAULT_CHAT_MODEL, + CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_RECOMMENDED, + CONF_TEMPERATURE, + CONF_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TOP_P, ) +from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -75,7 +83,7 @@ async def test_options( assert options["type"] is FlowResultType.CREATE_ENTRY assert options["data"]["prompt"] == "Speak like a pirate" assert options["data"]["max_tokens"] == 200 - assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL + assert options["data"][CONF_CHAT_MODEL] == RECOMMENDED_CHAT_MODEL @pytest.mark.parametrize( @@ -115,3 +123,78 @@ async def test_form_invalid_auth(hass: HomeAssistant, side_effect, error) -> Non assert result2["type"] is FlowResultType.FORM assert result2["errors"] == {"base": error} + + +@pytest.mark.parametrize( + ("current_options", "new_options", "expected_options"), + [ + ( + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "none", + CONF_PROMPT: "bla", + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + ), + ( + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_PROMPT: "", + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_PROMPT: "", + }, + ), + ], +) +async def test_options_switching( + hass: HomeAssistant, + mock_config_entry, + mock_init_component, + current_options, + new_options, + expected_options, +) -> None: + """Test the options form.""" + hass.config_entries.async_update_entry(mock_config_entry, options=current_options) + options_flow = await hass.config_entries.options.async_init( + mock_config_entry.entry_id + ) + if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED): + options_flow = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + **current_options, + CONF_RECOMMENDED: new_options[CONF_RECOMMENDED], + }, + ) + options = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + new_options, + ) + await hass.async_block_till_done() + assert options["type"] is FlowResultType.CREATE_ENTRY + assert options["data"] == expected_options