Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chatkit/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ActionConfig(BaseModel):


class Action(BaseModel, Generic[TType, TPayload]):
type: TType = Field(default=TType, frozen=True) # pyright: ignore
type: TType = Field(frozen=True) # pyright: ignore
payload: TPayload

@classmethod
Expand Down
101 changes: 51 additions & 50 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,41 +386,41 @@ def end_workflow(item: WorkflowItem):
async for event in _merge_generators(result.stream_events(), queue_iterator):
# Events emitted from agent context helpers
if isinstance(event, _EventWrapper):
event = event.event
unwrapped_event = event.event
if (
event.type == "thread.item.added"
or event.type == "thread.item.done"
unwrapped_event.type == "thread.item.added"
or unwrapped_event.type == "thread.item.done"
):
# End the current workflow if visual item is added after it
if (
ctx.workflow_item
and ctx.workflow_item.id != event.item.id
and event.item.type != "client_tool_call"
and event.item.type != "hidden_context_item"
and ctx.workflow_item.id != unwrapped_event.item.id
and unwrapped_event.item.type != "client_tool_call"
and unwrapped_event.item.type != "hidden_context_item"
):
yield end_workflow(ctx.workflow_item)

# track the current workflow if one is added
if (
event.type == "thread.item.added"
and event.item.type == "workflow"
unwrapped_event.type == "thread.item.added"
and unwrapped_event.item.type == "workflow"
):
ctx.workflow_item = event.item
ctx.workflow_item = unwrapped_event.item

# track integration produced items so we can clean them up if
# there is a guardrail tripwire
produced_items.add(event.item.id)
yield event
produced_items.add(unwrapped_event.item.id)
yield unwrapped_event
continue

if event.type == "run_item_stream_event":
event = event.item
run_item = event.item
if (
event.type == "tool_call_item"
and event.raw_item.type == "function_call"
run_item.type == "tool_call_item"
and run_item.raw_item.type == "function_call"
):
current_tool_call = event.raw_item.call_id
current_item_id = event.raw_item.id
current_tool_call = run_item.raw_item.call_id
current_item_id = run_item.raw_item.id
assert current_item_id
produced_items.add(current_item_id)
continue
Expand All @@ -430,42 +430,42 @@ def end_workflow(item: WorkflowItem):
continue

# Handle Responses events
event = event.data
if event.type == "response.content_part.added":
if event.part.type == "reasoning_text":
response_event = event.data
if response_event.type == "response.content_part.added":
if response_event.part.type == "reasoning_text":
continue
content = _convert_content(event.part)
content = _convert_content(response_event.part)
yield ThreadItemUpdated(
item_id=event.item_id,
item_id=response_event.item_id,
update=AssistantMessageContentPartAdded(
content_index=event.content_index,
content_index=response_event.content_index,
content=content,
),
)
elif event.type == "response.output_text.delta":
elif response_event.type == "response.output_text.delta":
yield ThreadItemUpdated(
item_id=event.item_id,
item_id=response_event.item_id,
update=AssistantMessageContentPartTextDelta(
content_index=event.content_index,
delta=event.delta,
content_index=response_event.content_index,
delta=response_event.delta,
),
)
elif event.type == "response.output_text.done":
elif response_event.type == "response.output_text.done":
yield ThreadItemUpdated(
item_id=event.item_id,
item_id=response_event.item_id,
update=AssistantMessageContentPartDone(
content_index=event.content_index,
content_index=response_event.content_index,
content=AssistantMessageContent(
text=event.text,
text=response_event.text,
annotations=[],
),
),
)
elif event.type == "response.output_text.annotation.added":
elif response_event.type == "response.output_text.annotation.added":
# Ignore annotation-added events; annotations are reflected in the final item content.
continue
elif event.type == "response.output_item.added":
item = event.item
elif response_event.type == "response.output_item.added":
item = response_event.item
if item.type == "reasoning" and not ctx.workflow_item:
ctx.workflow_item = WorkflowItem(
id=ctx.generate_id("workflow"),
Expand All @@ -488,7 +488,7 @@ def end_workflow(item: WorkflowItem):
created_at=datetime.now(),
),
)
elif event.type == "response.reasoning_summary_text.delta":
elif response_event.type == "response.reasoning_summary_text.delta":
if not ctx.workflow_item:
continue

Expand All @@ -498,9 +498,9 @@ def end_workflow(item: WorkflowItem):
and len(ctx.workflow_item.workflow.tasks) == 0
):
streaming_thought = StreamingThoughtTracker(
item_id=event.item_id,
index=event.summary_index,
task=ThoughtTask(content=event.delta),
item_id=response_event.item_id,
index=response_event.summary_index,
task=ThoughtTask(content=response_event.delta),
)
ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
yield ThreadItemUpdated(
Expand All @@ -513,10 +513,10 @@ def end_workflow(item: WorkflowItem):
elif (
streaming_thought
and streaming_thought.task in ctx.workflow_item.workflow.tasks
and event.item_id == streaming_thought.item_id
and event.summary_index == streaming_thought.index
and response_event.item_id == streaming_thought.item_id
and response_event.summary_index == streaming_thought.index
):
streaming_thought.task.content += event.delta
streaming_thought.task.content += response_event.delta
yield ThreadItemUpdated(
item_id=ctx.workflow_item.id,
update=WorkflowTaskUpdated(
Expand All @@ -526,23 +526,24 @@ def end_workflow(item: WorkflowItem):
),
),
)
elif event.type == "response.reasoning_summary_text.done":
elif response_event.type == "response.reasoning_summary_text.done":
if ctx.workflow_item:
update: WorkflowTaskUpdated | WorkflowTaskAdded
if (
streaming_thought
and streaming_thought.task in ctx.workflow_item.workflow.tasks
and event.item_id == streaming_thought.item_id
and event.summary_index == streaming_thought.index
and response_event.item_id == streaming_thought.item_id
and response_event.summary_index == streaming_thought.index
):
task = streaming_thought.task
task.content = event.text
task.content = response_event.text
streaming_thought = None
update = WorkflowTaskUpdated(
task=task,
task_index=ctx.workflow_item.workflow.tasks.index(task),
)
else:
task = ThoughtTask(content=event.text)
task = ThoughtTask(content=response_event.text)
ctx.workflow_item.workflow.tasks.append(task)
update = WorkflowTaskAdded(
task=task,
Expand All @@ -552,8 +553,8 @@ def end_workflow(item: WorkflowItem):
item_id=ctx.workflow_item.id,
update=update,
)
elif event.type == "response.output_item.done":
item = event.item
elif response_event.type == "response.output_item.done":
item = response_event.item
if item.type == "message":
produced_items.add(item.id)
yield ThreadItemDoneEvent(
Expand Down Expand Up @@ -614,7 +615,7 @@ def end_workflow(item: WorkflowItem):
async def accumulate_text(
events: AsyncIterator[StreamEvent],
base_widget: TWidget,
) -> AsyncIterator[TWidget]:
) -> AsyncIterator[Markdown | Text]:
text = ""
yield base_widget
async for event in events:
Expand Down Expand Up @@ -698,7 +699,7 @@ def workflow_to_input(
Convert a TaskItem into input item(s) to send to the model.
Returns WorkflowItem.response_items by default.
"""
messages = []
messages: list[TResponseInputItem] = []
for task in item.workflow.tasks:
if task.type != "custom" or (not task.title and not task.content):
continue
Expand All @@ -719,7 +720,7 @@ def workflow_to_input(
role="user",
)
)
return messages
return messages if messages else None

def widget_to_input(
self, item: WidgetItem
Expand Down
57 changes: 29 additions & 28 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,19 @@ async def _process_non_streaming(
return self._serialize(self._to_thread_response(thread))
case ThreadsListReq():
params = request.params
threads = await self.store.load_threads(
thread_meta_page = await self.store.load_threads(
limit=params.limit or DEFAULT_PAGE_SIZE,
after=params.after,
order=params.order,
context=context,
)
return self._serialize(
Page(
has_more=threads.has_more,
after=threads.after,
has_more=thread_meta_page.has_more,
after=thread_meta_page.after,
data=[
self._to_thread_response(thread) for thread in threads.data
self._to_thread_response(thread_meta)
for thread_meta in thread_meta_page.data
],
)
)
Expand Down Expand Up @@ -388,12 +389,12 @@ async def _process_non_streaming(
]
return self._serialize(items)
case ThreadsUpdateReq():
thread = await self.store.load_thread(
thread_meta = await self.store.load_thread(
request.params.thread_id, context=context
)
thread.title = request.params.title
await self.store.save_thread(thread, context=context)
return self._serialize(self._to_thread_response(thread))
thread_meta.title = request.params.title
await self.store.save_thread(thread_meta, context=context)
return self._serialize(self._to_thread_response(thread_meta))
case ThreadsDeleteReq():
await self.store.delete_thread(
request.params.thread_id, context=context
Expand Down Expand Up @@ -436,25 +437,25 @@ async def _process_streaming_impl(
yield event

case ThreadsAddUserMessageReq():
thread = await self.store.load_thread(
thread_meta = await self.store.load_thread(
request.params.thread_id, context=context
)
user_message = await self._build_user_message_item(
request.params.input, thread, context
request.params.input, thread_meta, context
)
async for event in self._process_new_thread_item_respond(
thread,
thread_meta,
user_message,
context,
):
yield event

case ThreadsAddClientToolOutputReq():
thread = await self.store.load_thread(
thread_meta = await self.store.load_thread(
request.params.thread_id, context=context
)
items = await self.store.load_thread_items(
thread.id, None, 1, "desc", context
thread_meta.id, None, 1, "desc", context
)
tool_call = next(
(
Expand All @@ -467,29 +468,29 @@ async def _process_streaming_impl(
)
if not tool_call:
raise ValueError(
f"Last thread item in {thread.id} was not a ClientToolCallItem"
f"Last thread item in {thread_meta.id} was not a ClientToolCallItem"
)

tool_call.output = request.params.result
tool_call.status = "completed"

await self.store.save_item(thread.id, tool_call, context=context)
await self.store.save_item(thread_meta.id, tool_call, context=context)

# Safety against dangling pending tool calls if there are
# multiple in a row, which should be impossible, and
# integrations should ultimately filter out pending tool calls
# when creating input response messages.
await self._cleanup_pending_client_tool_call(thread, context)
await self._cleanup_pending_client_tool_call(thread_meta, context)

async for event in self._process_events(
thread,
thread_meta,
context,
lambda: self.respond(thread, None, context),
lambda: self.respond(thread_meta, None, context),
):
yield event

case ThreadsRetryAfterItemReq():
thread_metadata = await self.store.load_thread(
thread_meta = await self.store.load_thread(
request.params.thread_id, context=context
)

Expand All @@ -515,29 +516,29 @@ async def _process_streaming_impl(
request.params.thread_id, item.id, context=context
)
async for event in self._process_events(
thread_metadata,
thread_meta,
context,
lambda: self.respond(
thread_metadata,
thread_meta,
user_message_item,
context,
),
):
yield event
case ThreadsCustomActionReq():
thread_metadata = await self.store.load_thread(
thread_meta = await self.store.load_thread(
request.params.thread_id, context=context
)

item: ThreadItem | None = None
sender: ThreadItem | None = None
if request.params.item_id:
item = await self.store.load_item(
sender = await self.store.load_item(
request.params.thread_id,
request.params.item_id,
context=context,
)

if item and not isinstance(item, WidgetItem):
if sender and not isinstance(sender, WidgetItem):
# shouldn't happen if the caller is using the API correctly.
yield ErrorEvent(
code=ErrorCode.STREAM_ERROR,
Expand All @@ -546,12 +547,12 @@ async def _process_streaming_impl(
return

async for event in self._process_events(
thread_metadata,
thread_meta,
context,
lambda: self.action(
thread_metadata,
thread_meta,
request.params.action,
item,
sender if isinstance(sender, WidgetItem) else None,
context,
),
):
Expand Down