diff --git a/homeassistant/components/ai_task/__init__.py b/homeassistant/components/ai_task/__init__.py index daaf190fc55b86..767104916bf5d6 100644 --- a/homeassistant/components/ai_task/__init__.py +++ b/homeassistant/components/ai_task/__init__.py @@ -3,10 +3,8 @@ import logging from typing import Any -from aiohttp import web import voluptuous as vol -from homeassistant.components.http import KEY_HASS, HomeAssistantView from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR from homeassistant.core import ( @@ -28,7 +26,6 @@ ATTR_STRUCTURE, ATTR_TASK_NAME, DATA_COMPONENT, - DATA_IMAGES, DATA_PREFERENCES, DOMAIN, SERVICE_GENERATE_DATA, @@ -42,7 +39,6 @@ GenDataTaskResult, GenImageTask, GenImageTaskResult, - ImageData, async_generate_data, async_generate_image, ) @@ -55,7 +51,6 @@ "GenDataTaskResult", "GenImageTask", "GenImageTaskResult", - "ImageData", "async_generate_data", "async_generate_image", "async_setup", @@ -94,10 +89,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass) hass.data[DATA_COMPONENT] = entity_component hass.data[DATA_PREFERENCES] = AITaskPreferences(hass) - hass.data[DATA_IMAGES] = {} await hass.data[DATA_PREFERENCES].async_load() async_setup_http(hass) - hass.http.register_view(ImageView) hass.services.async_register( DOMAIN, SERVICE_GENERATE_DATA, @@ -209,28 +202,3 @@ def async_set_preferences( def as_dict(self) -> dict[str, str | None]: """Get the current preferences.""" return {key: getattr(self, key) for key in self.KEYS} - - -class ImageView(HomeAssistantView): - """View to generated images.""" - - url = f"/api/{DOMAIN}/images/{{filename}}" - name = f"api:{DOMAIN}/images" - - async def get( - self, - request: web.Request, - filename: str, - ) -> web.Response: - """Serve image.""" - hass = request.app[KEY_HASS] - image_storage = hass.data[DATA_IMAGES] - image_data = image_storage.get(filename) - - if image_data is None: - raise web.HTTPNotFound - - return web.Response( - body=image_data.data, - content_type=image_data.mime_type, - ) diff --git a/homeassistant/components/ai_task/const.py b/homeassistant/components/ai_task/const.py index b62f8002ecf25d..978e6f3cfb9dd1 100644 --- a/homeassistant/components/ai_task/const.py +++ b/homeassistant/components/ai_task/const.py @@ -8,19 +8,19 @@ from homeassistant.util.hass_dict import HassKey if TYPE_CHECKING: + from homeassistant.components.media_source import local_source from homeassistant.helpers.entity_component import EntityComponent from . import AITaskPreferences from .entity import AITaskEntity - from .task import ImageData DOMAIN = "ai_task" DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN) DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences") -DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images") +DATA_MEDIA_SOURCE: HassKey[local_source.LocalSource] = HassKey(f"{DOMAIN}_media_source") +IMAGE_DIR: Final = "image" IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour -MAX_IMAGES = 20 SERVICE_GENERATE_DATA = "generate_data" SERVICE_GENERATE_IMAGE = "generate_image" diff --git a/homeassistant/components/ai_task/manifest.json b/homeassistant/components/ai_task/manifest.json index 9e2eec4651d646..d05faf18055355 100644 --- a/homeassistant/components/ai_task/manifest.json +++ b/homeassistant/components/ai_task/manifest.json @@ -1,7 +1,7 @@ { "domain": "ai_task", "name": "AI Task", - "after_dependencies": ["camera", "http"], + "after_dependencies": ["camera"], "codeowners": ["@home-assistant/core"], "dependencies": ["conversation", "media_source"], "documentation": "https://www.home-assistant.io/integrations/ai_task", diff --git a/homeassistant/components/ai_task/media_source.py b/homeassistant/components/ai_task/media_source.py index 17995584fd7489..2906acf7a2d314 100644 --- a/homeassistant/components/ai_task/media_source.py +++ b/homeassistant/components/ai_task/media_source.py @@ -2,89 +2,21 @@ from __future__ import annotations -from datetime import timedelta -import logging - -from homeassistant.components.http.auth import async_sign_path -from homeassistant.components.media_player import BrowseError, MediaClass -from homeassistant.components.media_source import ( - BrowseMediaSource, - MediaSource, - MediaSourceItem, - PlayMedia, - Unresolvable, -) +from homeassistant.components.media_source import MediaSource, local_source from homeassistant.core import HomeAssistant -from .const import DATA_IMAGES, DOMAIN, IMAGE_EXPIRY_TIME - -_LOGGER = logging.getLogger(__name__) - - -async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource: - """Set up image media source.""" - _LOGGER.debug("Setting up image media source") - return ImageMediaSource(hass) - - -class ImageMediaSource(MediaSource): - """Provide images as media sources.""" - - name: str = "AI Generated Images" - - def __init__(self, hass: HomeAssistant) -> None: - """Initialize ImageMediaSource.""" - super().__init__(DOMAIN) - self.hass = hass - - async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia: - """Resolve media to a url.""" - image_storage = self.hass.data[DATA_IMAGES] - image = image_storage.get(item.identifier) - - if image is None: - raise Unresolvable(f"Could not resolve media item: {item.identifier}") - - return PlayMedia( - async_sign_path( - self.hass, - f"/api/{DOMAIN}/images/{item.identifier}", - timedelta(seconds=IMAGE_EXPIRY_TIME or 1800), - ), - image.mime_type, - ) - - async def async_browse_media( - self, - item: MediaSourceItem, - ) -> BrowseMediaSource: - """Return media.""" - if item.identifier: - raise BrowseError("Unknown item") +from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR - image_storage = self.hass.data[DATA_IMAGES] - children = [ - BrowseMediaSource( - domain=DOMAIN, - identifier=filename, - media_class=MediaClass.IMAGE, - media_content_type=image.mime_type, - title=image.title or filename, - can_play=True, - can_expand=False, - ) - for filename, image in image_storage.items() - ] +async def async_get_media_source(hass: HomeAssistant) -> MediaSource: + """Set up local media source.""" + media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}") - return BrowseMediaSource( - domain=DOMAIN, - identifier=None, - media_class=MediaClass.APP, - media_content_type="", - title="AI Generated Images", - can_play=False, - can_expand=True, - children_media_class=MediaClass.IMAGE, - children=children, - ) + hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource( + hass, + DOMAIN, + "AI Generated Images", + {IMAGE_DIR: media_dir}, + f"/{DOMAIN}", + ) + return source diff --git a/homeassistant/components/ai_task/task.py b/homeassistant/components/ai_task/task.py index 5cd57395d9dbf2..e6d86bee978c88 100644 --- a/homeassistant/components/ai_task/task.py +++ b/homeassistant/components/ai_task/task.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta -from functools import partial +import io import mimetypes from pathlib import Path import tempfile @@ -18,16 +18,15 @@ from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session -from homeassistant.helpers.event import async_call_later from homeassistant.util import RE_SANITIZE_FILENAME, slugify from .const import ( DATA_COMPONENT, - DATA_IMAGES, + DATA_MEDIA_SOURCE, DATA_PREFERENCES, DOMAIN, + IMAGE_DIR, IMAGE_EXPIRY_TIME, - MAX_IMAGES, AITaskEntityFeature, ) @@ -157,24 +156,6 @@ async def async_generate_data( ) -def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None: - """Remove old images to keep the storage size under the limit.""" - if num_to_remove <= 0: - return - - if num_to_remove >= len(image_storage): - image_storage.clear() - return - - sorted_images = sorted( - image_storage.items(), - key=lambda item: item[1].timestamp, - ) - - for filename, _ in sorted_images[:num_to_remove]: - image_storage.pop(filename, None) - - async def async_generate_image( hass: HomeAssistant, *, @@ -224,36 +205,34 @@ async def async_generate_image( if service_result.get("revised_prompt") is None: service_result["revised_prompt"] = instructions - image_storage = hass.data[DATA_IMAGES] - - if len(image_storage) + 1 > MAX_IMAGES: - _cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES) + source = hass.data[DATA_MEDIA_SOURCE] current_time = datetime.now() ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png" sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name)) - filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}" - image_storage[filename] = ImageData( - data=image_data, - timestamp=int(current_time.timestamp()), - mime_type=task_result.mime_type, - title=service_result["revised_prompt"], + image_file = ImageData( + filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}", + file=io.BytesIO(image_data), + content_type=task_result.mime_type, ) - def _purge_image(filename: str, now: datetime) -> None: - """Remove image from storage.""" - image_storage.pop(filename, None) + target_folder = media_source.MediaSourceItem.from_uri( + hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None + ) - if IMAGE_EXPIRY_TIME > 0: - async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename)) + service_result["media_source_id"] = await source.async_upload_media( + target_folder, image_file + ) + item = media_source.MediaSourceItem.from_uri( + hass, service_result["media_source_id"], None + ) service_result["url"] = async_sign_path( hass, - f"/api/{DOMAIN}/images/{filename}", - timedelta(seconds=IMAGE_EXPIRY_TIME or 1800), + (await source.async_resolve_media(item)).url, + timedelta(seconds=IMAGE_EXPIRY_TIME), ) - service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}" return service_result @@ -358,20 +337,8 @@ def as_dict(self) -> dict[str, Any]: @dataclass(slots=True) class ImageData: - """Image data for stored generated images.""" - - data: bytes - """Raw image data.""" - - timestamp: int - """Timestamp when the image was generated, as a Unix timestamp.""" + """Implementation of media_source.local_source.UploadedFile protocol.""" - mime_type: str - """MIME type of the image.""" - - title: str - """Title of the image, usually the prompt used to generate it.""" - - def __str__(self) -> str: - """Return image data as a string.""" - return f"" + filename: str + file: io.IOBase + content_type: str diff --git a/homeassistant/components/backup/const.py b/homeassistant/components/backup/const.py index 773deaef1741c7..1cfb796bd2e6fc 100644 --- a/homeassistant/components/backup/const.py +++ b/homeassistant/components/backup/const.py @@ -26,6 +26,7 @@ "tmp_backups/*.tar", "OZW_Log.txt", "tts/*", + "ai_task/*", ] EXCLUDE_DATABASE_FROM_BACKUP = [ diff --git a/homeassistant/components/logbook/helpers.py b/homeassistant/components/logbook/helpers.py index 4fa0da9033ac1e..238e6a0dda8a27 100644 --- a/homeassistant/components/logbook/helpers.py +++ b/homeassistant/components/logbook/helpers.py @@ -5,8 +5,9 @@ from collections.abc import Callable, Mapping from typing import Any -from homeassistant.components.sensor import ATTR_STATE_CLASS +from homeassistant.components.sensor import ATTR_STATE_CLASS, NON_NUMERIC_DEVICE_CLASSES from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_DEVICE_ID, ATTR_DOMAIN, ATTR_ENTITY_ID, @@ -28,7 +29,13 @@ from homeassistant.helpers.event import async_track_state_change_event from homeassistant.util.event_type import EventType -from .const import ALWAYS_CONTINUOUS_DOMAINS, AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN +from .const import ( + ALWAYS_CONTINUOUS_DOMAINS, + AUTOMATION_EVENTS, + BUILT_IN_EVENTS, + DOMAIN, + SENSOR_DOMAIN, +) from .models import LogbookConfig @@ -38,8 +45,10 @@ def async_filter_entities(hass: HomeAssistant, entity_ids: list[str]) -> list[st return [ entity_id for entity_id in entity_ids - if split_entity_id(entity_id)[0] not in ALWAYS_CONTINUOUS_DOMAINS - and not is_sensor_continuous(hass, ent_reg, entity_id) + if (domain := split_entity_id(entity_id)[0]) not in ALWAYS_CONTINUOUS_DOMAINS + and not ( + domain == SENSOR_DOMAIN and is_sensor_continuous(hass, ent_reg, entity_id) + ) ] @@ -214,6 +223,10 @@ def _forward_state_events_filtered(event: Event[EventStateChangedData]) -> None: ) +def _device_class_is_numeric(device_class: str | None) -> bool: + return device_class is not None and device_class not in NON_NUMERIC_DEVICE_CLASSES + + def is_sensor_continuous( hass: HomeAssistant, ent_reg: er.EntityRegistry, entity_id: str ) -> bool: @@ -233,7 +246,11 @@ def is_sensor_continuous( # has a unit_of_measurement or state_class, and filter if # it does if (state := hass.states.get(entity_id)) and (attributes := state.attributes): - return ATTR_UNIT_OF_MEASUREMENT in attributes or ATTR_STATE_CLASS in attributes + return ( + ATTR_UNIT_OF_MEASUREMENT in attributes + or ATTR_STATE_CLASS in attributes + or _device_class_is_numeric(attributes.get(ATTR_DEVICE_CLASS)) + ) # If its not in the state machine, we need to check # the entity registry to see if its a sensor # filter with a state class. We do not check @@ -243,8 +260,10 @@ def is_sensor_continuous( # the state machine will always have the state. return bool( (entry := ent_reg.async_get(entity_id)) - and entry.capabilities - and entry.capabilities.get(ATTR_STATE_CLASS) + and ( + (entry.capabilities and entry.capabilities.get(ATTR_STATE_CLASS)) + or _device_class_is_numeric(entry.device_class) + ) ) @@ -258,6 +277,12 @@ def _is_state_filtered(new_state: State, old_state: State) -> bool: new_state.state == old_state.state or new_state.last_changed != new_state.last_updated or new_state.domain in ALWAYS_CONTINUOUS_DOMAINS - or ATTR_UNIT_OF_MEASUREMENT in new_state.attributes - or ATTR_STATE_CLASS in new_state.attributes + or ( + new_state.domain == SENSOR_DOMAIN + and ( + ATTR_UNIT_OF_MEASUREMENT in new_state.attributes + or ATTR_STATE_CLASS in new_state.attributes + or _device_class_is_numeric(new_state.attributes.get(ATTR_DEVICE_CLASS)) + ) + ) ) diff --git a/tests/components/ai_task/conftest.py b/tests/components/ai_task/conftest.py index 06f9a56a813f44..ceffb7c055e854 100644 --- a/tests/components/ai_task/conftest.py +++ b/tests/components/ai_task/conftest.py @@ -157,4 +157,4 @@ async def async_setup_entry_platform( with mock_config_flow(TEST_DOMAIN, ConfigFlow): assert await hass.config_entries.async_setup(mock_config_entry.entry_id) - await hass.async_block_till_done() + await hass.async_block_till_done(wait_background_tasks=True) diff --git a/tests/components/ai_task/test_init.py b/tests/components/ai_task/test_init.py index 5c6465936d91f9..83e1808b6d8c1e 100644 --- a/tests/components/ai_task/test_init.py +++ b/tests/components/ai_task/test_init.py @@ -4,13 +4,14 @@ from typing import Any from unittest.mock import patch +from freezegun import freeze_time from freezegun.api import FrozenDateTimeFactory import pytest import voluptuous as vol from homeassistant.components import media_source from homeassistant.components.ai_task import AITaskPreferences -from homeassistant.components.ai_task.const import DATA_PREFERENCES +from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE, DATA_PREFERENCES from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import selector @@ -291,6 +292,7 @@ async def test_generate_data_service_invalid_structure( ), ], ) +@freeze_time("2025-06-14 22:59:00") async def test_generate_image_service( hass: HomeAssistant, init_components: None, @@ -302,21 +304,32 @@ async def test_generate_image_service( preferences = hass.data[DATA_PREFERENCES] preferences.async_set_preferences(**set_preferences) - result = await hass.services.async_call( - "ai_task", - "generate_image", - { - "task_name": "Test Image", - "instructions": "Generate a test image", - } - | msg_extra, - blocking=True, - return_response=True, - ) + with patch.object( + hass.data[DATA_MEDIA_SOURCE], + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await hass.services.async_call( + "ai_task", + "generate_image", + { + "task_name": "Test Image", + "instructions": "Generate a test image", + } + | msg_extra, + blocking=True, + return_response=True, + ) + mock_upload_media.assert_called_once() assert "image_data" not in result - assert result["media_source_id"].startswith("media-source://ai_task/images/") - assert result["url"].startswith("/api/ai_task/images/") + assert ( + result["media_source_id"] + == "media-source://ai_task/image/2025-06-14_225900_test_task.png" + ) + assert result["url"].startswith( + "/ai_task/image/2025-06-14_225900_test_task.png?authSig=" + ) assert result["mime_type"] == "image/png" assert result["model"] == "mock_model" assert result["revised_prompt"] == "mock_revised_prompt" diff --git a/tests/components/ai_task/test_media_source.py b/tests/components/ai_task/test_media_source.py index eae597efb91268..18f1834e08255b 100644 --- a/tests/components/ai_task/test_media_source.py +++ b/tests/components/ai_task/test_media_source.py @@ -1,64 +1,11 @@ """Test ai_task media source.""" -import pytest - from homeassistant.components import media_source -from homeassistant.components.ai_task import ImageData from homeassistant.core import HomeAssistant -@pytest.fixture(name="image_id") -async def mock_image_generate(hass: HomeAssistant) -> str: - """Mock image generation and return the image_id.""" - image_storage = hass.data.setdefault("ai_task_images", {}) - filename = "2025-06-15_150640_test_task.png" - image_storage[filename] = ImageData( - data=b"A", - timestamp=1750000000, - mime_type="image/png", - title="Mock Image", - ) - return filename - - -async def test_browsing( - hass: HomeAssistant, init_components: None, image_id: str -) -> None: - """Test browsing image media source.""" - item = await media_source.async_browse_media(hass, "media-source://ai_task") - - assert item is not None - assert item.title == "AI Generated Images" - assert len(item.children) == 1 - assert item.children[0].media_content_type == "image/png" - assert item.children[0].identifier == image_id - assert item.children[0].title == "Mock Image" - - with pytest.raises( - media_source.BrowseError, - match="Unknown item", - ): - await media_source.async_browse_media( - hass, "media-source://ai_task/invalid_path" - ) - - -async def test_resolving( - hass: HomeAssistant, init_components: None, image_id: str -) -> None: - """Test resolving.""" - item = await media_source.async_resolve_media( - hass, f"media-source://ai_task/{image_id}", None - ) - assert item is not None - assert item.url.startswith(f"/api/ai_task/images/{image_id}?authSig=") - assert item.mime_type == "image/png" +async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None: + """Test that the image media source is created.""" + item = await media_source.async_browse_media(hass, "media-source://") - invalid_id = "aabbccddeeff" - with pytest.raises( - media_source.Unresolvable, - match=f"Could not resolve media item: {invalid_id}", - ): - await media_source.async_resolve_media( - hass, f"media-source://ai_task/{invalid_id}", None - ) + assert any(c.title == "AI Generated Images" for c in item.children) diff --git a/tests/components/ai_task/test_task.py b/tests/components/ai_task/test_task.py index bc8bff4e632b33..345d6c30981520 100644 --- a/tests/components/ai_task/test_task.py +++ b/tests/components/ai_task/test_task.py @@ -1,6 +1,6 @@ """Test tasks for the AI Task integration.""" -from datetime import datetime, timedelta +from datetime import timedelta from pathlib import Path from unittest.mock import patch @@ -11,10 +11,10 @@ from homeassistant.components import media_source from homeassistant.components.ai_task import ( AITaskEntityFeature, - ImageData, async_generate_data, async_generate_image, ) +from homeassistant.components.ai_task.const import DATA_MEDIA_SOURCE from homeassistant.components.camera import Image from homeassistant.components.conversation import async_get_chat_log from homeassistant.const import STATE_UNKNOWN @@ -257,6 +257,7 @@ async def test_generate_data_mixed_attachments( assert media_attachment.path == Path("/media/test.mp4") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, init_components: None, @@ -277,17 +278,26 @@ async def test_generate_image( assert state is not None assert state.state == STATE_UNKNOWN - result = await async_generate_image( - hass, - task_name="Test Task", - entity_id=TEST_ENTITY_ID, - instructions="Test prompt", - ) + with patch.object( + hass.data[DATA_MEDIA_SOURCE], + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await async_generate_image( + hass, + task_name="Test Task", + entity_id=TEST_ENTITY_ID, + instructions="Test prompt", + ) + mock_upload_media.assert_called_once() assert "image_data" not in result - assert result["media_source_id"].startswith("media-source://ai_task/images/") - assert result["media_source_id"].endswith("_test_task.png") - assert result["url"].startswith("/api/ai_task/images/") - assert result["url"].count("_test_task.png?authSig=") == 1 + assert ( + result["media_source_id"] + == "media-source://ai_task/image/2025-06-14_225900_test_task.png" + ) + assert result["url"].startswith( + "/ai_task/image/2025-06-14_225900_test_task.png?authSig=" + ) assert result["mime_type"] == "image/png" assert result["model"] == "mock_model" assert result["revised_prompt"] == "mock_revised_prompt" @@ -309,40 +319,3 @@ async def test_generate_image( entity_id=TEST_ENTITY_ID, instructions="Test prompt", ) - - -async def test_image_cleanup( - hass: HomeAssistant, - init_components: None, - mock_ai_task_entity: MockAITaskEntity, -) -> None: - """Test image cache cleanup.""" - image_storage = hass.data.setdefault("ai_task_images", {}) - image_storage.clear() - image_storage.update( - { - str(idx): ImageData( - data=b"mock_image_data", - timestamp=int(datetime.now().timestamp()), - mime_type="image/png", - title="Test Image", - ) - for idx in range(20) - } - ) - assert len(image_storage) == 20 - - result = await async_generate_image( - hass, - task_name="Test Task", - entity_id=TEST_ENTITY_ID, - instructions="Test prompt", - ) - - assert result["url"].split("?authSig=")[0].split("/")[-1] in image_storage - assert len(image_storage) == 20 - - async_fire_time_changed(hass, dt_util.utcnow() + timedelta(hours=1, seconds=1)) - await hass.async_block_till_done() - - assert len(image_storage) == 19 diff --git a/tests/components/google_generative_ai_conversation/test_ai_task.py b/tests/components/google_generative_ai_conversation/test_ai_task.py index 11e6864d312766..25799ef4bc106c 100644 --- a/tests/components/google_generative_ai_conversation/test_ai_task.py +++ b/tests/components/google_generative_ai_conversation/test_ai_task.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch +from freezegun import freeze_time from google.genai.types import File, FileState, GenerateContentResponse import pytest import voluptuous as vol @@ -222,6 +223,7 @@ async def test_generate_data( @pytest.mark.usefixtures("mock_init_component") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -255,14 +257,17 @@ async def test_generate_image( ], ) - assert hass.data[ai_task.DATA_IMAGES] == {} - - result = await ai_task.async_generate_image( - hass, - task_name="Test Task", - entity_id="ai_task.google_ai_task", - instructions="Generate a test image", - ) + with patch.object( + media_source.local_source.LocalSource, + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await ai_task.async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.google_ai_task", + instructions="Generate a test image", + ) assert result["height"] is None assert result["width"] is None @@ -270,11 +275,11 @@ async def test_generate_image( assert result["mime_type"] == "image/png" assert result["model"] == RECOMMENDED_IMAGE_MODEL.partition("/")[-1] - assert len(hass.data[ai_task.DATA_IMAGES]) == 1 - image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values())) - assert image_data.data == mock_image_data - assert image_data.mime_type == "image/png" - assert image_data.title == "Generate a test image" + mock_upload_media.assert_called_once() + image_data = mock_upload_media.call_args[0][1] + assert image_data.file.getvalue() == mock_image_data + assert image_data.content_type == "image/png" + assert image_data.filename == "2025-06-14_225900_test_task.png" # Verify that generate_content was called with correct parameters assert mock_generate_content.called diff --git a/tests/components/logbook/test_websocket_api.py b/tests/components/logbook/test_websocket_api.py index 7b2550ccc82e51..80d52d02ee3292 100644 --- a/tests/components/logbook/test_websocket_api.py +++ b/tests/components/logbook/test_websocket_api.py @@ -16,8 +16,10 @@ from homeassistant.components.recorder import Recorder from homeassistant.components.recorder.util import get_instance from homeassistant.components.script import EVENT_SCRIPT_STARTED +from homeassistant.components.sensor import ATTR_STATE_CLASS, SensorDeviceClass from homeassistant.components.websocket_api import TYPE_RESULT from homeassistant.const import ( + ATTR_DEVICE_CLASS, ATTR_DOMAIN, ATTR_ENTITY_ID, ATTR_FRIENDLY_NAME, @@ -310,13 +312,15 @@ async def test_get_events_entities_filtered_away( hass.states.async_set("light.kitchen", STATE_ON) await hass.async_block_till_done() hass.states.async_set( - "light.filtered", STATE_ON, {"brightness": 100, ATTR_UNIT_OF_MEASUREMENT: "any"} + "sensor.filtered", + STATE_ON, + {"brightness": 100, ATTR_UNIT_OF_MEASUREMENT: "any"}, ) await hass.async_block_till_done() hass.states.async_set("light.kitchen", STATE_OFF, {"brightness": 200}) await hass.async_block_till_done() hass.states.async_set( - "light.filtered", + "sensor.filtered", STATE_OFF, {"brightness": 300, ATTR_UNIT_OF_MEASUREMENT: "any"}, ) @@ -345,7 +349,7 @@ async def test_get_events_entities_filtered_away( "id": 2, "type": "logbook/get_events", "start_time": now.isoformat(), - "entity_ids": ["light.filtered"], + "entity_ids": ["sensor.filtered"], } ) response = await client.receive_json() @@ -3041,3 +3045,160 @@ def auto_off_listener(event): assert listeners_without_writes( hass.bus.async_listeners() ) == listeners_without_writes(init_listeners) + + +@pytest.mark.parametrize( + ("entity_id", "attributes", "result_count"), + [ + ( + "light.kitchen", + {ATTR_UNIT_OF_MEASUREMENT: "any", "brightness": 100}, + 1, # Light is not a filterable domain + ), + ( + "sensor.sensor0", + {ATTR_UNIT_OF_MEASUREMENT: "any"}, + 0, # Sensor with UoM is always filtered + ), + ( + "sensor.sensor1", + {ATTR_DEVICE_CLASS: SensorDeviceClass.AQI}, + 0, # Sensor with a numeric device class is always filtered + ), + ( + "sensor.sensor2", + {ATTR_DEVICE_CLASS: SensorDeviceClass.ENUM}, + 1, # Sensor with a non-numeric device class is not filtered + ), + ( + "sensor.sensor3", + {ATTR_STATE_CLASS: "any"}, + 0, # Sensor with state class is always filtered + ), + ( + "sensor.sensor4", + {}, + 1, # Sensor with no UoM, device_class, or state_class is not filtered + ), + ( + "number.number0", + {ATTR_UNIT_OF_MEASUREMENT: "any"}, + 1, # Non-sensor domains are not filtered by presence of UoM + ), + ( + "number.number1", + {}, + 1, # Not a filtered domain + ), + ( + "input_number.number0", + {ATTR_UNIT_OF_MEASUREMENT: "any"}, + 1, # Non-sensor domains are not filtered by presence of UoM + ), + ( + "input_number.number1", + {}, + 1, # Not a filtered domain + ), + ( + "counter.counter0", + {}, + 0, # Counter is an always continuous domain + ), + ( + "zone.home", + {}, + 1, # Zone is not an always continuous domain + ), + ], +) +@patch("homeassistant.components.logbook.websocket_api.EVENT_COALESCE_TIME", 0) +async def test_consistent_stream_and_recorder_filtering( + recorder_mock: Recorder, + hass: HomeAssistant, + hass_ws_client: WebSocketGenerator, + entity_id: str, + attributes: dict, + result_count: int, +) -> None: + """Test that the logbook live stream and get_events apis use consistent filtering rules.""" + now = dt_util.utcnow() + await asyncio.gather( + *[ + async_setup_component(hass, comp, {}) + for comp in ("homeassistant", "logbook") + ] + ) + await async_recorder_block_till_done(hass) + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + + hass.states.async_set(entity_id, "1.0", attributes) + hass.states.async_set("binary_sensor.other_entity", "off") + + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + websocket_client = await hass_ws_client() + await websocket_client.send_json( + { + "id": 1, + "type": "logbook/event_stream", + "start_time": now.isoformat(), + "entity_ids": [entity_id, "binary_sensor.other_entity"], + } + ) + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 1 + assert msg["type"] == TYPE_RESULT + assert msg["success"] + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 1 + assert msg["type"] == "event" + assert msg["event"]["events"] == [] + assert "partial" in msg["event"] + await async_wait_recording_done(hass) + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 1 + assert msg["type"] == "event" + assert msg["event"]["events"] == [] + assert "partial" not in msg["event"] + await async_wait_recording_done(hass) + + hass.states.async_set( + entity_id, + "2.0", + attributes, + ) + hass.states.async_set("binary_sensor.other_entity", "on") + await get_instance(hass).async_block_till_done() + await hass.async_block_till_done() + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 1 + assert msg["type"] == "event" + assert "partial" not in msg["event"] + assert len(msg["event"]["events"]) == 1 + result_count + + await hass.async_block_till_done() + + await async_wait_recording_done(hass) + + await websocket_client.send_json( + { + "id": 2, + "type": "logbook/get_events", + "start_time": now.isoformat(), + "entity_ids": [entity_id], + } + ) + response = await websocket_client.receive_json() + assert response["success"] + assert response["id"] == 2 + + results = response["result"] + assert len(results) == result_count diff --git a/tests/components/openai_conversation/test_ai_task.py b/tests/components/openai_conversation/test_ai_task.py index 31a9212bff2125..51ac505893e119 100644 --- a/tests/components/openai_conversation/test_ai_task.py +++ b/tests/components/openai_conversation/test_ai_task.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest.mock import AsyncMock, patch +from freezegun import freeze_time import httpx from openai import PermissionDeniedError import pytest @@ -212,6 +213,7 @@ async def test_generate_data_with_attachments( @pytest.mark.usefixtures("mock_init_component") +@freeze_time("2025-06-14 22:59:00") async def test_generate_image( hass: HomeAssistant, mock_config_entry: MockConfigEntry, @@ -241,14 +243,17 @@ async def test_generate_image( create_message_item(id="msg_A", text="", output_index=1), ] - assert hass.data[ai_task.DATA_IMAGES] == {} - - result = await ai_task.async_generate_image( - hass, - task_name="Test Task", - entity_id="ai_task.openai_ai_task", - instructions="Generate test image", - ) + with patch.object( + media_source.local_source.LocalSource, + "async_upload_media", + return_value="media-source://ai_task/image/2025-06-14_225900_test_task.png", + ) as mock_upload_media: + result = await ai_task.async_generate_image( + hass, + task_name="Test Task", + entity_id="ai_task.openai_ai_task", + instructions="Generate test image", + ) assert result["height"] == 1024 assert result["width"] == 1536 @@ -256,11 +261,11 @@ async def test_generate_image( assert result["mime_type"] == "image/png" assert result["model"] == "gpt-image-1" - assert len(hass.data[ai_task.DATA_IMAGES]) == 1 - image_data = next(iter(hass.data[ai_task.DATA_IMAGES].values())) - assert image_data.data == b"A" - assert image_data.mime_type == "image/png" - assert image_data.title == "Mock revised prompt." + mock_upload_media.assert_called_once() + image_data = mock_upload_media.call_args[0][1] + assert image_data.file.getvalue() == b"A" + assert image_data.content_type == "image/png" + assert image_data.filename == "2025-06-14_225900_test_task.png" assert ( issue_registry.async_get_issue(DOMAIN, "organization_verification_required")