Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ class GoogleGenerativeAITaskEntity(
):
"""Google Generative AI AI Task entity."""

_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
_attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)

async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.structure)
await self._async_handle_chat_log(chat_log, task.structure, task.attachments)

if not isinstance(chat_log.content[-1], conversation.AssistantContent):
LOGGER.error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import replace
import mimetypes
from pathlib import Path
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from google.genai import Client
from google.genai.errors import APIError, ClientError
Expand All @@ -30,8 +30,8 @@
import voluptuous as vol
from voluptuous_openapi import convert

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
Expand Down Expand Up @@ -60,6 +60,9 @@
TIMEOUT_MILLIS,
)

if TYPE_CHECKING:
from . import GoogleGenerativeAIConfigEntry

# Max number of back and forth with the LLM to generate a response
MAX_TOOL_ITERATIONS = 10

Expand Down Expand Up @@ -313,7 +316,7 @@ class GoogleGenerativeAILLMBaseEntity(Entity):

def __init__(
self,
entry: ConfigEntry,
entry: GoogleGenerativeAIConfigEntry,
subentry: ConfigSubentry,
default_model: str = RECOMMENDED_CHAT_MODEL,
) -> None:
Expand All @@ -335,6 +338,7 @@ async def _async_handle_chat_log(
self,
chat_log: conversation.ChatLog,
structure: vol.Schema | None = None,
attachments: list[ai_task.PlayMediaWithId] | None = None,
) -> None:
"""Generate an answer for the chat log."""
options = self.subentry.data
Expand Down Expand Up @@ -438,6 +442,18 @@ async def _async_handle_chat_log(
user_message = chat_log.content[-1]
assert isinstance(user_message, conversation.UserContent)
chat_request: str | list[Part] = user_message.content
if attachments:
if any(a.path is None for a in attachments):
raise HomeAssistantError(
"Only local attachments are currently supported"
)
files = await async_prepare_files_for_prompt(
self.hass,
self._genai_client,
[a.path for a in attachments], # type: ignore[misc]
)
chat_request = [chat_request, *files]

# To prevent infinite loops, we limit the number of iterations
for _iteration in range(MAX_TOOL_ITERATIONS):
try:
Expand Down Expand Up @@ -508,7 +524,7 @@ def create_generate_content_config(self) -> GenerateContentConfig:
async def async_prepare_files_for_prompt(
hass: HomeAssistant, client: Client, files: list[Path]
) -> list[File]:
"""Append files to a prompt.
"""Upload files so they can be attached to a prompt.

Caller needs to ensure that the files are allowed.
"""
Expand Down
32 changes: 32 additions & 0 deletions homeassistant/components/matter/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,36 @@ def _update_from_device(self) -> None:
clusters.TemperatureControl.Attributes.MaxTemperature,
),
),
MatterDiscoverySchema(
platform=Platform.NUMBER,
entity_description=MatterNumberEntityDescription(
key="InovelliLEDIndicatorIntensityOff",
entity_category=EntityCategory.CONFIG,
translation_key="led_indicator_intensity_off",
native_max_value=75,
native_min_value=0,
native_step=1,
mode=NumberMode.BOX,
),
entity_class=MatterNumber,
required_attributes=(
custom_clusters.InovelliCluster.Attributes.LEDIndicatorIntensityOff,
),
),
MatterDiscoverySchema(
platform=Platform.NUMBER,
entity_description=MatterNumberEntityDescription(
key="InovelliLEDIndicatorIntensityOn",
entity_category=EntityCategory.CONFIG,
translation_key="led_indicator_intensity_on",
native_max_value=75,
native_min_value=0,
native_step=1,
mode=NumberMode.BOX,
),
entity_class=MatterNumber,
required_attributes=(
custom_clusters.InovelliCluster.Attributes.LEDIndicatorIntensityOn,
),
),
]
6 changes: 6 additions & 0 deletions homeassistant/components/matter/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@
},
"auto_relock_timer": {
"name": "Autorelock time"
},
"led_indicator_intensity_off": {
"name": "LED off intensity"
},
"led_indicator_intensity_on": {
"name": "LED on intensity"
}
},
"light": {
Expand Down
4 changes: 3 additions & 1 deletion homeassistant/components/tesla_fleet/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,5 +226,7 @@ async def async_step_reauth_confirm(
def _is_valid_domain(self, domain: str) -> bool:
"""Validate domain format."""
# Basic domain validation regex
domain_pattern = re.compile(r"^(?:[a-zA-Z0-9]+\.)+[a-zA-Z0-9-]+$")
domain_pattern = re.compile(
r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$"
)
return bool(domain_pattern.match(domain))
14 changes: 10 additions & 4 deletions homeassistant/helpers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,19 @@ def selector_serializer(schema: Any) -> Any: # noqa: C901
if isinstance(schema, selector.ObjectSelector):
result = {"type": "object"}
if fields := schema.config.get("fields"):
result["properties"] = {
field: convert(
properties = {}
required = []
for field, field_schema in fields.items():
properties[field] = convert(
selector.selector(field_schema["selector"]),
custom_serializer=selector_serializer,
)
for field, field_schema in fields.items()
}
if field_schema.get("required"):
required.append(field)
result["properties"] = properties

if required:
result["required"] = required
else:
result["additionalProperties"] = True
if schema.config.get("multiple"):
Expand Down
96 changes: 92 additions & 4 deletions tests/components/google_generative_ai_conversation/test_ai_task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test AI Task platform of Google Generative AI Conversation integration."""

from unittest.mock import AsyncMock
from pathlib import Path
from unittest.mock import AsyncMock, patch

from google.genai.types import GenerateContentResponse
from google.genai.types import File, FileState, GenerateContentResponse
import pytest
import voluptuous as vol

from homeassistant.components import ai_task
from homeassistant.components import ai_task, media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
Expand Down Expand Up @@ -64,6 +65,93 @@ async def test_generate_data(
)
assert result.data == "Hi there!"

# Test with attachments
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
candidates=[
{
"content": {
"parts": [{"text": "Hi there!"}],
"role": "model",
},
}
],
),
],
]
file1 = File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE)
file2 = File(name="context.txt", state=FileState.ACTIVE)
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
media_source.PlayMedia(
url="http://example.com/context.txt",
mime_type="text/plain",
path=Path("context.txt"),
),
],
),
patch(
"google.genai.files.Files.upload",
side_effect=[file1, file2],
) as mock_upload,
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
{"media_content_id": "media-source://media/context.txt"},
],
)

outgoing_message = mock_send_message_stream.mock_calls[1][2]["message"]
assert outgoing_message == ["Test prompt", file1, file2]

assert result.data == "Hi there!"
assert len(mock_upload.mock_calls) == 2
assert mock_upload.mock_calls[0][2]["file"] == Path("doorbell_snapshot.jpg")
assert mock_upload.mock_calls[1][2]["file"] == Path("context.txt")

# Test attachments require play media with a path
with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=None,
),
],
),
pytest.raises(
HomeAssistantError, match="Only local attachments are currently supported"
),
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Test prompt",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
],
)

# Test with structure
mock_send_message_stream.return_value = [
[
GenerateContentResponse(
Expand Down Expand Up @@ -97,7 +185,7 @@ async def test_generate_data(
)
assert result.data == {"characters": ["Mario", "Luigi"]}

assert len(mock_chat_create.mock_calls) == 2
assert len(mock_chat_create.mock_calls) == 4
config = mock_chat_create.mock_calls[-1][2]["config"]
assert config.response_mime_type == "application/json"
assert config.response_schema == {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ async def test_generate_content_service_with_image(
),
patch("pathlib.Path.exists", return_value=True),
patch.object(hass.config, "is_allowed_path", return_value=True),
patch("builtins.open", mock_open(read_data="this is an image")),
patch("mimetypes.guess_type", return_value=["image/jpeg"]),
):
response = await hass.services.async_call(
Expand Down
Loading
Loading