diff --git a/chatkit/actions.py b/chatkit/actions.py index 9048f63..e00b3ad 100644 --- a/chatkit/actions.py +++ b/chatkit/actions.py @@ -24,7 +24,7 @@ class ActionConfig(BaseModel): class Action(BaseModel, Generic[TType, TPayload]): type: TType = Field(default=TType, frozen=True) # pyright: ignore - payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions + payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions @classmethod def create( diff --git a/chatkit/agents.py b/chatkit/agents.py index 4d307d7..5f3b9ca 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -1,5 +1,6 @@ import asyncio import json +from collections import defaultdict from collections.abc import AsyncIterator from datetime import datetime from inspect import cleandoc @@ -45,6 +46,7 @@ Annotation, AssistantMessageContent, AssistantMessageContentPartAdded, + AssistantMessageContentPartAnnotationAdded, AssistantMessageContentPartDone, AssistantMessageContentPartTextDelta, AssistantMessageItem, @@ -207,9 +209,10 @@ def _complete(self): def _convert_content(content: Content) -> AssistantMessageContent: if content.type == "output_text": - annotations = [] - for annotation in content.annotations: - annotations.extend(_convert_annotation(annotation)) + annotations = [ + _convert_annotation(annotation) for annotation in content.annotations + ] + annotations = [a for a in annotations if a is not None] return AssistantMessageContent( text=content.text, annotations=annotations, @@ -221,37 +224,43 @@ def _convert_content(content: Content) -> AssistantMessageContent: ) -def _convert_annotation( - annotation: ResponsesAnnotation, -) -> list[Annotation]: +def _convert_annotation(raw_annotation: object) -> Annotation | None: # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class - # resulting into annotation being a dict instead of a ResponsesAnnotation - if isinstance(annotation, dict): - annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation) + # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation + annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python( + raw_annotation + ) - result: list[Annotation] = [] if annotation.type == "file_citation": filename = annotation.filename if not filename: - return [] - result.append( - Annotation( - source=FileSource(filename=filename, title=filename), - index=annotation.index, - ) + return None + + return Annotation( + source=FileSource(filename=filename, title=filename), + index=annotation.index, ) - elif annotation.type == "url_citation": - result.append( - Annotation( - source=URLSource( - url=annotation.url, - title=annotation.title, - ), - index=annotation.end_index, - ) + + if annotation.type == "url_citation": + return Annotation( + source=URLSource( + url=annotation.url, + title=annotation.title, + ), + index=annotation.end_index, ) - return result + if annotation.type == "container_file_citation": + filename = annotation.filename + if not filename: + return None + + return Annotation( + source=FileSource(filename=filename, title=filename), + index=annotation.end_index, + ) + + return None T1 = TypeVar("T1") @@ -349,6 +358,10 @@ async def stream_agent_response( queue_iterator = _AsyncQueueIterator(context._events) produced_items = set() streaming_thought: None | StreamingThoughtTracker = None + # item_id -> content_index -> annotation count + item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict( + lambda: defaultdict(int) + ) # check if the last item in the thread was a workflow or a client tool call # if it was a client tool call, check if the second last item was a workflow @@ -462,7 +475,24 @@ def end_workflow(item: WorkflowItem): ), ) elif event.type == "response.output_text.annotation.added": - # Ignore annotation-added events; annotations are reflected in the final item content. + annotation = _convert_annotation(event.annotation) + if annotation: + # Manually track annotation indices per content part in case we drop an annotation that + # we can't convert to our internal representation (e.g. missing filename). + annotation_index = item_annotation_count[event.item_id][ + event.content_index + ] + item_annotation_count[event.item_id][event.content_index] = ( + annotation_index + 1 + ) + yield ThreadItemUpdated( + item_id=event.item_id, + update=AssistantMessageContentPartAnnotationAdded( + content_index=event.content_index, + annotation_index=annotation_index, + annotation=annotation, + ), + ) continue elif event.type == "response.output_item.added": item = event.item diff --git a/pyproject.toml b/pyproject.toml index e3dced4..bb76263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-chatkit" -version = "1.1.2" +version = "1.2.2" description = "A ChatKit backend SDK." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_agents.py b/tests/test_agents.py index ab9e9e6..dcd7bd0 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -38,6 +38,9 @@ ResponseContentPartAddedEvent, ) from openai.types.responses.response_file_search_tool_call import Result +from openai.types.responses.response_output_text import ( + AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation, +) from openai.types.responses.response_output_text import ( AnnotationFileCitation as ResponsesAnnotationFileCitation, ) @@ -64,6 +67,7 @@ Annotation, AssistantMessageContent, AssistantMessageContentPartAdded, + AssistantMessageContentPartAnnotationAdded, AssistantMessageContentPartDone, AssistantMessageContentPartTextDelta, AssistantMessageItem, @@ -790,7 +794,17 @@ async def test_stream_agent_response_maps_events(): sequence_number=3, ), ), - None, + ThreadItemUpdated( + item_id="123", + update=AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=0, + annotation=Annotation( + source=FileSource(filename="file.txt", title="file.txt"), + index=5, + ), + ), + ), ), ], ) @@ -810,6 +824,91 @@ async def test_event_mapping(raw_event, expected_event): assert events == [] +async def test_stream_agent_response_emits_annotation_added_events(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + item_id = "item_123" + + def add_annotation_event(annotation, sequence_number): + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_text.annotation.added", + annotation=annotation, + content_index=0, + item_id=item_id, + annotation_index=sequence_number, + output_index=0, + sequence_number=sequence_number, + ), + ) + ) + + add_annotation_event( + ResponsesAnnotationFileCitation( + type="file_citation", + file_id="file_invalid", + filename="", + index=0, + ), + sequence_number=0, + ) + add_annotation_event( + ResponsesAnnotationContainerFileCitation( + type="container_file_citation", + container_id="container_1", + file_id="file_123", + filename="container.txt", + start_index=0, + end_index=3, + ), + sequence_number=1, + ) + add_annotation_event( + ResponsesAnnotationURLCitation( + type="url_citation", + url="https://example.com", + title="Example", + start_index=1, + end_index=5, + ), + sequence_number=2, + ) + result.done() + + events = await all_events(stream_agent_response(context, result)) + assert events == [ + ThreadItemUpdated( + item_id=item_id, + update=AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=0, + annotation=Annotation( + source=FileSource(filename="container.txt", title="container.txt"), + index=3, + ), + ), + ), + ThreadItemUpdated( + item_id=item_id, + update=AssistantMessageContentPartAnnotationAdded( + content_index=0, + annotation_index=1, + annotation=Annotation( + source=URLSource( + url="https://example.com", + title="Example", + ), + index=5, + ), + ), + ), + ] + + @pytest.mark.parametrize("throw_guardrail", ["input", "output"]) async def test_stream_agent_response_yields_item_removed_event(throw_guardrail): context = AgentContext( @@ -942,6 +1041,14 @@ async def test_stream_agent_response_assistant_message_content_types(): index=0, filename="test.txt", ), + ResponsesAnnotationContainerFileCitation( + type="container_file_citation", + container_id="container_1", + file_id="f_456", + filename="container.txt", + start_index=0, + end_index=3, + ), ResponsesAnnotationURLCitation( type="url_citation", url="https://www.google.com", @@ -994,6 +1101,13 @@ async def test_stream_agent_response_assistant_message_content_types(): ), index=0, ), + Annotation( + source=FileSource( + filename="container.txt", + title="container.txt", + ), + index=3, + ), Annotation( source=URLSource( url="https://www.google.com", diff --git a/uv.lock b/uv.lock index 8f3f1d1..e301ccc 100644 --- a/uv.lock +++ b/uv.lock @@ -819,7 +819,7 @@ wheels = [ [[package]] name = "openai-chatkit" -version = "1.1.2" +version = "1.2.2" source = { virtual = "." } dependencies = [ { name = "openai" },