Skip to content

Commit

Permalink
Add recommended options to OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed May 25, 2024
1 parent 4b0f58e commit 62fd856
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 77 deletions.
109 changes: 70 additions & 39 deletions homeassistant/components/openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)

_LOGGER = logging.getLogger(__name__)
Expand All @@ -49,6 +50,12 @@
}
)

RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_PROMPT: DEFAULT_PROMPT,
}


async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
Expand Down Expand Up @@ -88,7 +95,7 @@ async def async_step_user(
return self.async_create_entry(
title="OpenAI Conversation",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
options=RECOMMENDED_OPTIONS,
)

return self.async_show_form(
Expand All @@ -109,16 +116,32 @@ class OpenAIOptionsFlow(OptionsFlow):
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.last_rendered_recommended = config_entry.options.get(
CONF_RECOMMENDED, False
)

async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage the options."""
options: dict[str, Any] | MappingProxyType[str, Any] = self.config_entry.options

if user_input is not None:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
if user_input[CONF_RECOMMENDED] == self.last_rendered_recommended:
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)

# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]

options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}

schema = openai_config_option_schema(self.hass, options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
Expand All @@ -127,55 +150,63 @@ async def async_step_init(

def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
options: dict[str, Any] | MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options."""
apis: list[SelectOptionDict] = [
hass_apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
hass_apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)

return {
schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT, DEFAULT_PROMPT)},
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}

if options.get(CONF_RECOMMENDED):
return schema

schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): str,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=2, step=0.05)),
}
)
return schema
10 changes: 6 additions & 4 deletions homeassistant/components/openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__)

CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o"
RECOMMENDED_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens"
DEFAULT_MAX_TOKENS = 150
RECOMMENDED_MAX_TOKENS = 150
CONF_TOP_P = "top_p"
DEFAULT_TOP_P = 1.0
RECOMMENDED_TOP_P = 1.0
CONF_TEMPERATURE = "temperature"
DEFAULT_TEMPERATURE = 1.0
RECOMMENDED_TEMPERATURE = 1.0
59 changes: 28 additions & 31 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TOP_P,
DEFAULT_CHAT_MODEL,
DEFAULT_MAX_TOKENS,
DEFAULT_PROMPT,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_P,
)

# Max number of back and forth with the LLM to generate a response
Expand Down Expand Up @@ -90,15 +90,14 @@ async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
options = self.entry.options
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None

if self.entry.options.get(CONF_LLM_HASS_API):
if options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
llm_api = llm.async_get_api(self.hass, options[CONF_LLM_HASS_API])
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
Expand All @@ -110,26 +109,12 @@ async def async_process(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]

model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)

if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid_now()
try:
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
Expand All @@ -142,11 +127,23 @@ async def async_process(
device_id=user_input.device_id,
)

prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
api_prompt = await llm_api.async_get_api_prompt(empty_tool_input)

else:
api_prompt = llm.PROMPT_NO_API_CONFIGURED

prompt = (
api_prompt
+ "\n"
+ template.Template(
options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
)

except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
Expand All @@ -163,20 +160,20 @@ async def async_process(

messages.append({"role": "user", "content": user_input.text})

LOGGER.debug("Prompt for %s: %s", model, messages)
LOGGER.debug("Prompt: %s", messages)

client = self.hass.data[DOMAIN][self.entry.entry_id]

# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=model,
model=options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL),
messages=messages,
tools=tools,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS),
top_p=options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
temperature=options.get(CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE),
user=conversation_id,
)
except openai.OpenAIError as err:
Expand Down
3 changes: 2 additions & 1 deletion homeassistant/components/openai_conversation/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
"recommended": "Recommended model settings"
},
"data_description": {
"prompt": "Instruct how the LLM should respond. This can be a template."
Expand Down
Loading

0 comments on commit 62fd856

Please sign in to comment.