Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion homeassistant/components/ollama/ai_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class OllamaTaskEntity(
):
"""Ollama AI Task entity."""

_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
_attr_supported_features = (
ai_task.AITaskEntityFeature.GENERATE_DATA
| ai_task.AITaskEntityFeature.SUPPORT_ATTACHMENTS
)

async def _async_generate_data(
self,
Expand Down
9 changes: 9 additions & 0 deletions homeassistant/components/ollama/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,18 @@ def _convert_content(
],
)
if isinstance(chat_content, conversation.UserContent):
images: list[ollama.Image] = []
for attachment in chat_content.attachments or ():
if not attachment.mime_type.startswith("image/"):
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="unsupported_attachment_type",
)
images.append(ollama.Image(value=attachment.path))
return ollama.Message(
role=MessageRole.USER.value,
content=chat_content.content,
images=images or None,
)
if isinstance(chat_content, conversation.SystemContent):
return ollama.Message(
Expand Down
5 changes: 5 additions & 0 deletions homeassistant/components/ollama/strings.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,10 @@
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
}
}
},
"exceptions": {
"unsupported_attachment_type": {
"message": "Ollama only supports image attachments in user content, but received non-image attachment."
}
}
}
116 changes: 115 additions & 1 deletion tests/components/ollama/test_ai_task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Test AI Task platform of Ollama integration."""

from pathlib import Path
from unittest.mock import patch

import ollama
import pytest
import voluptuous as vol

from homeassistant.components import ai_task
from homeassistant.components import ai_task, media_source
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity_registry as er, selector
Expand Down Expand Up @@ -243,3 +245,115 @@ async def mock_chat_response():
},
),
)


@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data_with_attachment(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with image attachments."""
entity_id = "ai_task.ollama_ai_task"

# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {"role": "assistant", "content": "Generated test data"},
"done": True,
"done_reason": "stop",
}

with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
],
),
patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
) as mock_chat,
):
result = await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
],
)

assert result.data == "Generated test data"

assert mock_chat.call_count == 1
messages = mock_chat.call_args[1]["messages"]
assert len(messages) == 2
chat_message = messages[1]
assert chat_message.role == "user"
assert chat_message.content == "Generate test data"
assert chat_message.images == [
ollama.Image(value=Path("doorbell_snapshot.jpg")),
]


@pytest.mark.usefixtures("mock_init_component")
async def test_generate_data_with_unsupported_file_format(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
entity_registry: er.EntityRegistry,
) -> None:
"""Test AI Task data generation with image attachments."""
entity_id = "ai_task.ollama_ai_task"

# Mock the Ollama chat response as an async iterator
async def mock_chat_response():
"""Mock streaming response."""
yield {
"message": {"role": "assistant", "content": "Generated test data"},
"done": True,
"done_reason": "stop",
}

with (
patch(
"homeassistant.components.media_source.async_resolve_media",
side_effect=[
media_source.PlayMedia(
url="http://example.com/doorbell_snapshot.jpg",
mime_type="image/jpeg",
path=Path("doorbell_snapshot.jpg"),
),
media_source.PlayMedia(
url="http://example.com/context.txt",
mime_type="text/plain",
path=Path("context.txt"),
),
],
),
patch(
"ollama.AsyncClient.chat",
return_value=mock_chat_response(),
),
pytest.raises(
HomeAssistantError,
match="Ollama only supports image attachments in user content",
),
):
await ai_task.async_generate_data(
hass,
task_name="Test Task",
entity_id=entity_id,
instructions="Generate test data",
attachments=[
{"media_content_id": "media-source://media/doorbell_snapshot.jpg"},
{"media_content_id": "media-source://media/context.txt"},
],
)
Loading