From 1e50d7a47debe06f0509b93aa9cd44599158f356 Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Tue, 7 Oct 2025 01:27:16 -0700 Subject: [PATCH] initial commit --- chatkit/actions.py | 2 +- chatkit/agents.py | 101 +++++++++++++++++++++++---------------------- chatkit/server.py | 57 ++++++++++++------------- 3 files changed, 81 insertions(+), 79 deletions(-) diff --git a/chatkit/actions.py b/chatkit/actions.py index 689a8f1..1658a14 100644 --- a/chatkit/actions.py +++ b/chatkit/actions.py @@ -23,7 +23,7 @@ class ActionConfig(BaseModel): class Action(BaseModel, Generic[TType, TPayload]): - type: TType = Field(default=TType, frozen=True) # pyright: ignore + type: TType = Field(frozen=True) # pyright: ignore payload: TPayload @classmethod diff --git a/chatkit/agents.py b/chatkit/agents.py index 349904a..df044f7 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -386,41 +386,41 @@ def end_workflow(item: WorkflowItem): async for event in _merge_generators(result.stream_events(), queue_iterator): # Events emitted from agent context helpers if isinstance(event, _EventWrapper): - event = event.event + unwrapped_event = event.event if ( - event.type == "thread.item.added" - or event.type == "thread.item.done" + unwrapped_event.type == "thread.item.added" + or unwrapped_event.type == "thread.item.done" ): # End the current workflow if visual item is added after it if ( ctx.workflow_item - and ctx.workflow_item.id != event.item.id - and event.item.type != "client_tool_call" - and event.item.type != "hidden_context_item" + and ctx.workflow_item.id != unwrapped_event.item.id + and unwrapped_event.item.type != "client_tool_call" + and unwrapped_event.item.type != "hidden_context_item" ): yield end_workflow(ctx.workflow_item) # track the current workflow if one is added if ( - event.type == "thread.item.added" - and event.item.type == "workflow" + unwrapped_event.type == "thread.item.added" + and unwrapped_event.item.type == "workflow" ): - ctx.workflow_item = event.item + ctx.workflow_item = unwrapped_event.item # track integration produced items so we can clean them up if # there is a guardrail tripwire - produced_items.add(event.item.id) - yield event + produced_items.add(unwrapped_event.item.id) + yield unwrapped_event continue if event.type == "run_item_stream_event": - event = event.item + run_item = event.item if ( - event.type == "tool_call_item" - and event.raw_item.type == "function_call" + run_item.type == "tool_call_item" + and run_item.raw_item.type == "function_call" ): - current_tool_call = event.raw_item.call_id - current_item_id = event.raw_item.id + current_tool_call = run_item.raw_item.call_id + current_item_id = run_item.raw_item.id assert current_item_id produced_items.add(current_item_id) continue @@ -430,42 +430,42 @@ def end_workflow(item: WorkflowItem): continue # Handle Responses events - event = event.data - if event.type == "response.content_part.added": - if event.part.type == "reasoning_text": + response_event = event.data + if response_event.type == "response.content_part.added": + if response_event.part.type == "reasoning_text": continue - content = _convert_content(event.part) + content = _convert_content(response_event.part) yield ThreadItemUpdated( - item_id=event.item_id, + item_id=response_event.item_id, update=AssistantMessageContentPartAdded( - content_index=event.content_index, + content_index=response_event.content_index, content=content, ), ) - elif event.type == "response.output_text.delta": + elif response_event.type == "response.output_text.delta": yield ThreadItemUpdated( - item_id=event.item_id, + item_id=response_event.item_id, update=AssistantMessageContentPartTextDelta( - content_index=event.content_index, - delta=event.delta, + content_index=response_event.content_index, + delta=response_event.delta, ), ) - elif event.type == "response.output_text.done": + elif response_event.type == "response.output_text.done": yield ThreadItemUpdated( - item_id=event.item_id, + item_id=response_event.item_id, update=AssistantMessageContentPartDone( - content_index=event.content_index, + content_index=response_event.content_index, content=AssistantMessageContent( - text=event.text, + text=response_event.text, annotations=[], ), ), ) - elif event.type == "response.output_text.annotation.added": + elif response_event.type == "response.output_text.annotation.added": # Ignore annotation-added events; annotations are reflected in the final item content. continue - elif event.type == "response.output_item.added": - item = event.item + elif response_event.type == "response.output_item.added": + item = response_event.item if item.type == "reasoning" and not ctx.workflow_item: ctx.workflow_item = WorkflowItem( id=ctx.generate_id("workflow"), @@ -488,7 +488,7 @@ def end_workflow(item: WorkflowItem): created_at=datetime.now(), ), ) - elif event.type == "response.reasoning_summary_text.delta": + elif response_event.type == "response.reasoning_summary_text.delta": if not ctx.workflow_item: continue @@ -498,9 +498,9 @@ def end_workflow(item: WorkflowItem): and len(ctx.workflow_item.workflow.tasks) == 0 ): streaming_thought = StreamingThoughtTracker( - item_id=event.item_id, - index=event.summary_index, - task=ThoughtTask(content=event.delta), + item_id=response_event.item_id, + index=response_event.summary_index, + task=ThoughtTask(content=response_event.delta), ) ctx.workflow_item.workflow.tasks.append(streaming_thought.task) yield ThreadItemUpdated( @@ -513,10 +513,10 @@ def end_workflow(item: WorkflowItem): elif ( streaming_thought and streaming_thought.task in ctx.workflow_item.workflow.tasks - and event.item_id == streaming_thought.item_id - and event.summary_index == streaming_thought.index + and response_event.item_id == streaming_thought.item_id + and response_event.summary_index == streaming_thought.index ): - streaming_thought.task.content += event.delta + streaming_thought.task.content += response_event.delta yield ThreadItemUpdated( item_id=ctx.workflow_item.id, update=WorkflowTaskUpdated( @@ -526,23 +526,24 @@ def end_workflow(item: WorkflowItem): ), ), ) - elif event.type == "response.reasoning_summary_text.done": + elif response_event.type == "response.reasoning_summary_text.done": if ctx.workflow_item: + update: WorkflowTaskUpdated | WorkflowTaskAdded if ( streaming_thought and streaming_thought.task in ctx.workflow_item.workflow.tasks - and event.item_id == streaming_thought.item_id - and event.summary_index == streaming_thought.index + and response_event.item_id == streaming_thought.item_id + and response_event.summary_index == streaming_thought.index ): task = streaming_thought.task - task.content = event.text + task.content = response_event.text streaming_thought = None update = WorkflowTaskUpdated( task=task, task_index=ctx.workflow_item.workflow.tasks.index(task), ) else: - task = ThoughtTask(content=event.text) + task = ThoughtTask(content=response_event.text) ctx.workflow_item.workflow.tasks.append(task) update = WorkflowTaskAdded( task=task, @@ -552,8 +553,8 @@ def end_workflow(item: WorkflowItem): item_id=ctx.workflow_item.id, update=update, ) - elif event.type == "response.output_item.done": - item = event.item + elif response_event.type == "response.output_item.done": + item = response_event.item if item.type == "message": produced_items.add(item.id) yield ThreadItemDoneEvent( @@ -614,7 +615,7 @@ def end_workflow(item: WorkflowItem): async def accumulate_text( events: AsyncIterator[StreamEvent], base_widget: TWidget, -) -> AsyncIterator[TWidget]: +) -> AsyncIterator[Markdown | Text]: text = "" yield base_widget async for event in events: @@ -698,7 +699,7 @@ def workflow_to_input( Convert a TaskItem into input item(s) to send to the model. Returns WorkflowItem.response_items by default. """ - messages = [] + messages: list[TResponseInputItem] = [] for task in item.workflow.tasks: if task.type != "custom" or (not task.title and not task.content): continue @@ -719,7 +720,7 @@ def workflow_to_input( role="user", ) ) - return messages + return messages if messages else None def widget_to_input( self, item: WidgetItem diff --git a/chatkit/server.py b/chatkit/server.py index fa405a2..fb92d9c 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -323,7 +323,7 @@ async def _process_non_streaming( return self._serialize(self._to_thread_response(thread)) case ThreadsListReq(): params = request.params - threads = await self.store.load_threads( + thread_meta_page = await self.store.load_threads( limit=params.limit or DEFAULT_PAGE_SIZE, after=params.after, order=params.order, @@ -331,10 +331,11 @@ async def _process_non_streaming( ) return self._serialize( Page( - has_more=threads.has_more, - after=threads.after, + has_more=thread_meta_page.has_more, + after=thread_meta_page.after, data=[ - self._to_thread_response(thread) for thread in threads.data + self._to_thread_response(thread_meta) + for thread_meta in thread_meta_page.data ], ) ) @@ -378,12 +379,12 @@ async def _process_non_streaming( ] return self._serialize(items) case ThreadsUpdateReq(): - thread = await self.store.load_thread( + thread_meta = await self.store.load_thread( request.params.thread_id, context=context ) - thread.title = request.params.title - await self.store.save_thread(thread, context=context) - return self._serialize(self._to_thread_response(thread)) + thread_meta.title = request.params.title + await self.store.save_thread(thread_meta, context=context) + return self._serialize(self._to_thread_response(thread_meta)) case ThreadsDeleteReq(): await self.store.delete_thread( request.params.thread_id, context=context @@ -426,25 +427,25 @@ async def _process_streaming_impl( yield event case ThreadsAddUserMessageReq(): - thread = await self.store.load_thread( + thread_meta = await self.store.load_thread( request.params.thread_id, context=context ) user_message = await self._build_user_message_item( - request.params.input, thread, context + request.params.input, thread_meta, context ) async for event in self._process_new_thread_item_respond( - thread, + thread_meta, user_message, context, ): yield event case ThreadsAddClientToolOutputReq(): - thread = await self.store.load_thread( + thread_meta = await self.store.load_thread( request.params.thread_id, context=context ) items = await self.store.load_thread_items( - thread.id, None, 1, "desc", context + thread_meta.id, None, 1, "desc", context ) tool_call = next( ( @@ -457,29 +458,29 @@ async def _process_streaming_impl( ) if not tool_call: raise ValueError( - f"Last thread item in {thread.id} was not a ClientToolCallItem" + f"Last thread item in {thread_meta.id} was not a ClientToolCallItem" ) tool_call.output = request.params.result tool_call.status = "completed" - await self.store.save_item(thread.id, tool_call, context=context) + await self.store.save_item(thread_meta.id, tool_call, context=context) # Safety against dangling pending tool calls if there are # multiple in a row, which should be impossible, and # integrations should ultimately filter out pending tool calls # when creating input response messages. - await self._cleanup_pending_client_tool_call(thread, context) + await self._cleanup_pending_client_tool_call(thread_meta, context) async for event in self._process_events( - thread, + thread_meta, context, - lambda: self.respond(thread, None, context), + lambda: self.respond(thread_meta, None, context), ): yield event case ThreadsRetryAfterItemReq(): - thread_metadata = await self.store.load_thread( + thread_meta = await self.store.load_thread( request.params.thread_id, context=context ) @@ -505,29 +506,29 @@ async def _process_streaming_impl( request.params.thread_id, item.id, context=context ) async for event in self._process_events( - thread_metadata, + thread_meta, context, lambda: self.respond( - thread_metadata, + thread_meta, user_message_item, context, ), ): yield event case ThreadsCustomActionReq(): - thread_metadata = await self.store.load_thread( + thread_meta = await self.store.load_thread( request.params.thread_id, context=context ) - item: ThreadItem | None = None + sender: ThreadItem | None = None if request.params.item_id: - item = await self.store.load_item( + sender = await self.store.load_item( request.params.thread_id, request.params.item_id, context=context, ) - if item and not isinstance(item, WidgetItem): + if sender and not isinstance(sender, WidgetItem): # shouldn't happen if the caller is using the API correctly. yield ErrorEvent( code=ErrorCode.STREAM_ERROR, @@ -536,12 +537,12 @@ async def _process_streaming_impl( return async for event in self._process_events( - thread_metadata, + thread_meta, context, lambda: self.action( - thread_metadata, + thread_meta, request.params.action, - item, + sender if isinstance(sender, WidgetItem) else None, context, ), ):