Skip to content

Commit

Permalink
Google gen updates (#117893)
Browse files Browse the repository at this point in the history
* Add a recommended model for Google Gen AI

* Add recommended settings to Google Gen AI

* Revert no API msg

* Use correct default settings

* Make sure options are cleared when using recommended

* Update snapshots

* address comments
  • Loading branch information
balloob committed May 23, 2024
1 parent c0bcf00 commit d1af40f
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,18 @@
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
)

_LOGGER = logging.getLogger(__name__)
Expand All @@ -54,6 +56,12 @@
}
)

RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "",
}


async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
Expand Down Expand Up @@ -94,7 +102,7 @@ async def async_step_user(
return self.async_create_entry(
title="Google Generative AI",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
options=RECOMMENDED_OPTIONS,
)

return self.async_show_form(
Expand All @@ -115,18 +123,37 @@ class GoogleGenerativeAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)

async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options

if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = await google_generative_ai_config_option_schema(
self.hass, self.config_entry.options
)
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)

# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]

# If we switch to not recommended, generate used prompt.
if user_input[CONF_RECOMMENDED]:
options = RECOMMENDED_OPTIONS
else:
options = {
CONF_RECOMMENDED: False,
CONF_PROMPT: DEFAULT_PROMPT
+ "\n"
+ user_input.get(CONF_TONE_PROMPT, ""),
}

schema = await google_generative_ai_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
Expand All @@ -135,87 +162,94 @@ async def async_step_init(

async def google_generative_ai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict:
"""Return a schema for Google Generative AI completion options."""
api_models = await hass.async_add_executor_job(partial(genai.list_models))

models: list[SelectOptionDict] = [
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="Gemini 1.5 Flash (recommended)",
value="models/gemini-1.5-flash-latest",
),
label="No control",
value="none",
)
]
models.extend(
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)

if options.get(CONF_RECOMMENDED):
return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_TONE_PROMPT,
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
default="",
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
}

api_models = await hass.async_add_executor_job(partial(genai.list_models))

models = [
SelectOptionDict(
label=api_model.display_name,
value=api_model.name,
)
for api_model in sorted(api_models, key=lambda x: x.display_name)
if (
api_model.name
not in (
"models/gemini-1.0-pro", # duplicate of gemini-pro
"models/gemini-1.5-flash-latest",
)
api_model.name != "models/gemini-1.0-pro" # duplicate of gemini-pro
and "vision" not in api_model.name
and "generateContent" in api_model.supported_generation_methods
)
)

apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)

return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=DEFAULT_CHAT_MODEL,
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(
SelectSelectorConfig(
mode=SelectSelectorMode.DROPDOWN,
options=models,
)
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=DEFAULT_TOP_K,
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
default=RECOMMENDED_MAX_TOKENS,
): int,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_TONE_PROMPT = "tone_prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
Expand All @@ -23,14 +24,14 @@
{%- endfor %}
"""

CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "models/gemini-pro"
RECOMMENDED_CHAT_MODEL = "models/gemini-1.5-flash-latest"
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 0.9
RECOMMENDED_TEMPERATURE = 1.0
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0
RECOMMENDED_TOP_P = 0.95
CONF_TOP_K = "top_k"
DEFAULT_TOP_K = 1
RECOMMENDED_TOP_K = 64
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
DEFAULT_ALLOW_HASS_ACCESS = False
RECOMMENDED_MAX_TOKENS = 150
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
)

# Max number of back and forth with the LLM to generate a response
Expand Down Expand Up @@ -156,17 +157,16 @@ async def async_process(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = genai.GenerativeModel(
model_name=self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL),
model_name=self.entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
generation_config={
"temperature": self.entry.options.get(
CONF_TEMPERATURE, DEFAULT_TEMPERATURE
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"top_p": self.entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
"top_k": self.entry.options.get(CONF_TOP_K, RECOMMENDED_TOP_K),
"max_output_tokens": self.entry.options.get(
CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
),
},
tools=tools or None,
Expand All @@ -179,6 +179,10 @@ async def async_process(
conversation_id = ulid.ulid_now()
messages = [{}, {}]

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
raw_prompt += "\n" + tone_prompt

try:
prompt = self._async_generate_prompt(raw_prompt, llm_api)
except TemplateError as err:
Expand Down Expand Up @@ -221,7 +225,7 @@ async def async_process(
if not chat_response.parts:
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to Google Generative AI. Likely blocked",
"Sorry, I had a problem getting a response from Google Generative AI.",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
"step": {
"init": {
"data": {
"prompt": "Prompt Template",
"recommended": "Recommended settings",
"prompt": "Prompt",
"tone_prompt": "Tone",
"chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature",
"top_p": "Top P",
"top_k": "Top K",
"max_tokens": "Maximum tokens to return in response",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
},
"data_description": {
"prompt": "Extra data to provide to the LLM. This can be a template.",
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
}
}
}
Expand Down

0 comments on commit d1af40f

Please sign in to comment.