Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

from __future__ import annotations

import asyncio
from functools import partial
import mimetypes
from pathlib import Path
from types import MappingProxyType

from google.genai import Client
from google.genai.errors import APIError, ClientError
from google.genai.types import File, FileState
from requests.exceptions import Timeout
import voluptuous as vol

Expand Down Expand Up @@ -42,13 +39,13 @@
DEFAULT_TITLE,
DEFAULT_TTS_NAME,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_AI_TASK_OPTIONS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_TTS_OPTIONS,
TIMEOUT_MILLIS,
)
from .entity import async_prepare_files_for_prompt

SERVICE_GENERATE_CONTENT = "generate_content"
CONF_IMAGE_FILENAME = "image_filename"
Expand Down Expand Up @@ -92,58 +89,22 @@ async def generate_content(call: ServiceCall) -> ServiceResponse:

client = config_entry.runtime_data

def append_files_to_prompt():
image_filenames = call.data[CONF_IMAGE_FILENAME]
filenames = call.data[CONF_FILENAMES]
for filename in set(image_filenames + filenames):
files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES]

if files:
for filename in files:
if not hass.config.is_allowed_path(filename):
raise HomeAssistantError(
f"Cannot read `{filename}`, no access to path; "
"`allowlist_external_dirs` may need to be adjusted in "
"`configuration.yaml`"
)
if not Path(filename).exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mimetype = mimetypes.guess_type(filename)[0]
with open(filename, "rb") as file:
uploaded_file = client.files.upload(
file=file, config={"mime_type": mimetype}
)
prompt_parts.append(uploaded_file)

async def wait_for_file_processing(uploaded_file: File) -> None:
"""Wait for file processing to complete."""
while True:
uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)
if uploaded_file.state not in (
FileState.STATE_UNSPECIFIED,
FileState.PROCESSING,
):
break
LOGGER.debug(
"Waiting for file `%s` to be processed, current state: %s",
uploaded_file.name,
uploaded_file.state,
)
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)

if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
prompt_parts.extend(
await async_prepare_files_for_prompt(
hass, client, [Path(filename) for filename in files]
)

await hass.async_add_executor_job(append_files_to_prompt)

tasks = [
asyncio.create_task(wait_for_file_processing(part))
for part in prompt_parts
if isinstance(part, File) and part.state != FileState.ACTIVE
]
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
await asyncio.gather(*tasks)
)

try:
response = await client.aio.models.generate_content(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

from __future__ import annotations

import asyncio
import codecs
from collections.abc import AsyncGenerator, Callable
from dataclasses import replace
import mimetypes
from pathlib import Path
from typing import Any, cast

from google.genai import Client
from google.genai.errors import APIError, ClientError
from google.genai.types import (
AutomaticFunctionCallingConfig,
Content,
File,
FileState,
FunctionDeclaration,
GenerateContentConfig,
GenerateContentResponse,
Expand All @@ -26,6 +32,7 @@

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr, llm
from homeassistant.helpers.entity import Entity
Expand All @@ -42,13 +49,15 @@
CONF_TOP_P,
CONF_USE_GOOGLE_SEARCH_TOOL,
DOMAIN,
FILE_POLLING_INTERVAL_SECONDS,
LOGGER,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_HARM_BLOCK_THRESHOLD,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_TOP_K,
RECOMMENDED_TOP_P,
TIMEOUT_MILLIS,
)

# Max number of back and forth with the LLM to generate a response
Expand Down Expand Up @@ -494,3 +503,68 @@ def create_generate_content_config(self) -> GenerateContentConfig:
),
],
)


async def async_prepare_files_for_prompt(
hass: HomeAssistant, client: Client, files: list[Path]
) -> list[File]:
"""Append files to a prompt.

Caller needs to ensure that the files are allowed.
"""

def upload_files() -> list[File]:
prompt_parts: list[File] = []
for filename in files:
if not filename.exists():
raise HomeAssistantError(f"`{filename}` does not exist")
mimetype = mimetypes.guess_type(filename)[0]
prompt_parts.append(
client.files.upload(
file=filename,
config={
"mime_type": mimetype,
"display_name": filename.name,
},
)
)
return prompt_parts

async def wait_for_file_processing(uploaded_file: File) -> None:
"""Wait for file processing to complete."""
first = True
while uploaded_file.state in (
FileState.STATE_UNSPECIFIED,
FileState.PROCESSING,
):
if first:
first = False
else:
LOGGER.debug(
"Waiting for file `%s` to be processed, current state: %s",
uploaded_file.name,
uploaded_file.state,
)
await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS)

uploaded_file = await client.aio.files.get(
name=uploaded_file.name,
config={"http_options": {"timeout": TIMEOUT_MILLIS}},
)

if uploaded_file.state == FileState.FAILED:
raise HomeAssistantError(
f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}"
)

prompt_parts = await hass.async_add_executor_job(upload_files)

tasks = [
asyncio.create_task(wait_for_file_processing(part))
for part in prompt_parts
if part.state != FileState.ACTIVE
]
async with asyncio.timeout(TIMEOUT_MILLIS / 1000):
await asyncio.gather(*tasks)

return prompt_parts
2 changes: 1 addition & 1 deletion homeassistant/components/homee/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
"wrong_hub": "IP-Address belongs to a different homee than the configured one."
"wrong_hub": "IP address belongs to a different homee than the configured one."
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/motion_blinds/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
"documentation": "https://www.home-assistant.io/integrations/motion_blinds",
"iot_class": "local_push",
"loggers": ["motionblinds"],
"requirements": ["motionblinds==0.6.28"]
"requirements": ["motionblinds==0.6.29"]
}
29 changes: 28 additions & 1 deletion homeassistant/components/ollama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CONF_NUM_CTX,
CONF_PROMPT,
CONF_THINK,
DEFAULT_AI_TASK_NAME,
DEFAULT_NAME,
DEFAULT_TIMEOUT,
DOMAIN,
Expand All @@ -47,7 +48,7 @@
]

CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
PLATFORMS = (Platform.CONVERSATION,)
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)

type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]

Expand Down Expand Up @@ -118,6 +119,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
parent_entry = api_keys_entries[entry.data[CONF_URL]]

hass.config_entries.async_add_subentry(parent_entry, subentry)

conversation_entity = entity_registry.async_get_entity_id(
"conversation",
DOMAIN,
Expand Down Expand Up @@ -208,6 +210,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
minor_version=1,
)

if entry.version == 3 and entry.minor_version == 1:
# Add AI Task subentry with default options. We can only create a new
# subentry if we can find an existing model in the entry. The model
# was removed in the previous migration step, so we need to
# check the subentries for an existing model.
existing_model = next(
iter(
model
for subentry in entry.subentries.values()
if (model := subentry.data.get(CONF_MODEL)) is not None
),
None,
)
if existing_model:
hass.config_entries.async_add_subentry(
entry,
ConfigSubentry(
data=MappingProxyType({CONF_MODEL: existing_model}),
subentry_type="ai_task_data",
title=DEFAULT_AI_TASK_NAME,
unique_id=None,
),
)
hass.config_entries.async_update_entry(entry, minor_version=2)

_LOGGER.debug(
"Migration to version %s:%s successful", entry.version, entry.minor_version
)
Expand Down
77 changes: 77 additions & 0 deletions homeassistant/components/ollama/ai_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""AI Task integration for Ollama."""

from __future__ import annotations

from json import JSONDecodeError
import logging

from homeassistant.components import ai_task, conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from homeassistant.util.json import json_loads

from .entity import OllamaBaseLLMEntity

_LOGGER = logging.getLogger(__name__)


async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up AI Task entities."""
for subentry in config_entry.subentries.values():
if subentry.subentry_type != "ai_task_data":
continue

async_add_entities(
[OllamaTaskEntity(config_entry, subentry)],
config_subentry_id=subentry.subentry_id,
)


class OllamaTaskEntity(
ai_task.AITaskEntity,
OllamaBaseLLMEntity,
):
"""Ollama AI Task entity."""

_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA

async def _async_generate_data(
self,
task: ai_task.GenDataTask,
chat_log: conversation.ChatLog,
) -> ai_task.GenDataTaskResult:
"""Handle a generate data task."""
await self._async_handle_chat_log(chat_log, task.structure)

if not isinstance(chat_log.content[-1], conversation.AssistantContent):
raise HomeAssistantError(
"Last content in chat log is not an AssistantContent"
)

text = chat_log.content[-1].content or ""

if not task.structure:
return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=text,
)
try:
data = json_loads(text)
except JSONDecodeError as err:
_LOGGER.error(
"Failed to parse JSON response: %s. Response: %s",
err,
text,
)
raise HomeAssistantError("Error with Ollama structured response") from err

return ai_task.GenDataTaskResult(
conversation_id=chat_log.conversation_id,
data=data,
)
Loading
Loading