Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions homeassistant/components/ai_task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import logging
from typing import Any

from aiohttp import web
import voluptuous as vol

from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ATTR_ENTITY_ID, CONF_DESCRIPTION, CONF_SELECTOR
from homeassistant.core import (
Expand All @@ -28,7 +26,6 @@
ATTR_STRUCTURE,
ATTR_TASK_NAME,
DATA_COMPONENT,
DATA_IMAGES,
DATA_PREFERENCES,
DOMAIN,
SERVICE_GENERATE_DATA,
Expand All @@ -42,7 +39,6 @@
GenDataTaskResult,
GenImageTask,
GenImageTaskResult,
ImageData,
async_generate_data,
async_generate_image,
)
Expand All @@ -55,7 +51,6 @@
"GenDataTaskResult",
"GenImageTask",
"GenImageTaskResult",
"ImageData",
"async_generate_data",
"async_generate_image",
"async_setup",
Expand Down Expand Up @@ -94,10 +89,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
entity_component = EntityComponent[AITaskEntity](_LOGGER, DOMAIN, hass)
hass.data[DATA_COMPONENT] = entity_component
hass.data[DATA_PREFERENCES] = AITaskPreferences(hass)
hass.data[DATA_IMAGES] = {}
await hass.data[DATA_PREFERENCES].async_load()
async_setup_http(hass)
hass.http.register_view(ImageView)
hass.services.async_register(
DOMAIN,
SERVICE_GENERATE_DATA,
Expand Down Expand Up @@ -209,28 +202,3 @@ def async_set_preferences(
def as_dict(self) -> dict[str, str | None]:
"""Get the current preferences."""
return {key: getattr(self, key) for key in self.KEYS}


class ImageView(HomeAssistantView):
"""View to generated images."""

url = f"/api/{DOMAIN}/images/{{filename}}"
name = f"api:{DOMAIN}/images"

async def get(
self,
request: web.Request,
filename: str,
) -> web.Response:
"""Serve image."""
hass = request.app[KEY_HASS]
image_storage = hass.data[DATA_IMAGES]
image_data = image_storage.get(filename)

if image_data is None:
raise web.HTTPNotFound

return web.Response(
body=image_data.data,
content_type=image_data.mime_type,
)
6 changes: 3 additions & 3 deletions homeassistant/components/ai_task/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from homeassistant.util.hass_dict import HassKey

if TYPE_CHECKING:
from homeassistant.components.media_source import local_source
from homeassistant.helpers.entity_component import EntityComponent

from . import AITaskPreferences
from .entity import AITaskEntity
from .task import ImageData

DOMAIN = "ai_task"
DATA_COMPONENT: HassKey[EntityComponent[AITaskEntity]] = HassKey(DOMAIN)
DATA_PREFERENCES: HassKey[AITaskPreferences] = HassKey(f"{DOMAIN}_preferences")
DATA_IMAGES: HassKey[dict[str, ImageData]] = HassKey(f"{DOMAIN}_images")
DATA_MEDIA_SOURCE: HassKey[local_source.LocalSource] = HassKey(f"{DOMAIN}_media_source")

IMAGE_DIR: Final = "image"
IMAGE_EXPIRY_TIME = 60 * 60 # 1 hour
MAX_IMAGES = 20

SERVICE_GENERATE_DATA = "generate_data"
SERVICE_GENERATE_IMAGE = "generate_image"
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/ai_task/manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"domain": "ai_task",
"name": "AI Task",
"after_dependencies": ["camera", "http"],
"after_dependencies": ["camera"],
"codeowners": ["@home-assistant/core"],
"dependencies": ["conversation", "media_source"],
"documentation": "https://www.home-assistant.io/integrations/ai_task",
Expand Down
94 changes: 13 additions & 81 deletions homeassistant/components/ai_task/media_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,21 @@

from __future__ import annotations

from datetime import timedelta
import logging

from homeassistant.components.http.auth import async_sign_path
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.components.media_source import (
BrowseMediaSource,
MediaSource,
MediaSourceItem,
PlayMedia,
Unresolvable,
)
from homeassistant.components.media_source import MediaSource, local_source
from homeassistant.core import HomeAssistant

from .const import DATA_IMAGES, DOMAIN, IMAGE_EXPIRY_TIME

_LOGGER = logging.getLogger(__name__)


async def async_get_media_source(hass: HomeAssistant) -> ImageMediaSource:
"""Set up image media source."""
_LOGGER.debug("Setting up image media source")
return ImageMediaSource(hass)


class ImageMediaSource(MediaSource):
"""Provide images as media sources."""

name: str = "AI Generated Images"

def __init__(self, hass: HomeAssistant) -> None:
"""Initialize ImageMediaSource."""
super().__init__(DOMAIN)
self.hass = hass

async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
image_storage = self.hass.data[DATA_IMAGES]
image = image_storage.get(item.identifier)

if image is None:
raise Unresolvable(f"Could not resolve media item: {item.identifier}")

return PlayMedia(
async_sign_path(
self.hass,
f"/api/{DOMAIN}/images/{item.identifier}",
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
),
image.mime_type,
)

async def async_browse_media(
self,
item: MediaSourceItem,
) -> BrowseMediaSource:
"""Return media."""
if item.identifier:
raise BrowseError("Unknown item")
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR

image_storage = self.hass.data[DATA_IMAGES]

children = [
BrowseMediaSource(
domain=DOMAIN,
identifier=filename,
media_class=MediaClass.IMAGE,
media_content_type=image.mime_type,
title=image.title or filename,
can_play=True,
can_expand=False,
)
for filename, image in image_storage.items()
]
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
"""Set up local media source."""
media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}")

return BrowseMediaSource(
domain=DOMAIN,
identifier=None,
media_class=MediaClass.APP,
media_content_type="",
title="AI Generated Images",
can_play=False,
can_expand=True,
children_media_class=MediaClass.IMAGE,
children=children,
)
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
hass,
DOMAIN,
"AI Generated Images",
{IMAGE_DIR: media_dir},
f"/{DOMAIN}",
)
return source
79 changes: 23 additions & 56 deletions homeassistant/components/ai_task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial
import io
import mimetypes
from pathlib import Path
import tempfile
Expand All @@ -18,16 +18,15 @@
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
from homeassistant.helpers.event import async_call_later
from homeassistant.util import RE_SANITIZE_FILENAME, slugify

from .const import (
DATA_COMPONENT,
DATA_IMAGES,
DATA_MEDIA_SOURCE,
DATA_PREFERENCES,
DOMAIN,
IMAGE_DIR,
IMAGE_EXPIRY_TIME,
MAX_IMAGES,
AITaskEntityFeature,
)

Expand Down Expand Up @@ -157,24 +156,6 @@ async def async_generate_data(
)


def _cleanup_images(image_storage: dict[str, ImageData], num_to_remove: int) -> None:
"""Remove old images to keep the storage size under the limit."""
if num_to_remove <= 0:
return

if num_to_remove >= len(image_storage):
image_storage.clear()
return

sorted_images = sorted(
image_storage.items(),
key=lambda item: item[1].timestamp,
)

for filename, _ in sorted_images[:num_to_remove]:
image_storage.pop(filename, None)


async def async_generate_image(
hass: HomeAssistant,
*,
Expand Down Expand Up @@ -224,36 +205,34 @@ async def async_generate_image(
if service_result.get("revised_prompt") is None:
service_result["revised_prompt"] = instructions

image_storage = hass.data[DATA_IMAGES]

if len(image_storage) + 1 > MAX_IMAGES:
_cleanup_images(image_storage, len(image_storage) + 1 - MAX_IMAGES)
source = hass.data[DATA_MEDIA_SOURCE]

current_time = datetime.now()
ext = mimetypes.guess_extension(task_result.mime_type, False) or ".png"
sanitized_task_name = RE_SANITIZE_FILENAME.sub("", slugify(task_name))
filename = f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}"

image_storage[filename] = ImageData(
data=image_data,
timestamp=int(current_time.timestamp()),
mime_type=task_result.mime_type,
title=service_result["revised_prompt"],
image_file = ImageData(
filename=f"{current_time.strftime('%Y-%m-%d_%H%M%S')}_{sanitized_task_name}{ext}",
file=io.BytesIO(image_data),
content_type=task_result.mime_type,
)

def _purge_image(filename: str, now: datetime) -> None:
"""Remove image from storage."""
image_storage.pop(filename, None)
target_folder = media_source.MediaSourceItem.from_uri(
hass, f"media-source://{DOMAIN}/{IMAGE_DIR}", None
)

if IMAGE_EXPIRY_TIME > 0:
async_call_later(hass, IMAGE_EXPIRY_TIME, partial(_purge_image, filename))
service_result["media_source_id"] = await source.async_upload_media(
target_folder, image_file
)

item = media_source.MediaSourceItem.from_uri(
hass, service_result["media_source_id"], None
)
service_result["url"] = async_sign_path(
hass,
f"/api/{DOMAIN}/images/{filename}",
timedelta(seconds=IMAGE_EXPIRY_TIME or 1800),
(await source.async_resolve_media(item)).url,
timedelta(seconds=IMAGE_EXPIRY_TIME),
)
service_result["media_source_id"] = f"media-source://{DOMAIN}/images/{filename}"

return service_result

Expand Down Expand Up @@ -358,20 +337,8 @@ def as_dict(self) -> dict[str, Any]:

@dataclass(slots=True)
class ImageData:
"""Image data for stored generated images."""

data: bytes
"""Raw image data."""

timestamp: int
"""Timestamp when the image was generated, as a Unix timestamp."""
"""Implementation of media_source.local_source.UploadedFile protocol."""

mime_type: str
"""MIME type of the image."""

title: str
"""Title of the image, usually the prompt used to generate it."""

def __str__(self) -> str:
"""Return image data as a string."""
return f"<ImageData {self.title}: {id(self)}>"
filename: str
file: io.IOBase
content_type: str
1 change: 1 addition & 0 deletions homeassistant/components/backup/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"tmp_backups/*.tar",
"OZW_Log.txt",
"tts/*",
"ai_task/*",
]

EXCLUDE_DATABASE_FROM_BACKUP = [
Expand Down
Loading
Loading