Skip to content

Commit

Permalink
LLM Tools support for OpenAI integration (#117645)
Browse files Browse the repository at this point in the history
* initial commit

* Add tests

* Move tests to the correct file

* Fix exception type

* Undo change to default prompt

* Add intent dependency

* Move format_tool out of the class

* Fix tests

* coverage

* Adjust to new API

* Update strings

* Update tests

* Remove unrelated change

* Test referencing non-existing API

* Add test to verify no exception on tool conversion for Assist tools

* Bump voluptuous-openapi==0.0.4

* Add device_id to tool input

* Fix tests

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
  • Loading branch information
Shulyaka and balloob committed May 22, 2024
1 parent 09213d8 commit 2f0215b
Show file tree
Hide file tree
Showing 11 changed files with 663 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
"documentation": "https://www.home-assistant.io/integrations/google_generative_ai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.3"]
"requirements": ["google-generativeai==0.5.4", "voluptuous-openapi==0.0.4"]
}
73 changes: 46 additions & 27 deletions homeassistant/components/openai_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import logging
import types
from types import MappingProxyType
from typing import Any

Expand All @@ -16,11 +15,15 @@
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_API_KEY
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
TemplateSelector,
)

Expand All @@ -46,16 +49,6 @@
}
)

DEFAULT_OPTIONS = types.MappingProxyType(
{
CONF_PROMPT: DEFAULT_PROMPT,
CONF_CHAT_MODEL: DEFAULT_CHAT_MODEL,
CONF_MAX_TOKENS: DEFAULT_MAX_TOKENS,
CONF_TOP_P: DEFAULT_TOP_P,
CONF_TEMPERATURE: DEFAULT_TEMPERATURE,
}
)


async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
"""Validate the user input allows us to connect.
Expand Down Expand Up @@ -92,7 +85,11 @@ async def async_step_user(
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
else:
return self.async_create_entry(title="OpenAI Conversation", data=user_input)
return self.async_create_entry(
title="OpenAI Conversation",
data=user_input,
options={CONF_LLM_HASS_API: llm.LLM_API_ASSIST},
)

return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
Expand All @@ -118,45 +115,67 @@ async def async_step_init(
) -> ConfigFlowResult:
"""Manage the options."""
if user_input is not None:
return self.async_create_entry(title="OpenAI Conversation", data=user_input)
schema = openai_config_option_schema(self.config_entry.options)
if user_input[CONF_LLM_HASS_API] == "none":
user_input.pop(CONF_LLM_HASS_API)
return self.async_create_entry(title="", data=user_input)
schema = openai_config_option_schema(self.hass, self.config_entry.options)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema),
)


def openai_config_option_schema(options: MappingProxyType[str, Any]) -> dict:
def openai_config_option_schema(
hass: HomeAssistant,
options: MappingProxyType[str, Any],
) -> dict:
"""Return a schema for OpenAI completion options."""
if not options:
options = DEFAULT_OPTIONS
apis: list[SelectOptionDict] = [
SelectOptionDict(
label="No control",
value="none",
)
]
apis.extend(
SelectOptionDict(
label=api.name,
value=api.id,
)
for api in llm.async_get_apis(hass)
)

return {
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options[CONF_PROMPT]},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_CHAT_MODEL,
description={
# New key in HA 2023.4
"suggested_value": options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
"suggested_value": options.get(CONF_CHAT_MODEL)
},
default=DEFAULT_CHAT_MODEL,
): str,
vol.Optional(
CONF_LLM_HASS_API,
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
default="none",
): SelectSelector(SelectSelectorConfig(options=apis)),
vol.Optional(
CONF_PROMPT,
description={"suggested_value": options.get(CONF_PROMPT)},
default=DEFAULT_PROMPT,
): TemplateSelector(),
vol.Optional(
CONF_MAX_TOKENS,
description={"suggested_value": options[CONF_MAX_TOKENS]},
description={"suggested_value": options.get(CONF_MAX_TOKENS)},
default=DEFAULT_MAX_TOKENS,
): int,
vol.Optional(
CONF_TOP_P,
description={"suggested_value": options[CONF_TOP_P]},
description={"suggested_value": options.get(CONF_TOP_P)},
default=DEFAULT_TOP_P,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
vol.Optional(
CONF_TEMPERATURE,
description={"suggested_value": options[CONF_TEMPERATURE]},
description={"suggested_value": options.get(CONF_TEMPERATURE)},
default=DEFAULT_TEMPERATURE,
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
}
4 changes: 0 additions & 4 deletions homeassistant/components/openai_conversation/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
{%- endif %}
{%- endfor %}
{%- endfor %}
Answer the user's questions about the world truthfully.
If the user wants to control a device, reject the request and suggest using the Home Assistant app.
"""
CONF_CHAT_MODEL = "chat_model"
DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
Expand Down
146 changes: 117 additions & 29 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Conversation support for OpenAI."""

from typing import Literal
import json
from typing import Any, Literal

import openai
import voluptuous as vol
from voluptuous_openapi import convert

from homeassistant.components import assist_pipeline, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import TemplateError
from homeassistant.helpers import intent, template
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import intent, llm, template
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util import ulid

Expand All @@ -28,6 +31,9 @@
LOGGER,
)

# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10


async def async_setup_entry(
hass: HomeAssistant,
Expand All @@ -39,6 +45,15 @@ async def async_setup_entry(
async_add_entities([agent])


def _format_tool(tool: llm.Tool) -> dict[str, Any]:
"""Format tool specification."""
tool_spec = {"name": tool.name}
if tool.description:
tool_spec["description"] = tool.description
tool_spec["parameters"] = convert(tool.parameters)
return {"type": "function", "function": tool_spec}


class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
Expand Down Expand Up @@ -75,6 +90,26 @@ async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""
intent_response = intent.IntentResponse(language=user_input.language)
llm_api: llm.API | None = None
tools: list[dict[str, Any]] | None = None

if self.entry.options.get(CONF_LLM_HASS_API):
try:
llm_api = llm.async_get_api(
self.hass, self.entry.options[CONF_LLM_HASS_API]
)
except HomeAssistantError as err:
LOGGER.error("Error getting LLM API: %s", err)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)
tools = [_format_tool(tool) for tool in llm_api.async_get_tools()]

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
max_tokens = self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS)
Expand All @@ -87,7 +122,10 @@ async def async_process(
else:
conversation_id = ulid.ulid_now()
try:
prompt = self._async_generate_prompt(raw_prompt)
prompt = self._async_generate_prompt(
raw_prompt,
llm_api,
)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
Expand All @@ -106,38 +144,88 @@ async def async_process(

client = self.hass.data[DOMAIN][self.entry.entry_id]

try:
result = await client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

LOGGER.debug("Response %s", result)
response = result.choices[0].message.model_dump(include={"role", "content"})
messages.append(response)
# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
result = await client.chat.completions.create(
model=model,
messages=messages,
tools=tools,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature,
user=conversation_id,
)
except openai.OpenAIError as err:
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to OpenAI: {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

LOGGER.debug("Response %s", result)
response = result.choices[0].message
messages.append(response)
tool_calls = response.tool_calls

if not tool_calls or not llm_api:
break

for tool_call in tool_calls:
tool_input = llm.ToolInput(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
platform=DOMAIN,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=conversation.DOMAIN,
device_id=user_input.device_id,
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)

try:
tool_response = await llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)

LOGGER.debug("Tool response: %s", tool_response)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": json.dumps(tool_response),
}
)

self.history[conversation_id] = messages

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["content"])
intent_response.async_set_speech(response.content)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

def _async_generate_prompt(self, raw_prompt: str) -> str:
def _async_generate_prompt(
self,
raw_prompt: str,
llm_api: llm.API | None,
) -> str:
"""Generate a prompt for the user."""
raw_prompt += "\n"
if llm_api:
raw_prompt += llm_api.prompt_template
else:
raw_prompt += llm.PROMPT_NO_API_CONFIGURED

return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/openai_conversation/manifest.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
{
"domain": "openai_conversation",
"name": "OpenAI Conversation",
"after_dependencies": ["assist_pipeline"],
"after_dependencies": ["assist_pipeline", "intent"],
"codeowners": ["@balloob"],
"config_flow": true,
"dependencies": ["conversation"],
"documentation": "https://www.home-assistant.io/integrations/openai_conversation",
"integration_type": "service",
"iot_class": "cloud_polling",
"requirements": ["openai==1.3.8"]
"requirements": ["openai==1.3.8", "voluptuous-openapi==0.0.4"]
}
5 changes: 3 additions & 2 deletions homeassistant/components/openai_conversation/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
"init": {
"data": {
"prompt": "Prompt Template",
"model": "Completion Model",
"chat_model": "[%key:common::generic::model%]",
"max_tokens": "Maximum tokens to return in response",
"temperature": "Temperature",
"top_p": "Top P"
"top_p": "Top P",
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]"
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2826,7 +2826,8 @@ voip-utils==0.1.0
volkszaehler==0.4.0

# homeassistant.components.google_generative_ai_conversation
voluptuous-openapi==0.0.3
# homeassistant.components.openai_conversation
voluptuous-openapi==0.0.4

# homeassistant.components.volvooncall
volvooncall==0.10.3
Expand Down

0 comments on commit 2f0215b

Please sign in to comment.