Skip to content

Commit

Permalink
Allow llm API to render dynamic template prompt (#118055)
Browse files Browse the repository at this point in the history
* Allow llm API to render dynamic template prompt

* Make rendering api prompt async so it can become a RAG

* Fix test
  • Loading branch information
balloob committed May 24, 2024
1 parent 3b2cdb6 commit 7554ca9
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
CONF_PROMPT,
CONF_RECOMMENDED,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_PROMPT,
Expand All @@ -59,7 +58,7 @@
RECOMMENDED_OPTIONS = {
CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: llm.LLM_API_ASSIST,
CONF_TONE_PROMPT: "",
CONF_PROMPT: "",
}


Expand Down Expand Up @@ -142,16 +141,11 @@ async def async_step_init(
# Re-render the options again, now with the recommended options shown/hidden
self.last_rendered_recommended = user_input[CONF_RECOMMENDED]

# If we switch to not recommended, generate used prompt.
if user_input[CONF_RECOMMENDED]:
options = RECOMMENDED_OPTIONS
else:
options = {
CONF_RECOMMENDED: False,
CONF_PROMPT: DEFAULT_PROMPT
+ "\n"
+ user_input.get(CONF_TONE_PROMPT, ""),
}
options = {
CONF_RECOMMENDED: user_input[CONF_RECOMMENDED],
CONF_PROMPT: user_input[CONF_PROMPT],
CONF_LLM_HASS_API: user_input[CONF_LLM_HASS_API],
}

schema = await google_generative_ai_config_option_schema(self.hass, options)
return self.async_show_form(
Expand Down Expand Up @@ -179,22 +173,24 @@ async def google_generative_ai_config_option_schema(
for api in llm.async_get_apis(hass)
)

schema = {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
}

if options.get(CONF_RECOMMENDED):
return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_TONE_PROMPT,
description={"suggested_value": options.get(CONF_TONE_PROMPT)},
default="",
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
}
return schema

api_models = await hass.async_add_executor_job(partial(genai.list_models))

Expand All @@ -211,45 +207,35 @@ async def google_generative_ai_config_option_schema(
)
]

return {
vol.Required(
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
): bool,
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=hass_apis)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
}
schema.update(
{
vol.Optional(
CONF_CHAT_MODEL,
description={"suggested_value": options.get(CONF_CHAT_MODEL)},
default=RECOMMENDED_CHAT_MODEL,
): SelectSelector(
SelectSelectorConfig(mode=SelectSelectorMode.DROPDOWN, options=models)
),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=RECOMMENDED_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options.get(CONF_TOP_P)},
default=RECOMMENDED_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TOP_K,
description={"suggested_value": options.get(CONF_TOP_K)},
default=RECOMMENDED_TOP_K,
): int,
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=RECOMMENDED_MAX_TOKENS,
): int,
}
)
return schema
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,7 @@
DOMAIN = "google_generative_ai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
CONF_TONE_PROMPT = "tone_prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
DEFAULT_PROMPT = "Answer in plain text. Keep it simple and to the point."

CONF_RECOMMENDED = "recommended"
CONF_CHAT_MODEL = "chat_model"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
CONF_MAX_TOKENS,
CONF_PROMPT,
CONF_TEMPERATURE,
CONF_TONE_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
DEFAULT_PROMPT,
Expand Down Expand Up @@ -179,12 +178,32 @@ async def async_process(
conversation_id = ulid.ulid_now()
messages = [{}, {}]

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
if tone_prompt := self.entry.options.get(CONF_TONE_PROMPT):
raw_prompt += "\n" + tone_prompt

try:
prompt = self._async_generate_prompt(raw_prompt, llm_api)
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)

prompt = (
await llm_api.async_get_api_prompt(empty_tool_input) + "\n" + prompt
)

except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response.async_set_error(
Expand Down Expand Up @@ -271,18 +290,3 @@ async def async_process(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

def _async_generate_prompt(self, raw_prompt: str, llm_api: llm.API | None) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED

return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
"step": {
"init": {
"data": {
"recommended": "Recommended settings",
"prompt": "Prompt",
"tone_prompt": "Tone",
"recommended": "Recommended model settings",
"prompt": "Instructions",
"chat_model": "[%key:common::generic::model%]",
"temperature": "Temperature",
"top_p": "Top P",
Expand All @@ -29,8 +28,7 @@
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
},
"data_description": {
"prompt": "Extra data to provide to the LLM. This can be a template.",
"tone_prompt": "Instructions for the LLM on the style of the generated text. This can be a template."
"prompt": "Instruct how the LLM should respond. This can be a template."
}
}
}
Expand Down
18 changes: 1 addition & 17 deletions homeassistant/components/openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,7 @@
DOMAIN = "openai_conversation"
LOGGER = logging.getLogger(__package__)
CONF_PROMPT = "prompt"
DEFAULT_PROMPT = """This smart home is controlled by Home Assistant.
An overview of the areas and the devices in this smart home:
{%- for area in areas() %}
{%- set area_info = namespace(printed=false) %}
{%- for device in area_devices(area) -%}
{%- if not device_attr(device, "disabled_by") and not device_attr(device, "entry_type") and device_attr(device, "name") %}
{%- if not area_info.printed %}
{{ area_name(area) }}:
{%- set area_info.printed = true %}
{%- endif %}
- {{ device_attr(device, "name") }}{% if device_attr(device, "model") and (device_attr(device, "model") | string) not in (device_attr(device, "name") | string) %} ({{ device_attr(device, "model") }}){% endif %}
{%- endif %}
{%- endfor %}
{%- endfor %}
"""
DEFAULT_PROMPT = """Answer in plain text. Keep it simple and to the point."""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-4o"
CONF_MAX_TOKENS = "max_tokens"
Expand Down
50 changes: 27 additions & 23 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ async def async_process(
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
Expand All @@ -122,10 +121,33 @@ async def async_process(
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(
raw_prompt,
llm_api,
prompt = template.Template(
self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT), self.hass
).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)

if llm_api:
empty_tool_input = llm.ToolInput(
tool_name="",
tool_args={},
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)

prompt = (
await llm_api.async_get_api_prompt(empty_tool_input)
+ "\n"
+ prompt
)

except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
Expand All @@ -136,6 +158,7 @@ async def async_process(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

messages = [{"role": "system", "content": prompt}]

messages.append({"role": "user", "content": user_input.text})
Expand Down Expand Up @@ -213,22 +236,3 @@ async def async_process(
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

def _async_generate_prompt(
self,
raw_prompt: str,
llm_api: llm.API | None,
) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED

return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
},
parse_result=False,
)
11 changes: 9 additions & 2 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ class API(ABC):
hass: HomeAssistant
id: str
name: str
prompt_template: str

@abstractmethod
async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
raise NotImplementedError

@abstractmethod
@callback
Expand Down Expand Up @@ -183,9 +187,12 @@ def __init__(self, hass: HomeAssistant) -> None:
hass=hass,
id=LLM_API_ASSIST,
name="Assist",
prompt_template="Call the intent tools to control the system. Just pass the name to the intent.",
)

async def async_get_api_prompt(self, tool_input: ToolInput) -> str:
"""Return the prompt for the API."""
return "Call the intent tools to control Home Assistant. Just pass the name to the intent."

@callback
def async_get_tools(self) -> list[Tool]:
"""Return a list of LLM tools."""
Expand Down
Loading

0 comments on commit 7554ca9

Please sign in to comment.