diff --git a/homeassistant/components/google_generative_ai_conversation/__init__.py b/homeassistant/components/google_generative_ai_conversation/__init__.py index 99e475a376bba2..a3b87c05e5a1aa 100644 --- a/homeassistant/components/google_generative_ai_conversation/__init__.py +++ b/homeassistant/components/google_generative_ai_conversation/__init__.py @@ -2,15 +2,12 @@ from __future__ import annotations -import asyncio from functools import partial -import mimetypes from pathlib import Path from types import MappingProxyType from google.genai import Client from google.genai.errors import APIError, ClientError -from google.genai.types import File, FileState from requests.exceptions import Timeout import voluptuous as vol @@ -42,13 +39,13 @@ DEFAULT_TITLE, DEFAULT_TTS_NAME, DOMAIN, - FILE_POLLING_INTERVAL_SECONDS, LOGGER, RECOMMENDED_AI_TASK_OPTIONS, RECOMMENDED_CHAT_MODEL, RECOMMENDED_TTS_OPTIONS, TIMEOUT_MILLIS, ) +from .entity import async_prepare_files_for_prompt SERVICE_GENERATE_CONTENT = "generate_content" CONF_IMAGE_FILENAME = "image_filename" @@ -92,58 +89,22 @@ async def generate_content(call: ServiceCall) -> ServiceResponse: client = config_entry.runtime_data - def append_files_to_prompt(): - image_filenames = call.data[CONF_IMAGE_FILENAME] - filenames = call.data[CONF_FILENAMES] - for filename in set(image_filenames + filenames): + files = call.data[CONF_IMAGE_FILENAME] + call.data[CONF_FILENAMES] + + if files: + for filename in files: if not hass.config.is_allowed_path(filename): raise HomeAssistantError( f"Cannot read `{filename}`, no access to path; " "`allowlist_external_dirs` may need to be adjusted in " "`configuration.yaml`" ) - if not Path(filename).exists(): - raise HomeAssistantError(f"`{filename}` does not exist") - mimetype = mimetypes.guess_type(filename)[0] - with open(filename, "rb") as file: - uploaded_file = client.files.upload( - file=file, config={"mime_type": mimetype} - ) - prompt_parts.append(uploaded_file) - - async def wait_for_file_processing(uploaded_file: File) -> None: - """Wait for file processing to complete.""" - while True: - uploaded_file = await client.aio.files.get( - name=uploaded_file.name, - config={"http_options": {"timeout": TIMEOUT_MILLIS}}, - ) - if uploaded_file.state not in ( - FileState.STATE_UNSPECIFIED, - FileState.PROCESSING, - ): - break - LOGGER.debug( - "Waiting for file `%s` to be processed, current state: %s", - uploaded_file.name, - uploaded_file.state, - ) - await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS) - if uploaded_file.state == FileState.FAILED: - raise HomeAssistantError( - f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}" + prompt_parts.extend( + await async_prepare_files_for_prompt( + hass, client, [Path(filename) for filename in files] ) - - await hass.async_add_executor_job(append_files_to_prompt) - - tasks = [ - asyncio.create_task(wait_for_file_processing(part)) - for part in prompt_parts - if isinstance(part, File) and part.state != FileState.ACTIVE - ] - async with asyncio.timeout(TIMEOUT_MILLIS / 1000): - await asyncio.gather(*tasks) + ) try: response = await client.aio.models.generate_content( diff --git a/homeassistant/components/google_generative_ai_conversation/entity.py b/homeassistant/components/google_generative_ai_conversation/entity.py index d471da36a8cfab..8f8edea18cb805 100644 --- a/homeassistant/components/google_generative_ai_conversation/entity.py +++ b/homeassistant/components/google_generative_ai_conversation/entity.py @@ -2,15 +2,21 @@ from __future__ import annotations +import asyncio import codecs from collections.abc import AsyncGenerator, Callable from dataclasses import replace +import mimetypes +from pathlib import Path from typing import Any, cast +from google.genai import Client from google.genai.errors import APIError, ClientError from google.genai.types import ( AutomaticFunctionCallingConfig, Content, + File, + FileState, FunctionDeclaration, GenerateContentConfig, GenerateContentResponse, @@ -26,6 +32,7 @@ from homeassistant.components import conversation from homeassistant.config_entries import ConfigEntry, ConfigSubentry +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import device_registry as dr, llm from homeassistant.helpers.entity import Entity @@ -42,6 +49,7 @@ CONF_TOP_P, CONF_USE_GOOGLE_SEARCH_TOOL, DOMAIN, + FILE_POLLING_INTERVAL_SECONDS, LOGGER, RECOMMENDED_CHAT_MODEL, RECOMMENDED_HARM_BLOCK_THRESHOLD, @@ -49,6 +57,7 @@ RECOMMENDED_TEMPERATURE, RECOMMENDED_TOP_K, RECOMMENDED_TOP_P, + TIMEOUT_MILLIS, ) # Max number of back and forth with the LLM to generate a response @@ -494,3 +503,68 @@ def create_generate_content_config(self) -> GenerateContentConfig: ), ], ) + + +async def async_prepare_files_for_prompt( + hass: HomeAssistant, client: Client, files: list[Path] +) -> list[File]: + """Append files to a prompt. + + Caller needs to ensure that the files are allowed. + """ + + def upload_files() -> list[File]: + prompt_parts: list[File] = [] + for filename in files: + if not filename.exists(): + raise HomeAssistantError(f"`{filename}` does not exist") + mimetype = mimetypes.guess_type(filename)[0] + prompt_parts.append( + client.files.upload( + file=filename, + config={ + "mime_type": mimetype, + "display_name": filename.name, + }, + ) + ) + return prompt_parts + + async def wait_for_file_processing(uploaded_file: File) -> None: + """Wait for file processing to complete.""" + first = True + while uploaded_file.state in ( + FileState.STATE_UNSPECIFIED, + FileState.PROCESSING, + ): + if first: + first = False + else: + LOGGER.debug( + "Waiting for file `%s` to be processed, current state: %s", + uploaded_file.name, + uploaded_file.state, + ) + await asyncio.sleep(FILE_POLLING_INTERVAL_SECONDS) + + uploaded_file = await client.aio.files.get( + name=uploaded_file.name, + config={"http_options": {"timeout": TIMEOUT_MILLIS}}, + ) + + if uploaded_file.state == FileState.FAILED: + raise HomeAssistantError( + f"File `{uploaded_file.name}` processing failed, reason: {uploaded_file.error.message}" + ) + + prompt_parts = await hass.async_add_executor_job(upload_files) + + tasks = [ + asyncio.create_task(wait_for_file_processing(part)) + for part in prompt_parts + if part.state != FileState.ACTIVE + ] + async with asyncio.timeout(TIMEOUT_MILLIS / 1000): + await asyncio.gather(*tasks) + + return prompt_parts diff --git a/homeassistant/components/homee/strings.json b/homeassistant/components/homee/strings.json index 9523d62c6711a0..267d5553a8c5f1 100644 --- a/homeassistant/components/homee/strings.json +++ b/homeassistant/components/homee/strings.json @@ -5,7 +5,7 @@ "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", - "wrong_hub": "IP-Address belongs to a different homee than the configured one." + "wrong_hub": "IP address belongs to a different homee than the configured one." }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", diff --git a/homeassistant/components/motion_blinds/manifest.json b/homeassistant/components/motion_blinds/manifest.json index a82da20396ffd4..eca520d8946f8b 100644 --- a/homeassistant/components/motion_blinds/manifest.json +++ b/homeassistant/components/motion_blinds/manifest.json @@ -21,5 +21,5 @@ "documentation": "https://www.home-assistant.io/integrations/motion_blinds", "iot_class": "local_push", "loggers": ["motionblinds"], - "requirements": ["motionblinds==0.6.28"] + "requirements": ["motionblinds==0.6.29"] } diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 6fe4720d13f1e3..e16550c1e94a25 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -28,6 +28,7 @@ CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, + DEFAULT_AI_TASK_NAME, DEFAULT_NAME, DEFAULT_TIMEOUT, DOMAIN, @@ -47,7 +48,7 @@ ] CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -PLATFORMS = (Platform.CONVERSATION,) +PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION) type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient] @@ -118,6 +119,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None: parent_entry = api_keys_entries[entry.data[CONF_URL]] hass.config_entries.async_add_subentry(parent_entry, subentry) + conversation_entity = entity_registry.async_get_entity_id( "conversation", DOMAIN, @@ -208,6 +210,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> minor_version=1, ) + if entry.version == 3 and entry.minor_version == 1: + # Add AI Task subentry with default options. We can only create a new + # subentry if we can find an existing model in the entry. The model + # was removed in the previous migration step, so we need to + # check the subentries for an existing model. + existing_model = next( + iter( + model + for subentry in entry.subentries.values() + if (model := subentry.data.get(CONF_MODEL)) is not None + ), + None, + ) + if existing_model: + hass.config_entries.async_add_subentry( + entry, + ConfigSubentry( + data=MappingProxyType({CONF_MODEL: existing_model}), + subentry_type="ai_task_data", + title=DEFAULT_AI_TASK_NAME, + unique_id=None, + ), + ) + hass.config_entries.async_update_entry(entry, minor_version=2) + _LOGGER.debug( "Migration to version %s:%s successful", entry.version, entry.minor_version ) diff --git a/homeassistant/components/ollama/ai_task.py b/homeassistant/components/ollama/ai_task.py new file mode 100644 index 00000000000000..d796b28aac8f33 --- /dev/null +++ b/homeassistant/components/ollama/ai_task.py @@ -0,0 +1,77 @@ +"""AI Task integration for Ollama.""" + +from __future__ import annotations + +from json import JSONDecodeError +import logging + +from homeassistant.components import ai_task, conversation +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback +from homeassistant.util.json import json_loads + +from .entity import OllamaBaseLLMEntity + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddConfigEntryEntitiesCallback, +) -> None: + """Set up AI Task entities.""" + for subentry in config_entry.subentries.values(): + if subentry.subentry_type != "ai_task_data": + continue + + async_add_entities( + [OllamaTaskEntity(config_entry, subentry)], + config_subentry_id=subentry.subentry_id, + ) + + +class OllamaTaskEntity( + ai_task.AITaskEntity, + OllamaBaseLLMEntity, +): + """Ollama AI Task entity.""" + + _attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA + + async def _async_generate_data( + self, + task: ai_task.GenDataTask, + chat_log: conversation.ChatLog, + ) -> ai_task.GenDataTaskResult: + """Handle a generate data task.""" + await self._async_handle_chat_log(chat_log, task.structure) + + if not isinstance(chat_log.content[-1], conversation.AssistantContent): + raise HomeAssistantError( + "Last content in chat log is not an AssistantContent" + ) + + text = chat_log.content[-1].content or "" + + if not task.structure: + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=text, + ) + try: + data = json_loads(text) + except JSONDecodeError as err: + _LOGGER.error( + "Failed to parse JSON response: %s. Response: %s", + err, + text, + ) + raise HomeAssistantError("Error with Ollama structured response") from err + + return ai_task.GenDataTaskResult( + conversation_id=chat_log.conversation_id, + data=data, + ) diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index 49eb12a5c23665..cca917f6c29169 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -46,6 +46,8 @@ CONF_NUM_CTX, CONF_PROMPT, CONF_THINK, + DEFAULT_AI_TASK_NAME, + DEFAULT_CONVERSATION_NAME, DEFAULT_KEEP_ALIVE, DEFAULT_MAX_HISTORY, DEFAULT_MODEL, @@ -74,7 +76,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for Ollama.""" VERSION = 3 - MINOR_VERSION = 1 + MINOR_VERSION = 2 def __init__(self) -> None: """Initialize config flow.""" @@ -136,11 +138,14 @@ def async_get_supported_subentry_types( cls, config_entry: ConfigEntry ) -> dict[str, type[ConfigSubentryFlow]]: """Return subentries supported by this integration.""" - return {"conversation": ConversationSubentryFlowHandler} + return { + "conversation": OllamaSubentryFlowHandler, + "ai_task_data": OllamaSubentryFlowHandler, + } -class ConversationSubentryFlowHandler(ConfigSubentryFlow): - """Flow for managing conversation subentries.""" +class OllamaSubentryFlowHandler(ConfigSubentryFlow): + """Flow for managing Ollama subentries.""" def __init__(self) -> None: """Initialize the subentry flow.""" @@ -201,7 +206,11 @@ async def async_step_set_options( step_id="set_options", data_schema=vol.Schema( ollama_config_option_schema( - self.hass, self._is_new, options, models_to_list + self.hass, + self._is_new, + self._subentry_type, + options, + models_to_list, ) ), ) @@ -300,13 +309,19 @@ async def async_step_finish( def ollama_config_option_schema( hass: HomeAssistant, is_new: bool, + subentry_type: str, options: Mapping[str, Any], models_to_list: list[SelectOptionDict], ) -> dict: """Ollama options schema.""" if is_new: + if subentry_type == "ai_task_data": + default_name = DEFAULT_AI_TASK_NAME + else: + default_name = DEFAULT_CONVERSATION_NAME + schema: dict = { - vol.Required(CONF_NAME, default="Ollama Conversation"): str, + vol.Required(CONF_NAME, default=default_name): str, } else: schema = {} @@ -319,29 +334,38 @@ def ollama_config_option_schema( ): SelectSelector( SelectSelectorConfig(options=models_to_list, custom_value=True) ), - vol.Optional( - CONF_PROMPT, - description={ - "suggested_value": options.get( - CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT - ) - }, - ): TemplateSelector(), - vol.Optional( - CONF_LLM_HASS_API, - description={"suggested_value": options.get(CONF_LLM_HASS_API)}, - ): SelectSelector( - SelectSelectorConfig( - options=[ - SelectOptionDict( - label=api.name, - value=api.id, + } + ) + if subentry_type == "conversation": + schema.update( + { + vol.Optional( + CONF_PROMPT, + description={ + "suggested_value": options.get( + CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT ) - for api in llm.async_get_apis(hass) - ], - multiple=True, - ) - ), + }, + ): TemplateSelector(), + vol.Optional( + CONF_LLM_HASS_API, + description={"suggested_value": options.get(CONF_LLM_HASS_API)}, + ): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict( + label=api.name, + value=api.id, + ) + for api in llm.async_get_apis(hass) + ], + multiple=True, + ) + ), + } + ) + schema.update( + { vol.Optional( CONF_NUM_CTX, description={ diff --git a/homeassistant/components/ollama/const.py b/homeassistant/components/ollama/const.py index 3175525c70d230..7e80570bd5ed9a 100644 --- a/homeassistant/components/ollama/const.py +++ b/homeassistant/components/ollama/const.py @@ -159,3 +159,10 @@ "zephyr", ] DEFAULT_MODEL = "llama3.2:latest" + +DEFAULT_CONVERSATION_NAME = "Ollama Conversation" +DEFAULT_AI_TASK_NAME = "Ollama AI Task" + +RECOMMENDED_CONVERSATION_OPTIONS = { + CONF_MAX_HISTORY: DEFAULT_MAX_HISTORY, +} diff --git a/homeassistant/components/ollama/entity.py b/homeassistant/components/ollama/entity.py index 7b63b1dff002c1..4122d0c67d8680 100644 --- a/homeassistant/components/ollama/entity.py +++ b/homeassistant/components/ollama/entity.py @@ -8,6 +8,7 @@ from typing import Any import ollama +import voluptuous as vol from voluptuous_openapi import convert from homeassistant.components import conversation @@ -180,6 +181,7 @@ def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None: async def _async_handle_chat_log( self, chat_log: conversation.ChatLog, + structure: vol.Schema | None = None, ) -> None: """Generate an answer for the chat log.""" settings = {**self.entry.data, **self.subentry.data} @@ -200,6 +202,17 @@ async def _async_handle_chat_log( max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)) self._trim_history(message_history, max_messages) + output_format: dict[str, Any] | None = None + if structure: + output_format = convert( + structure, + custom_serializer=( + chat_log.llm_api.custom_serializer + if chat_log.llm_api + else llm.selector_serializer + ), + ) + # Get response # To prevent infinite loops, we limit the number of iterations for _iteration in range(MAX_TOOL_ITERATIONS): @@ -214,6 +227,7 @@ async def _async_handle_chat_log( keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s", options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)}, think=settings.get(CONF_THINK), + format=output_format, ) except (ollama.RequestError, ollama.ResponseError) as err: _LOGGER.error("Unexpected error talking to Ollama server: %s", err) diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index bb08e4a4684338..4261b2286bf8ec 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -55,6 +55,44 @@ "progress": { "download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details." } + }, + "ai_task_data": { + "initiate_flow": { + "user": "Add Generate data with AI service", + "reconfigure": "Reconfigure Generate data with AI service" + }, + "entry_type": "Generate data with AI service", + "step": { + "set_options": { + "data": { + "model": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::model%]", + "name": "[%key:common::config_flow::data::name%]", + "prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::prompt%]", + "max_history": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::max_history%]", + "num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::num_ctx%]", + "keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::keep_alive%]", + "think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::think%]" + }, + "data_description": { + "prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::prompt%]", + "keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::keep_alive%]", + "num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::num_ctx%]", + "think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::think%]" + } + }, + "download": { + "title": "[%key:component::ollama::config_subentries::conversation::step::download::title%]" + } + }, + "abort": { + "reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]", + "entry_not_loaded": "[%key:component::ollama::config_subentries::conversation::abort::entry_not_loaded%]", + "download_failed": "[%key:component::ollama::config_subentries::conversation::abort::download_failed%]", + "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]" + }, + "progress": { + "download": "[%key:component::ollama::config_subentries::conversation::progress::download%]" + } } } } diff --git a/homeassistant/components/sfr_box/manifest.json b/homeassistant/components/sfr_box/manifest.json index a2d65e9819d243..1987453a80d2a1 100644 --- a/homeassistant/components/sfr_box/manifest.json +++ b/homeassistant/components/sfr_box/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://www.home-assistant.io/integrations/sfr_box", "integration_type": "device", "iot_class": "local_polling", - "requirements": ["sfrbox-api==0.0.11"] + "requirements": ["sfrbox-api==0.0.12"] } diff --git a/homeassistant/components/sfr_box/strings.json b/homeassistant/components/sfr_box/strings.json index 35e9b1869ffde2..5139ec52badd9c 100644 --- a/homeassistant/components/sfr_box/strings.json +++ b/homeassistant/components/sfr_box/strings.json @@ -27,7 +27,7 @@ "host": "[%key:common::config_flow::data::host%]" }, "data_description": { - "host": "The hostname or IP address of your SFR device." + "host": "The hostname, IP address, or full URL of your SFR device. e.g.: '192.168.1.1' or 'https://sfrbox.example.com'" }, "description": "Setting the credentials is optional, but enables additional functionality." } diff --git a/requirements_all.txt b/requirements_all.txt index 01b297fd25070e..3e3c701508c323 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1452,7 +1452,7 @@ monzopy==1.4.2 mopeka-iot-ble==0.8.0 # homeassistant.components.motion_blinds -motionblinds==0.6.28 +motionblinds==0.6.29 # homeassistant.components.motionblinds_ble motionblindsble==0.1.3 @@ -2753,7 +2753,7 @@ sensoterra==2.0.1 sentry-sdk==1.45.1 # homeassistant.components.sfr_box -sfrbox-api==0.0.11 +sfrbox-api==0.0.12 # homeassistant.components.sharkiq sharkiq==1.1.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 29ab2676ee08bb..03f77b69b93ddc 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1244,7 +1244,7 @@ monzopy==1.4.2 mopeka-iot-ble==0.8.0 # homeassistant.components.motion_blinds -motionblinds==0.6.28 +motionblinds==0.6.29 # homeassistant.components.motionblinds_ble motionblindsble==0.1.3 @@ -2275,7 +2275,7 @@ sensoterra==2.0.1 sentry-sdk==1.45.1 # homeassistant.components.sfr_box -sfrbox-api==0.0.11 +sfrbox-api==0.0.12 # homeassistant.components.sharkiq sharkiq==1.1.0 diff --git a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr index a2603328959856..a0d34f49d375be 100644 --- a/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr +++ b/tests/components/google_generative_ai_conversation/snapshots/test_init.ambr @@ -122,8 +122,8 @@ dict({ 'contents': list([ 'Describe this image from my doorbell camera', - b'some file', - b'some file', + File(name='doorbell_snapshot.jpg', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=, source=None, video_metadata=None, error=None), + File(name='context.txt', display_name=None, mime_type=None, size_bytes=None, create_time=None, expiration_time=None, update_time=None, sha256_hash=None, uri=None, download_uri=None, state=, source=None, video_metadata=None, error=None), ]), 'model': 'models/gemini-2.5-flash', }), diff --git a/tests/components/google_generative_ai_conversation/test_init.py b/tests/components/google_generative_ai_conversation/test_init.py index c0a610f6a0ab88..351895c89fb617 100644 --- a/tests/components/google_generative_ai_conversation/test_init.py +++ b/tests/components/google_generative_ai_conversation/test_init.py @@ -80,7 +80,10 @@ async def test_generate_content_service_with_image( ) as mock_generate, patch( "google.genai.files.Files.upload", - return_value=b"some file", + side_effect=[ + File(name="doorbell_snapshot.jpg", state=FileState.ACTIVE), + File(name="context.txt", state=FileState.ACTIVE), + ], ), patch("pathlib.Path.exists", return_value=True), patch.object(hass.config, "is_allowed_path", return_value=True), @@ -92,7 +95,7 @@ async def test_generate_content_service_with_image( "generate_content", { "prompt": "Describe this image from my doorbell camera", - "filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"], + "filenames": ["doorbell_snapshot.jpg", "context.txt"], }, blocking=True, return_response=True, @@ -146,7 +149,7 @@ async def test_generate_content_file_processing_succeeds( "generate_content", { "prompt": "Describe this image from my doorbell camera", - "filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"], + "filenames": ["doorbell_snapshot.jpg", "context.txt"], }, blocking=True, return_response=True, @@ -208,7 +211,7 @@ async def test_generate_content_file_processing_fails( "generate_content", { "prompt": "Describe this image from my doorbell camera", - "filenames": ["doorbell_snapshot.jpg", "context.txt", "context.txt"], + "filenames": ["doorbell_snapshot.jpg", "context.txt"], }, blocking=True, return_response=True, diff --git a/tests/components/homee/fixtures/cover_with_position_slats.json b/tests/components/homee/fixtures/cover_with_position_slats.json index 8fd0d6f44fefd8..a61be87ab9f9dc 100644 --- a/tests/components/homee/fixtures/cover_with_position_slats.json +++ b/tests/components/homee/fixtures/cover_with_position_slats.json @@ -96,6 +96,55 @@ "options": { "automations": ["step"] } + }, + { + "id": 4, + "node_id": 3, + "instance": 0, + "minimum": -50, + "maximum": 125, + "current_value": 20.3, + "target_value": 20.3, + "last_value": 20.3, + "unit": "°C", + "step_value": 1.0, + "editable": 0, + "type": 5, + "state": 1, + "last_changed": 1709982925, + "changed_by": 1, + "changed_by_id": 0, + "based_on": 1, + "data": "", + "name": "", + "options": { + "history": { + "day": 1, + "week": 26, + "month": 6 + } + } + }, + { + "id": 5, + "node_id": 3, + "instance": 0, + "minimum": 0, + "maximum": 0, + "current_value": 0.0, + "target_value": 0.0, + "last_value": 0.0, + "unit": "text", + "step_value": 1.0, + "editable": 0, + "type": 44, + "state": 1, + "last_changed": 0, + "changed_by": 1, + "changed_by_id": 0, + "based_on": 1, + "data": "4.54", + "name": "" } ] } diff --git a/tests/components/homee/fixtures/cover_without_position.json b/tests/components/homee/fixtures/cover_without_position.json index e2bc6c7a38dd3f..f6e9ea19c8a7ae 100644 --- a/tests/components/homee/fixtures/cover_without_position.json +++ b/tests/components/homee/fixtures/cover_without_position.json @@ -43,6 +43,27 @@ "observes": [75], "automations": ["toggle"] } + }, + { + "id": 2, + "node_id": 3, + "instance": 0, + "minimum": 0, + "maximum": 0, + "current_value": 0.0, + "target_value": 0.0, + "last_value": 0.0, + "unit": "text", + "step_value": 1.0, + "editable": 0, + "type": 45, + "state": 1, + "last_changed": 0, + "changed_by": 1, + "changed_by_id": 0, + "based_on": 1, + "data": "1.45", + "name": "" } ] } diff --git a/tests/components/homee/snapshots/test_diagnostics.ambr b/tests/components/homee/snapshots/test_diagnostics.ambr index 76d3f426e1712b..d934c4e225e611 100644 --- a/tests/components/homee/snapshots/test_diagnostics.ambr +++ b/tests/components/homee/snapshots/test_diagnostics.ambr @@ -689,6 +689,55 @@ 'type': 113, 'unit': '°', }), + dict({ + 'based_on': 1, + 'changed_by': 1, + 'changed_by_id': 0, + 'current_value': 20.3, + 'data': '', + 'editable': 0, + 'id': 4, + 'instance': 0, + 'last_changed': 1709982925, + 'last_value': 20.3, + 'maximum': 125, + 'minimum': -50, + 'name': '', + 'node_id': 3, + 'options': dict({ + 'history': dict({ + 'day': 1, + 'month': 6, + 'week': 26, + }), + }), + 'state': 1, + 'step_value': 1.0, + 'target_value': 20.3, + 'type': 5, + 'unit': '°C', + }), + dict({ + 'based_on': 1, + 'changed_by': 1, + 'changed_by_id': 0, + 'current_value': 0.0, + 'data': '4.54', + 'editable': 0, + 'id': 5, + 'instance': 0, + 'last_changed': 0, + 'last_value': 0.0, + 'maximum': 0, + 'minimum': 0, + 'name': '', + 'node_id': 3, + 'state': 1, + 'step_value': 1.0, + 'target_value': 0.0, + 'type': 44, + 'unit': 'text', + }), ]), 'cube_type': 14, 'favorite': 0, diff --git a/tests/components/homee/snapshots/test_init.ambr b/tests/components/homee/snapshots/test_init.ambr new file mode 100644 index 00000000000000..664740dbeac6dc --- /dev/null +++ b/tests/components/homee/snapshots/test_init.ambr @@ -0,0 +1,71 @@ +# serializer version: 1 +# name: test_general_data + DeviceRegistryEntrySnapshot({ + 'area_id': None, + 'config_entries': , + 'config_entries_subentries': , + 'configuration_url': None, + 'connections': set({ + tuple( + 'mac', + '00:05:55:11:ee:cc', + ), + }), + 'disabled_by': None, + 'entry_type': None, + 'hw_version': None, + 'id': , + 'identifiers': set({ + tuple( + 'homee', + '00055511EECC', + ), + }), + 'is_new': False, + 'labels': set({ + }), + 'manufacturer': 'homee', + 'model': 'homee', + 'model_id': None, + 'name': 'TestHomee', + 'name_by_user': None, + 'primary_config_entry': , + 'serial_number': None, + 'suggested_area': None, + 'sw_version': '1.2.3', + 'via_device_id': None, + }) +# --- +# name: test_general_data.1 + DeviceRegistryEntrySnapshot({ + 'area_id': None, + 'config_entries': , + 'config_entries_subentries': , + 'configuration_url': None, + 'connections': set({ + }), + 'disabled_by': None, + 'entry_type': None, + 'hw_version': None, + 'id': , + 'identifiers': set({ + tuple( + 'homee', + '00055511EECC-3', + ), + }), + 'is_new': False, + 'labels': set({ + }), + 'manufacturer': None, + 'model': 'shutter_position_switch', + 'model_id': None, + 'name': 'Test Cover', + 'name_by_user': None, + 'primary_config_entry': , + 'serial_number': None, + 'suggested_area': None, + 'sw_version': '4.54', + 'via_device_id': , + }) +# --- diff --git a/tests/components/homee/test_cover.py b/tests/components/homee/test_cover.py index a3e26abc52a139..4f215c683a2caa 100644 --- a/tests/components/homee/test_cover.py +++ b/tests/components/homee/test_cover.py @@ -13,6 +13,10 @@ CoverEntityFeature, CoverState, ) +from homeassistant.components.homeassistant import ( + DOMAIN as HA_DOMAIN, + SERVICE_UPDATE_ENTITY, +) from homeassistant.components.homee.const import DOMAIN from homeassistant.const import ( ATTR_ENTITY_ID, @@ -23,9 +27,11 @@ SERVICE_SET_COVER_POSITION, SERVICE_SET_COVER_TILT_POSITION, SERVICE_STOP_COVER, + STATE_UNAVAILABLE, ) from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError +from homeassistant.setup import async_setup_component from . import build_mock_node, setup_integration @@ -39,6 +45,7 @@ async def test_open_close_stop_cover( ) -> None: """Test opening the cover.""" mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] await setup_integration(hass, mock_config_entry) @@ -73,6 +80,7 @@ async def test_open_close_reverse_cover( ) -> None: """Test opening the cover.""" mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] mock_homee.nodes[0].attributes[0].is_reversed = True await setup_integration(hass, mock_config_entry) @@ -102,6 +110,7 @@ async def test_set_cover_position( ) -> None: """Test setting the cover position.""" mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] await setup_integration(hass, mock_config_entry) @@ -246,6 +255,7 @@ async def test_cover_positions( # Cover open, tilt open. # mock_homee.nodes = [cover] mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] cover = mock_homee.nodes[0] await setup_integration(hass, mock_config_entry) @@ -348,3 +358,50 @@ async def test_send_error( assert exc_info.value.translation_domain == DOMAIN assert exc_info.value.translation_key == "connection_closed" + + +async def test_node_entity_connection_listener( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test if loss of connection is sensed correctly.""" + mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] + await setup_integration(hass, mock_config_entry) + + states = hass.states.get("cover.test_cover") + assert states.state != STATE_UNAVAILABLE + + await mock_homee.add_connection_listener.call_args_list[1][0][0](False) + await hass.async_block_till_done() + + states = hass.states.get("cover.test_cover") + assert states.state == STATE_UNAVAILABLE + + await mock_homee.add_connection_listener.call_args_list[1][0][0](True) + await hass.async_block_till_done() + + states = hass.states.get("cover.test_cover") + assert states.state != STATE_UNAVAILABLE + + +async def test_node_entity_update_action( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test the update_entity action for a HomeeEntity.""" + mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] + await setup_integration(hass, mock_config_entry) + await async_setup_component(hass, HA_DOMAIN, {}) + + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + {ATTR_ENTITY_ID: "cover.test_cover"}, + blocking=True, + ) + + mock_homee.update_node.assert_called_once_with(3) diff --git a/tests/components/homee/test_init.py b/tests/components/homee/test_init.py new file mode 100644 index 00000000000000..0b2ae21a8d0d20 --- /dev/null +++ b/tests/components/homee/test_init.py @@ -0,0 +1,131 @@ +"""Test Homee initialization.""" + +from unittest.mock import MagicMock + +from pyHomee import HomeeAuthFailedException, HomeeConnectionFailedException +import pytest +from syrupy.assertion import SnapshotAssertion + +from homeassistant.components.homee.const import DOMAIN +from homeassistant.config_entries import ConfigEntryState +from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr + +from . import build_mock_node, setup_integration +from .conftest import HOMEE_ID + +from tests.common import MockConfigEntry + + +@pytest.mark.parametrize( + "side_eff", + [ + HomeeConnectionFailedException("connection timed out"), + HomeeAuthFailedException("wrong username or password"), + ], +) +async def test_connection_errors( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, + side_eff: Exception, +) -> None: + """Test if connection errors on startup are handled correctly.""" + mock_homee.get_access_token.side_effect = side_eff + mock_config_entry.add_to_hass(hass) + + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY + + +async def test_connection_listener( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test if loss of connection is sensed correctly.""" + mock_homee.nodes = [build_mock_node("homee.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] + await setup_integration(hass, mock_config_entry) + + await mock_homee.add_connection_listener.call_args_list[0][0][0](False) + await hass.async_block_till_done() + assert "Disconnected from Homee" in caplog.text + await mock_homee.add_connection_listener.call_args_list[0][0][0](True) + await hass.async_block_till_done() + assert "Reconnected to Homee" in caplog.text + + +async def test_general_data( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, + device_registry: dr.DeviceRegistry, + snapshot: SnapshotAssertion, +) -> None: + """Test if data is set correctly.""" + mock_homee.nodes = [ + build_mock_node("cover_with_position_slats.json"), + build_mock_node("homee.json"), + ] + mock_homee.get_node_by_id = ( + lambda node_id: mock_homee.nodes[0] if node_id == 3 else mock_homee.nodes[1] + ) + await setup_integration(hass, mock_config_entry) + + # Verify hub and device created correctly using snapshots. + hub = device_registry.async_get_device(identifiers={(DOMAIN, f"{HOMEE_ID}")}) + device = device_registry.async_get_device(identifiers={(DOMAIN, f"{HOMEE_ID}-3")}) + + assert hub == snapshot + assert device == snapshot + + +async def test_software_version( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test sw_version for device with only AttributeType.SOFTWARE_VERSION.""" + mock_homee.nodes = [build_mock_node("cover_without_position.json")] + await setup_integration(hass, mock_config_entry) + + device = device_registry.async_get_device(identifiers={(DOMAIN, f"{HOMEE_ID}-3")}) + assert device.sw_version == "1.45" + + +async def test_invalid_profile( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, + device_registry: dr.DeviceRegistry, +) -> None: + """Test unknown value passed to get_name_for_enum.""" + mock_homee.nodes = [build_mock_node("cover_without_position.json")] + # This is a profile, that does not exist in the enum. + mock_homee.nodes[0].profile = 77 + await setup_integration(hass, mock_config_entry) + + device = device_registry.async_get_device(identifiers={(DOMAIN, f"{HOMEE_ID}-3")}) + assert device.model is None + + +async def test_unload_entry( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test unloading of config entry.""" + mock_homee.nodes = [build_mock_node("cover_with_position_slats.json")] + mock_homee.get_node_by_id.return_value = mock_homee.nodes[0] + await setup_integration(hass, mock_config_entry) + + assert mock_config_entry.state is ConfigEntryState.LOADED + + await hass.config_entries.async_unload(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_config_entry.state is ConfigEntryState.NOT_LOADED diff --git a/tests/components/homee/test_sensor.py b/tests/components/homee/test_sensor.py index 1d4ad4b0f66dbe..b51b3a23b75d8c 100644 --- a/tests/components/homee/test_sensor.py +++ b/tests/components/homee/test_sensor.py @@ -5,6 +5,10 @@ import pytest from syrupy.assertion import SnapshotAssertion +from homeassistant.components.homeassistant import ( + DOMAIN as HA_DOMAIN, + SERVICE_UPDATE_ENTITY, +) from homeassistant.components.homee.const import ( DOMAIN, OPEN_CLOSE_MAP, @@ -13,9 +17,10 @@ WINDOW_MAP_REVERSED, ) from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN -from homeassistant.const import Platform +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er, issue_registry as ir +from homeassistant.setup import async_setup_component from . import async_update_attribute_value, build_mock_node, setup_integration from .conftest import HOMEE_ID @@ -168,6 +173,49 @@ async def test_sensor_deprecation_unused_entity( ) +async def test_entity_connection_listener( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test if loss of connection is sensed correctly.""" + await setup_sensor(hass, mock_homee, mock_config_entry) + + states = hass.states.get("sensor.test_multisensor_energy_1") + assert states.state is not STATE_UNAVAILABLE + + await mock_homee.add_connection_listener.call_args_list[2][0][0](False) + await hass.async_block_till_done() + + states = hass.states.get("sensor.test_multisensor_energy_1") + assert states.state is STATE_UNAVAILABLE + + await mock_homee.add_connection_listener.call_args_list[2][0][0](True) + await hass.async_block_till_done() + + states = hass.states.get("sensor.test_multisensor_energy_1") + assert states.state is not STATE_UNAVAILABLE + + +async def test_entity_update_action( + hass: HomeAssistant, + mock_homee: MagicMock, + mock_config_entry: MockConfigEntry, +) -> None: + """Test the update_entity action for a HomeeEntity.""" + await setup_sensor(hass, mock_homee, mock_config_entry) + await async_setup_component(hass, HA_DOMAIN, {}) + + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + {ATTR_ENTITY_ID: "sensor.test_multisensor_temperature"}, + blocking=True, + ) + + mock_homee.update_attribute.assert_called_once_with(1, 23) + + async def test_sensor_snapshot( hass: HomeAssistant, mock_homee: MagicMock, diff --git a/tests/components/ollama/__init__.py b/tests/components/ollama/__init__.py index 92db3b13304279..9e7ae4772d4f2b 100644 --- a/tests/components/ollama/__init__.py +++ b/tests/components/ollama/__init__.py @@ -12,3 +12,8 @@ ollama.CONF_MAX_HISTORY: 2, ollama.CONF_MODEL: "test_model:latest", } + +TEST_AI_TASK_OPTIONS = { + ollama.CONF_MAX_HISTORY: 2, + ollama.CONF_MODEL: "test_model:latest", +} diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 552e7dee20a3b6..f3406bf5566d94 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -11,7 +11,7 @@ from homeassistant.helpers import llm from homeassistant.setup import async_setup_component -from . import TEST_OPTIONS, TEST_USER_DATA +from . import TEST_AI_TASK_OPTIONS, TEST_OPTIONS, TEST_USER_DATA from tests.common import MockConfigEntry @@ -31,14 +31,20 @@ def mock_config_entry( domain=ollama.DOMAIN, data=TEST_USER_DATA, version=3, - minor_version=1, + minor_version=2, subentries_data=[ { "data": {**TEST_OPTIONS, **mock_config_entry_options}, "subentry_type": "conversation", "title": "Ollama Conversation", "unique_id": None, - } + }, + { + "data": TEST_AI_TASK_OPTIONS, + "subentry_type": "ai_task_data", + "title": "Ollama AI Task", + "unique_id": None, + }, ], ) entry.add_to_hass(hass) diff --git a/tests/components/ollama/test_ai_task.py b/tests/components/ollama/test_ai_task.py new file mode 100644 index 00000000000000..ee812e7b316b43 --- /dev/null +++ b/tests/components/ollama/test_ai_task.py @@ -0,0 +1,245 @@ +"""Test AI Task platform of Ollama integration.""" + +from unittest.mock import patch + +import pytest +import voluptuous as vol + +from homeassistant.components import ai_task +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import entity_registry as er, selector + +from tests.common import MockConfigEntry + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Ensure entity is linked to the subentry + entity_entry = entity_registry.async_get(entity_id) + ai_task_entry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "ai_task_data" + ) + ) + assert entity_entry is not None + assert entity_entry.config_entry_id == mock_config_entry.entry_id + assert entity_entry.config_subentry_id == ai_task_entry.subentry_id + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": {"role": "assistant", "content": "Generated test data"}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + ) + + assert result.data == "Generated test data" + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_with_streaming( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation with streaming response.""" + entity_id = "ai_task.ollama_ai_task" + + async def mock_stream(): + """Mock streaming response.""" + yield {"message": {"role": "assistant", "content": "Stream "}} + yield { + "message": {"role": "assistant", "content": "response"}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_stream(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Streaming Task", + entity_id=entity_id, + instructions="Generate streaming data", + ) + + assert result.data == "Stream response" + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_connection_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task with connection error.""" + entity_id = "ai_task.ollama_ai_task" + + with ( + patch( + "ollama.AsyncClient.chat", + side_effect=Exception("Connection failed"), + ), + pytest.raises(Exception, match="Connection failed"), + ): + await ai_task.async_generate_data( + hass, + task_name="Test Error Task", + entity_id=entity_id, + instructions="Generate data that will fail", + ) + + +@pytest.mark.usefixtures("mock_init_component") +async def test_run_task_empty_response( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task with empty response.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock response with space (minimally non-empty) + async def mock_minimal_response(): + """Mock minimal streaming response.""" + yield { + "message": {"role": "assistant", "content": " "}, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_minimal_response(), + ): + result = await ai_task.async_generate_data( + hass, + task_name="Test Minimal Task", + entity_id=entity_id, + instructions="Generate minimal data", + ) + + assert result.data == " " + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": { + "role": "assistant", + "content": '{"characters": ["Mario", "Luigi"]}', + }, + "done": True, + "done_reason": "stop", + } + + with patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ) as mock_chat: + result = await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) + assert result.data == {"characters": ["Mario", "Luigi"]} + + assert mock_chat.call_count == 1 + assert mock_chat.call_args[1]["format"] == { + "type": "object", + "properties": {"characters": {"items": {"type": "string"}, "type": "array"}}, + "required": ["characters"], + } + + +@pytest.mark.usefixtures("mock_init_component") +async def test_generate_invalid_structured_data( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + entity_registry: er.EntityRegistry, +) -> None: + """Test AI Task data generation.""" + entity_id = "ai_task.ollama_ai_task" + + # Mock the Ollama chat response as an async iterator + async def mock_chat_response(): + """Mock streaming response.""" + yield { + "message": { + "role": "assistant", + "content": "INVALID JSON RESPONSE", + }, + "done": True, + "done_reason": "stop", + } + + with ( + patch( + "ollama.AsyncClient.chat", + return_value=mock_chat_response(), + ), + pytest.raises(HomeAssistantError), + ): + await ai_task.async_generate_data( + hass, + task_name="Test Task", + entity_id=entity_id, + instructions="Generate test data", + structure=vol.Schema( + { + vol.Required("characters"): selector.selector( + { + "text": { + "multiple": True, + } + } + ) + }, + ), + ) diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index 7372a460c95e1f..1a873c2adb7261 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -461,3 +461,78 @@ async def delayed_pull(self, model: str) -> None: ollama.CONF_NUM_CTX: 8192.0, ollama.CONF_THINK: True, } + + +async def test_creating_ai_task_subentry( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + mock_init_component, +) -> None: + """Test creating an AI task subentry.""" + old_subentries = set(mock_config_entry.subentries) + # Original conversation + original ai_task + assert len(mock_config_entry.subentries) == 2 + + with patch( + "ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model:latest"}]}, + ): + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task_data"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result.get("type") is FlowResultType.FORM + assert result.get("step_id") == "set_options" + assert not result.get("errors") + + with patch( + "ollama.AsyncClient.list", + return_value={"models": [{"model": "test_model:latest"}]}, + ): + result2 = await hass.config_entries.subentries.async_configure( + result["flow_id"], + { + "name": "Custom AI Task", + ollama.CONF_MODEL: "test_model:latest", + ollama.CONF_MAX_HISTORY: 5, + ollama.CONF_NUM_CTX: 4096, + ollama.CONF_KEEP_ALIVE: 30, + ollama.CONF_THINK: False, + }, + ) + await hass.async_block_till_done() + + assert result2.get("type") is FlowResultType.CREATE_ENTRY + assert result2.get("title") == "Custom AI Task" + assert result2.get("data") == { + ollama.CONF_MODEL: "test_model:latest", + ollama.CONF_MAX_HISTORY: 5, + ollama.CONF_NUM_CTX: 4096, + ollama.CONF_KEEP_ALIVE: 30, + ollama.CONF_THINK: False, + } + + assert ( + len(mock_config_entry.subentries) == 3 + ) # Original conversation + original ai_task + new ai_task + + new_subentry_id = list(set(mock_config_entry.subentries) - old_subentries)[0] + new_subentry = mock_config_entry.subentries[new_subentry_id] + assert new_subentry.subentry_type == "ai_task_data" + assert new_subentry.title == "Custom AI Task" + + +async def test_ai_task_subentry_not_loaded( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, +) -> None: + """Test creating an AI task subentry when entry is not loaded.""" + # Don't call mock_init_component to simulate not loaded state + result = await hass.config_entries.subentries.async_init( + (mock_config_entry.entry_id, "ai_task_data"), + context={"source": config_entries.SOURCE_USER}, + ) + + assert result.get("type") is FlowResultType.ABORT + assert result.get("reason") == "entry_not_loaded" diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index c7cd78fca9a5db..1db57302704108 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -93,16 +93,23 @@ async def test_migration_from_v1( return_value=True, ): await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() assert mock_config_entry.version == 3 - assert mock_config_entry.minor_version == 1 + assert mock_config_entry.minor_version == 2 # After migration, parent entry should only have URL assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} assert mock_config_entry.options == {} - assert len(mock_config_entry.subentries) == 1 + assert len(mock_config_entry.subentries) == 2 - subentry = next(iter(mock_config_entry.subentries.values())) + subentry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "conversation" + ) + ) assert subentry.unique_id is None assert subentry.title == "llama-3.2-8b" assert subentry.subentry_type == "conversation" @@ -110,6 +117,18 @@ async def test_migration_from_v1( expected_subentry_data = TEST_OPTIONS.copy() assert subentry.data == expected_subentry_data + # Find the AI Task subentry + ai_task_subentry = next( + iter( + entry + for entry in mock_config_entry.subentries.values() + if entry.subentry_type == "ai_task_data" + ) + ) + assert ai_task_subentry.unique_id is None + assert ai_task_subentry.title == "Ollama AI Task" + assert ai_task_subentry.subentry_type == "ai_task_data" + migrated_entity = entity_registry.async_get(entity.entity_id) assert migrated_entity is not None assert migrated_entity.config_entry_id == mock_config_entry.entry_id @@ -204,10 +223,17 @@ async def test_migration_from_v1_with_multiple_urls( for idx, entry in enumerate(entries): assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options - assert len(entry.subentries) == 1 - subentry = list(entry.subentries.values())[0] + assert len(entry.subentries) == 2 + + subentry = next( + iter( + subentry + for subentry in entry.subentries.values() + if subentry.subentry_type == "conversation" + ) + ) assert subentry.subentry_type == "conversation" # Subentry should include the model along with the original options expected_subentry_data = TEST_OPTIONS.copy() @@ -215,6 +241,17 @@ async def test_migration_from_v1_with_multiple_urls( assert subentry.data == expected_subentry_data assert subentry.title == f"Ollama {idx + 1}" + # Find the AI Task subentry + ai_task_subentry = next( + iter( + subentry + for subentry in entry.subentries.values() + if subentry.subentry_type == "ai_task_data" + ) + ) + assert ai_task_subentry.subentry_type == "ai_task_data" + assert ai_task_subentry.title == "Ollama AI Task" + dev = device_registry.async_get_device( identifiers={(DOMAIN, list(entry.subentries.values())[0].subentry_id)} ) @@ -295,9 +332,10 @@ async def test_migration_from_v1_with_same_urls( entry = entries[0] assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options - assert len(entry.subentries) == 2 # Two subentries from the two original entries + # Two conversation subentries from the two original entries and 1 aitask subentry + assert len(entry.subentries) == 3 # Check both subentries exist with correct data subentries = list(entry.subentries.values()) @@ -305,7 +343,11 @@ async def test_migration_from_v1_with_same_urls( assert "Ollama" in titles assert "Ollama 2" in titles - for subentry in subentries: + conversation_subentries = [ + subentry for subentry in subentries if subentry.subentry_type == "conversation" + ] + assert len(conversation_subentries) == 2 + for subentry in conversation_subentries: assert subentry.subentry_type == "conversation" # Subentry should include the model along with the original options expected_subentry_data = TEST_OPTIONS.copy() @@ -415,10 +457,10 @@ async def test_migration_from_v2_1( assert len(entries) == 1 entry = entries[0] assert entry.version == 3 - assert entry.minor_version == 1 + assert entry.minor_version == 2 assert not entry.options assert entry.title == "Ollama" - assert len(entry.subentries) == 2 + assert len(entry.subentries) == 3 conversation_subentries = [ subentry for subentry in entry.subentries.values() @@ -504,14 +546,44 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None: # Check migration to v3.1 assert mock_config_entry.version == 3 - assert mock_config_entry.minor_version == 1 + assert mock_config_entry.minor_version == 2 # Check that model was moved from main data to subentry assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} - assert len(mock_config_entry.subentries) == 1 + assert len(mock_config_entry.subentries) == 2 subentry = next(iter(mock_config_entry.subentries.values())) assert subentry.data == { **V21_TEST_USER_DATA, ollama.CONF_MODEL: "test_model:latest", } + + +async def test_migration_from_v3_1_without_subentry(hass: HomeAssistant) -> None: + """Test migration from version 3.1 where there is no existing subentry. + + This exercises the code path where the model is not moved to a subentry + because the subentry does not exist, which is a scenario that can happen + if the user created the config entry without adding a subentry, or + if the user manually removed the subentry after the migration to v3.1. + """ + mock_config_entry = MockConfigEntry( + domain=DOMAIN, + data={ + ollama.CONF_MODEL: "test_model:latest", + }, + version=3, + minor_version=1, + ) + mock_config_entry.add_to_hass(hass) + + with patch( + "homeassistant.components.ollama.async_setup_entry", + return_value=True, + ): + await hass.config_entries.async_setup(mock_config_entry.entry_id) + + assert mock_config_entry.version == 3 + assert mock_config_entry.minor_version == 2 + + assert next(iter(mock_config_entry.subentries.values()), None) is None