Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chatkit/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ActionConfig(BaseModel):

class Action(BaseModel, Generic[TType, TPayload]):
type: TType = Field(default=TType, frozen=True) # pyright: ignore
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions
payload: TPayload = None # pyright: ignore - default to None to allow no-payload actions

@classmethod
def create(
Expand Down
84 changes: 57 additions & 27 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
from collections import defaultdict
from collections.abc import AsyncIterator
from datetime import datetime
from inspect import cleandoc
Expand Down Expand Up @@ -45,6 +46,7 @@
Annotation,
AssistantMessageContent,
AssistantMessageContentPartAdded,
AssistantMessageContentPartAnnotationAdded,
AssistantMessageContentPartDone,
AssistantMessageContentPartTextDelta,
AssistantMessageItem,
Expand Down Expand Up @@ -207,9 +209,10 @@ def _complete(self):

def _convert_content(content: Content) -> AssistantMessageContent:
if content.type == "output_text":
annotations = []
for annotation in content.annotations:
annotations.extend(_convert_annotation(annotation))
annotations = [
_convert_annotation(annotation) for annotation in content.annotations
]
annotations = [a for a in annotations if a is not None]
return AssistantMessageContent(
text=content.text,
annotations=annotations,
Expand All @@ -221,37 +224,43 @@ def _convert_content(content: Content) -> AssistantMessageContent:
)


def _convert_annotation(
annotation: ResponsesAnnotation,
) -> list[Annotation]:
def _convert_annotation(raw_annotation: object) -> Annotation | None:
# There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
# resulting into annotation being a dict instead of a ResponsesAnnotation
if isinstance(annotation, dict):
annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation)
# resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
raw_annotation
)
Comment on lines +229 to +232
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we put this in a defensive try except since the validator can throw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered that, but decided maybe an exception is best here.

  1. It didn't have a try before this change
  2. This is only operating on objects we got from the responses API so it's a pretty serious contract breakage for this validator not to pass.


result: list[Annotation] = []
if annotation.type == "file_citation":
filename = annotation.filename
if not filename:
return []
result.append(
Annotation(
source=FileSource(filename=filename, title=filename),
index=annotation.index,
)
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=annotation.index,
)
elif annotation.type == "url_citation":
result.append(
Annotation(
source=URLSource(
url=annotation.url,
title=annotation.title,
),
index=annotation.end_index,
)

if annotation.type == "url_citation":
return Annotation(
source=URLSource(
url=annotation.url,
title=annotation.title,
),
index=annotation.end_index,
)

return result
if annotation.type == "container_file_citation":
filename = annotation.filename
if not filename:
return None

return Annotation(
source=FileSource(filename=filename, title=filename),
index=annotation.end_index,
)

return None


T1 = TypeVar("T1")
Expand Down Expand Up @@ -349,6 +358,10 @@ async def stream_agent_response(
queue_iterator = _AsyncQueueIterator(context._events)
produced_items = set()
streaming_thought: None | StreamingThoughtTracker = None
# item_id -> content_index -> annotation count
item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
lambda: defaultdict(int)
)

# check if the last item in the thread was a workflow or a client tool call
# if it was a client tool call, check if the second last item was a workflow
Expand Down Expand Up @@ -462,7 +475,24 @@ def end_workflow(item: WorkflowItem):
),
)
elif event.type == "response.output_text.annotation.added":
# Ignore annotation-added events; annotations are reflected in the final item content.
annotation = _convert_annotation(event.annotation)
if annotation:
# Manually track annotation indices per content part in case we drop an annotation that
# we can't convert to our internal representation (e.g. missing filename).
annotation_index = item_annotation_count[event.item_id][
event.content_index
]
item_annotation_count[event.item_id][event.content_index] = (
annotation_index + 1
)
yield ThreadItemUpdated(
item_id=event.item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=event.content_index,
annotation_index=annotation_index,
annotation=annotation,
),
)
continue
elif event.type == "response.output_item.added":
item = event.item
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "openai-chatkit"
version = "1.1.2"
version = "1.2.2"
description = "A ChatKit backend SDK."
readme = "README.md"
requires-python = ">=3.10"
Expand Down
116 changes: 115 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
ResponseContentPartAddedEvent,
)
from openai.types.responses.response_file_search_tool_call import Result
from openai.types.responses.response_output_text import (
AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
)
from openai.types.responses.response_output_text import (
AnnotationFileCitation as ResponsesAnnotationFileCitation,
)
Expand All @@ -64,6 +67,7 @@
Annotation,
AssistantMessageContent,
AssistantMessageContentPartAdded,
AssistantMessageContentPartAnnotationAdded,
AssistantMessageContentPartDone,
AssistantMessageContentPartTextDelta,
AssistantMessageItem,
Expand Down Expand Up @@ -790,7 +794,17 @@ async def test_stream_agent_response_maps_events():
sequence_number=3,
),
),
None,
ThreadItemUpdated(
item_id="123",
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
annotation_index=0,
annotation=Annotation(
source=FileSource(filename="file.txt", title="file.txt"),
index=5,
),
),
),
),
],
)
Expand All @@ -810,6 +824,91 @@ async def test_event_mapping(raw_event, expected_event):
assert events == []


async def test_stream_agent_response_emits_annotation_added_events():
context = AgentContext(
previous_response_id=None, thread=thread, store=mock_store, request_context=None
)
result = make_result()
item_id = "item_123"

def add_annotation_event(annotation, sequence_number):
result.add_event(
RawResponsesStreamEvent(
type="raw_response_event",
data=Mock(
type="response.output_text.annotation.added",
annotation=annotation,
content_index=0,
item_id=item_id,
annotation_index=sequence_number,
output_index=0,
sequence_number=sequence_number,
),
)
)

add_annotation_event(
ResponsesAnnotationFileCitation(
type="file_citation",
file_id="file_invalid",
filename="",
index=0,
),
sequence_number=0,
)
add_annotation_event(
ResponsesAnnotationContainerFileCitation(
type="container_file_citation",
container_id="container_1",
file_id="file_123",
filename="container.txt",
start_index=0,
end_index=3,
),
sequence_number=1,
)
add_annotation_event(
ResponsesAnnotationURLCitation(
type="url_citation",
url="https://example.com",
title="Example",
start_index=1,
end_index=5,
),
sequence_number=2,
)
result.done()

events = await all_events(stream_agent_response(context, result))
assert events == [
ThreadItemUpdated(
item_id=item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
annotation_index=0,
annotation=Annotation(
source=FileSource(filename="container.txt", title="container.txt"),
index=3,
),
),
),
ThreadItemUpdated(
item_id=item_id,
update=AssistantMessageContentPartAnnotationAdded(
content_index=0,
annotation_index=1,
annotation=Annotation(
source=URLSource(
url="https://example.com",
title="Example",
),
index=5,
),
),
),
]


@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
context = AgentContext(
Expand Down Expand Up @@ -942,6 +1041,14 @@ async def test_stream_agent_response_assistant_message_content_types():
index=0,
filename="test.txt",
),
ResponsesAnnotationContainerFileCitation(
type="container_file_citation",
container_id="container_1",
file_id="f_456",
filename="container.txt",
start_index=0,
end_index=3,
),
ResponsesAnnotationURLCitation(
type="url_citation",
url="https://www.google.com",
Expand Down Expand Up @@ -994,6 +1101,13 @@ async def test_stream_agent_response_assistant_message_content_types():
),
index=0,
),
Annotation(
source=FileSource(
filename="container.txt",
title="container.txt",
),
index=3,
),
Annotation(
source=URLSource(
url="https://www.google.com",
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.