diff --git a/homeassistant/components/google_generative_ai_conversation/config_flow.py b/homeassistant/components/google_generative_ai_conversation/config_flow.py index 97b5fc25b2fd52..2f9040344b3ecf 100644 --- a/homeassistant/components/google_generative_ai_conversation/config_flow.py +++ b/homeassistant/components/google_generative_ai_conversation/config_flow.py @@ -34,16 +34,18 @@ CONF_CHAT_MODEL, CONF_MAX_TOKENS, CONF_PROMPT, + CONF_RECOMMENDED, CONF_TEMPERATURE, + CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_K, - DEFAULT_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_K, + RECOMMENDED_TOP_P, ) _LOGGER = logging.getLogger(__name__) @@ -54,6 +56,12 @@ } ) +RECOMMENDED_OPTIONS = { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_TONE_PROMPT: "", +} + async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None: """Validate the user input allows us to connect. @@ -94,7 +102,7 @@ async def async_step_user( return self.async_create_entry( title="Google Generative AI", data=user_input, - options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST}, + options=RECOMMENDED_OPTIONS, ) return self.async_show_form( @@ -115,18 +123,37 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow): def __init__(self, config_entry: ConfigEntry) -> None: """Initialize options flow.""" self.config_entry = config_entry + self.last_rendered_recommended = config_entry.options.get( + CONF_RECOMMENDED, False + ) async def async_step_init( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Manage the options.""" + options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options + if user_input is not None: - if user_input[CONF_LLM_HASS_API] == "none": - user_input.pop(CONF_LLM_HASS_API) - return self.async_create_entry(title="", data=user_input) - schema = await google_generative_ai_config_option_schema( - self.hass, self.config_entry.options - ) + if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended: + if user_input[CONF_LLM_HASS_API] == "none": + user_input.pop(CONF_LLM_HASS_API) + return self.async_create_entry(title="", data=user_input) + + # Re-render the options again, now with the recommended options shown/hidden + self.last_rendered_recommended = user_input[CONF_RECOMMENDED] + + # If we switch to not recommended, generate used prompt. + if user_input[CONF_RECOMMENDED]: + options = RECOMMENDED_OPTIONS + else: + options = { + CONF_RECOMMENDED: False, + CONF_PROMPT: DEFAULT_PROMPT + + "\n" + + user_input.get(CONF_TONE_PROMPT, ""), + } + + schema = await google_generative_ai_config_option_schema(self.hass, options) return self.async_show_form( step_id="init", data_schema=vol.Schema(schema), @@ -135,87 +162,94 @@ async def async_step_init( async def google_generative_ai_config_option_schema( hass: HomeAssistant, - options: MappingProxyType[str, Any], + options: dict[str, Any] | MappingProxyType[str, Any], ) -> dict: """Return a schema for Google Generative AI completion options.""" - api_models = await hass.async_add_executor_job(partial(genai.list_models)) - - models: list[SelectOptionDict] = [ + hass_apis: list[SelectOptionDict] = [ SelectOptionDict( - label="Gemini 1.5 Flash (recommended)", - value="models/gemini-1.5-flash-latest", - ), + label="No control", + value="none", + ) ] - models.extend( + hass_apis.extend( + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ) + + if options.get(CONF_RECOMMENDED): + return { + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, + vol.Optional( + CONF_TONE_PROMPT, + description={"suggested_value": options.get(CONF_TONE_PROMPT)}, + default="", + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), + } + + api_models = await hass.async_add_executor_job(partial(genai.list_models)) + + models = [ SelectOptionDict( label=api_model.display_name, value=api_model.name, ) for api_model in sorted(api_models, key=lambda x: x.display_name) if ( - api_model.name - not in ( - "models/gemini-1.0-pro", # duplicate of gemini-pro - "models/gemini-1.5-flash-latest", - ) + api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro and "vision" not in api_model.name and "generateContent" in api_model.supported_generation_methods ) - ) - - apis: list[SelectOptionDict] = [ - SelectOptionDict( - label="No control", - value="none", - ) ] - apis.extend( - SelectOptionDict( - label=api.name, - value=api.id, - ) - for api in llm.async_get_apis(hass) - ) return { + vol.Required( + CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False) + ): bool, vol.Optional( CONF_CHAT_MODEL, description={"suggested_value": options.get(CONF_CHAT_MODEL)}, - default=DEFAULT_CHAT_MODEL, + default=RECOMMENDED_CHAT_MODEL, ): SelectSelector( - SelectSelectorConfig( - mode=SelectSelectorMode.DROPDOWN, - options=models, - ) + SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models) ), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - default="none", - ): SelectSelector(SelectSelectorConfig(options=apis)), vol.Optional( CONF_PROMPT, description={"suggested_value": options.get(CONF_PROMPT)}, default=DEFAULT_PROMPT, ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + default="none", + ): SelectSelector(SelectSelectorConfig(options=hass_apis)), vol.Optional( CONF_TEMPERATURE, description={"suggested_value": options.get(CONF_TEMPERATURE)}, - default=DEFAULT_TEMPERATURE, + default=RECOMMENDED_TEMPERATURE, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), vol.Optional( CONF_TOP_P, description={"suggested_value": options.get(CONF_TOP_P)}, - default=DEFAULT_TOP_P, + default=RECOMMENDED_TOP_P, ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), vol.Optional( CONF_TOP_K, description={"suggested_value": options.get(CONF_TOP_K)}, - default=DEFAULT_TOP_K, + default=RECOMMENDED_TOP_K, ): int, vol.Optional( CONF_MAX_TOKENS, description={"suggested_value": options.get(CONF_MAX_TOKENS)}, - default=DEFAULT_MAX_TOKENS, + default=RECOMMENDED_MAX_TOKENS, ): int, } diff --git a/homeassistant/components/google_generative_ai_conversation/const.py b/homeassistant/components/google_generative_ai_conversation/const.py index ba47b2acfe3a77..53a1e2a74a94a6 100644 --- a/homeassistant/components/google_generative_ai_conversation/const.py +++ b/homeassistant/components/google_generative_ai_conversation/const.py @@ -5,6 +5,7 @@ DOMAIN = "google_generative_ai_conversation" LOGGER = logging.getLogger(__package__) CONF_PROMPT = "prompt" +CONF_TONE_PROMPT = "tone_prompt" DEFAULT_PROMPT = """This smart home is controlled by Home Assistant. An overview of the areas and the devices in this smart home: @@ -23,14 +24,14 @@ {%- endfor %} """ +CONF_RECOMMENDED = "recommended" CONF_CHAT_MODEL = "chat_model" -DEFAULT_CHAT_MODEL = "models/gemini-pro" +RECOMMENDED_CHAT_MODEL = "models/gemini-1.5-flash-latest" CONF_TEMPERATURE = "temperature" -DEFAULT_TEMPERATURE = 0.9 +RECOMMENDED_TEMPERATURE = 1.0 CONF_TOP_P = "top_p" -DEFAULT_TOP_P = 1.0 +RECOMMENDED_TOP_P = 0.95 CONF_TOP_K = "top_k" -DEFAULT_TOP_K = 1 +RECOMMENDED_TOP_K = 64 CONF_MAX_TOKENS = "max_tokens" -DEFAULT_MAX_TOKENS = 150 -DEFAULT_ALLOW_HASS_ACCESS = False +RECOMMENDED_MAX_TOKENS = 150 diff --git a/homeassistant/components/google_generative_ai_conversation/conversation.py b/homeassistant/components/google_generative_ai_conversation/conversation.py index bc21a1a524ae24..b68ab39d53be08 100644 --- a/homeassistant/components/google_generative_ai_conversation/conversation.py +++ b/homeassistant/components/google_generative_ai_conversation/conversation.py @@ -25,16 +25,17 @@ CONF_MAX_TOKENS, CONF_PROMPT, CONF_TEMPERATURE, + CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, DEFAULT_PROMPT, - DEFAULT_TEMPERATURE, - DEFAULT_TOP_K, - DEFAULT_TOP_P, DOMAIN, LOGGER, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TEMPERATURE, + RECOMMENDED_TOP_K, + RECOMMENDED_TOP_P, ) # Max number of back and forth with the LLM to generate a response @@ -156,17 +157,16 @@ async def async_process( ) tools = [_format_tool(tool) for tool in llm_api.async_get_tools()] - raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) model = genai.GenerativeModel( - model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL), + model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL), generation_config={ "temperature": self.entry.options.get( - CONF_TEMPERATURE, DEFAULT_TEMPERATURE + CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE ), - "top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P), - "top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K), + "top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P), + "top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K), "max_output_tokens": self.entry.options.get( - CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS + CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS ), }, tools=tools or None, @@ -179,6 +179,10 @@ async def async_process( conversation_id = ulid.ulid_now() messages = [{}, {}] + raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT) + if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT): + raw_prompt += "\n" + tone_prompt + try: prompt = self._async_generate_prompt(raw_prompt, llm_api) except TemplateError as err: @@ -221,7 +225,7 @@ async def async_process( if not chat_response.parts: intent_response.async_set_error( intent.IntentResponseErrorCode.UNKNOWN, - "Sorry, I had a problem talking to Google Generative AI. Likely blocked", + "Sorry, I had a problem getting a response from Google Generative AI.", ) return conversation.ConversationResult( response=intent_response, conversation_id=conversation_id diff --git a/homeassistant/components/google_generative_ai_conversation/strings.json b/homeassistant/components/google_generative_ai_conversation/strings.json index a6be0c694c17a6..8a961c9e3d3cb7 100644 --- a/homeassistant/components/google_generative_ai_conversation/strings.json +++ b/homeassistant/components/google_generative_ai_conversation/strings.json @@ -18,13 +18,19 @@ "step": { "init": { "data": { - "prompt": "Prompt Template", + "recommended": "Recommended settings", + "prompt": "Prompt", + "tone_prompt": "Tone", "chat_model": "[%key:common::generic::model%]", "temperature": "Temperature", "top_p": "Top P", "top_k": "Top K", "max_tokens": "Maximum tokens to return in response", "llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]" + }, + "data_description": { + "prompt": "Extra data to provide to the LLM. This can be a template.", + "tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template." } } } diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr index f296c3a37c3294..24342bc0b1e00b 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_conversation.ambr @@ -8,11 +8,11 @@ dict({ 'generation_config': dict({ 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, }), - 'model_name': 'models/gemini-pro', + 'model_name': 'models/gemini-1.5-flash-latest', 'tools': None, }), ), @@ -67,11 +67,11 @@ dict({ 'generation_config': dict({ 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, }), - 'model_name': 'models/gemini-pro', + 'model_name': 'models/gemini-1.5-flash-latest', 'tools': None, }), ), @@ -126,11 +126,11 @@ dict({ 'generation_config': dict({ 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, }), - 'model_name': 'models/gemini-pro', + 'model_name': 'models/gemini-1.5-flash-latest', 'tools': None, }), ), @@ -185,11 +185,11 @@ dict({ 'generation_config': dict({ 'max_output_tokens': 150, - 'temperature': 0.9, - 'top_k': 1, - 'top_p': 1.0, + 'temperature': 1.0, + 'top_k': 64, + 'top_p': 0.95, }), - 'model_name': 'models/gemini-pro', + 'model_name': 'models/gemini-1.5-flash-latest', 'tools': None, }), ), diff --git a/tests/components/google_generative_ai_conversation/test_config_flow.py b/tests/components/google_generative_ai_conversation/test_config_flow.py index 57c9633a743aee..a4972d03496355 100644 --- a/tests/components/google_generative_ai_conversation/test_config_flow.py +++ b/tests/components/google_generative_ai_conversation/test_config_flow.py @@ -10,13 +10,17 @@ from homeassistant.components.google_generative_ai_conversation.const import ( CONF_CHAT_MODEL, CONF_MAX_TOKENS, + CONF_PROMPT, + CONF_RECOMMENDED, + CONF_TEMPERATURE, + CONF_TONE_PROMPT, CONF_TOP_K, CONF_TOP_P, - DEFAULT_CHAT_MODEL, - DEFAULT_MAX_TOKENS, - DEFAULT_TOP_K, - DEFAULT_TOP_P, DOMAIN, + RECOMMENDED_CHAT_MODEL, + RECOMMENDED_MAX_TOKENS, + RECOMMENDED_TOP_K, + RECOMMENDED_TOP_P, ) from homeassistant.const import CONF_LLM_HASS_API from homeassistant.core import HomeAssistant @@ -42,7 +46,7 @@ def mock_models(): model_10_pro.name = "models/gemini-pro" with patch( "homeassistant.components.google_generative_ai_conversation.config_flow.genai.list_models", - return_value=[model_10_pro], + return_value=[model_15_flash, model_10_pro], ): yield @@ -84,36 +88,89 @@ async def test_form(hass: HomeAssistant) -> None: "api_key": "bla", } assert result2["options"] == { + CONF_RECOMMENDED: True, CONF_LLM_HASS_API: llm.LLM_API_ASSIST, + CONF_TONE_PROMPT: "", } assert len(mock_setup_entry.mock_calls) == 1 -async def test_options( - hass: HomeAssistant, mock_config_entry, mock_init_component, mock_models +@pytest.mark.parametrize( + ("current_options", "new_options", "expected_options"), + [ + ( + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "none", + CONF_TONE_PROMPT: "bla", + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + }, + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + ), + ( + { + CONF_RECOMMENDED: False, + CONF_PROMPT: "Speak like a pirate", + CONF_TEMPERATURE: 0.3, + CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, + CONF_TOP_P: RECOMMENDED_TOP_P, + CONF_TOP_K: RECOMMENDED_TOP_K, + CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_TONE_PROMPT: "", + }, + { + CONF_RECOMMENDED: True, + CONF_LLM_HASS_API: "assist", + CONF_TONE_PROMPT: "", + }, + ), + ], +) +async def test_options_switching( + hass: HomeAssistant, + mock_config_entry, + mock_init_component, + mock_models, + current_options, + new_options, + expected_options, ) -> None: """Test the options form.""" + hass.config_entries.async_update_entry(mock_config_entry, options=current_options) options_flow = await hass.config_entries.options.async_init( mock_config_entry.entry_id ) + if current_options.get(CONF_RECOMMENDED) != new_options.get(CONF_RECOMMENDED): + options_flow = await hass.config_entries.options.async_configure( + options_flow["flow_id"], + { + **current_options, + CONF_RECOMMENDED: new_options[CONF_RECOMMENDED], + }, + ) options = await hass.config_entries.options.async_configure( options_flow["flow_id"], - { - "prompt": "Speak like a pirate", - "temperature": 0.3, - }, + new_options, ) await hass.async_block_till_done() assert options["type"] is FlowResultType.CREATE_ENTRY - assert options["data"]["prompt"] == "Speak like a pirate" - assert options["data"]["temperature"] == 0.3 - assert options["data"][CONF_CHAT_MODEL] == DEFAULT_CHAT_MODEL - assert options["data"][CONF_TOP_P] == DEFAULT_TOP_P - assert options["data"][CONF_TOP_K] == DEFAULT_TOP_K - assert options["data"][CONF_MAX_TOKENS] == DEFAULT_MAX_TOKENS - assert ( - CONF_LLM_HASS_API not in options["data"] - ), "Options flow should not set this key" + assert options["data"] == expected_options @pytest.mark.parametrize( diff --git a/tests/components/google_generative_ai_conversation/test_conversation.py b/tests/components/google_generative_ai_conversation/test_conversation.py index 76fe10a0d155b8..af7aebace35058 100644 --- a/tests/components/google_generative_ai_conversation/test_conversation.py +++ b/tests/components/google_generative_ai_conversation/test_conversation.py @@ -354,7 +354,7 @@ async def test_blocked_response( assert result.response.response_type == intent.IntentResponseType.ERROR, result assert result.response.error_code == "unknown", result assert result.response.as_dict()["speech"]["plain"]["speech"] == ( - "Sorry, I had a problem talking to Google Generative AI. Likely blocked" + "Sorry, I had a problem getting a response from Google Generative AI." )