From eee0a2f7b6f2b34513ed04cfd5293958d30638d0 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Fri, 21 Nov 2025 21:37:04 -0800 Subject: [PATCH 1/9] _process_events handles cancelled stream; add an overridable method for custom cleanup after cancellation --- chatkit/agents.py | 3 ++ chatkit/server.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++ chatkit/store.py | 3 +- 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/chatkit/agents.py b/chatkit/agents.py index 9f6d48e..61b2790 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -691,6 +691,9 @@ async def hidden_context_to_input( """ Convert a HiddenContextItem into input item(s) to send to the model. Required to override when HiddenContextItems with non-string content are used. + + ChatKitServer may save HiddenContextItems with string content; make sure your override + can also handle HiddenContextItems with string content. """ if not isinstance(item.content, str): raise NotImplementedError( diff --git a/chatkit/server.py b/chatkit/server.py index 66d7b58..b1ba2b2 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -27,6 +27,12 @@ from .store import AttachmentStore, Store, StoreItemType, default_generate_id from .types import ( Action, + AssistantMessageContent, + AssistantMessageContentPartAdded, + AssistantMessageContentPartAnnotationAdded, + AssistantMessageContentPartDone, + AssistantMessageContentPartTextDelta, + AssistantMessageItem, AttachmentsCreateReq, AttachmentsDeleteReq, ChatKitReq, @@ -47,6 +53,7 @@ ThreadItemDoneEvent, ThreadItemRemovedEvent, ThreadItemReplacedEvent, + ThreadItemUpdate, ThreadItemUpdatedEvent, ThreadMetadata, ThreadsAddClientToolOutputReq, @@ -66,6 +73,7 @@ WidgetItem, WidgetRootUpdated, WidgetStreamingTextValueDelta, + WorkflowItem, is_streaming_req, ) from .version import __version__ @@ -308,6 +316,25 @@ def action( "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions" ) + async def handle_stream_cancelled( + self, + thread: ThreadMetadata, + pending_items: list[ThreadItem], + context: TContext, + ): + """Perform custom cleanup / stop inference when a stream is cancelled. + + Args: + thread: The thread that was being processed. + pending_items: Items that were not done streaming at cancellation time. + By default, already-streamed assistant messages, widgets, and workflows are + saved to the store during error handling prior to this method being called. + If you want to remove them from the thread, you can do so here. + (Updates you make here will not be reflected in the UI until a reload.) + context: Arbitrary per-request context provided by the caller. + """ + pass + async def process( self, request: str | bytes | bytearray, context: TContext ) -> StreamingResult | NonStreamingResult: @@ -408,6 +435,9 @@ async def _process_streaming( async for event in self._process_streaming_impl(request, context): b = self._serialize(event) yield b"data: " + b + b"\n\n" + except asyncio.CancelledError: + # Let cancellation bubble up without logging as an error. + raise except Exception: logger.exception("Error while generating streamed response") raise @@ -606,22 +636,39 @@ async def _process_events( last_thread = thread.model_copy(deep=True) + # Keep track of items that were streamed but not yet saved + # so that we can persist them when the stream is cancelled. + pending_items: dict[str, ThreadItem] = {} + try: with agents_sdk_user_agent_override(): async for event in stream(): + if isinstance(event, ThreadItemAddedEvent): + pending_items[event.item.id] = event.item + match event: case ThreadItemDoneEvent(): await self.store.add_thread_item( thread.id, event.item, context=context ) + pending_items.pop(event.item.id, None) case ThreadItemRemovedEvent(): await self.store.delete_thread_item( thread.id, event.item_id, context=context ) + pending_items.pop(event.item_id, None) case ThreadItemReplacedEvent(): await self.store.save_item( thread.id, event.item, context=context ) + pending_items.pop(event.item.id, None) + case ThreadItemUpdatedEvent(): + # Keep the pending assistant message item up to date + # so that we can persist already-streamed partial content + # if the stream is cancelled. + self._update_pending_assistant_message_items( + pending_items, event + ) # special case - don't send hidden context items back to the client should_swallow_event = isinstance( @@ -643,6 +690,15 @@ async def _process_events( last_thread = thread.model_copy(deep=True) await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) + except asyncio.CancelledError: + # When a stream is cancelled, whether it's a deliberate stop request or due to a network issue, + # save already-streamed items to the thread. + await self._persist_cancelled_stream_state(thread, pending_items, context) + # Allow custom cleanup. + await self.handle_stream_cancelled( + thread, list(pending_items.values()), context + ) + raise except CustomStreamError as e: yield ErrorEvent( code="custom", @@ -666,6 +722,80 @@ async def _process_events( await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) + async def _persist_cancelled_stream_state( + self, + thread: ThreadMetadata, + pending_items: dict[str, ThreadItem], + context: TContext, + ): + # Persist any streamed items that the UI should keep when cancellation happens mid-stream. + for item in pending_items.values(): + if isinstance(item, (AssistantMessageItem, WidgetItem, WorkflowItem)): + await self.store.add_thread_item(thread.id, item, context=context) + + await self.store.add_thread_item( + thread.id, + HiddenContextItem( + thread_id=thread.id, + created_at=datetime.now(), + id=self.store.generate_item_id("hidden_context", thread, context), + content="SYSTEM: The user cancelled the stream.", + ), + context=context, + ) + + def _apply_assistant_message_update( + self, + item: AssistantMessageItem, + update: AssistantMessageContentPartAdded + | AssistantMessageContentPartTextDelta + | AssistantMessageContentPartAnnotationAdded + | AssistantMessageContentPartDone, + ) -> AssistantMessageItem: + updated = item.model_copy(deep=True) + + # Pad the content list so the requested content_index exists before we write into it. + # (Streaming updates can arrive for an index that hasn’t been created yet) + while len(updated.content) <= update.content_index: + updated.content.append(AssistantMessageContent(text="", annotations=[])) + + match update: + case AssistantMessageContentPartAdded(): + updated.content[update.content_index] = update.content + case AssistantMessageContentPartTextDelta(): + updated.content[update.content_index].text += update.delta + case AssistantMessageContentPartAnnotationAdded(): + annotations = updated.content[update.content_index].annotations + if update.annotation_index <= len(annotations): + annotations.insert(update.annotation_index, update.annotation) + else: + annotations.append(update.annotation) + case AssistantMessageContentPartDone(): + updated.content[update.content_index] = update.content + return updated + + def _update_pending_assistant_message_items( + self, + pending_items: dict[str, ThreadItem], + event: ThreadItemUpdatedEvent, + ): + if not isinstance( + event.update, + ( + AssistantMessageContentPartAdded, + AssistantMessageContentPartTextDelta, + AssistantMessageContentPartAnnotationAdded, + AssistantMessageContentPartDone, + ), + ): + return + + updated_item = pending_items.get(event.item_id) + if updated_item and isinstance(updated_item, AssistantMessageItem): + pending_items[updated_item.id] = self._apply_assistant_message_update( + updated_item, event.update + ) + async def _build_user_message_item( self, input: UserMessageInput, thread: ThreadMetadata, context: TContext ) -> UserMessageItem: diff --git a/chatkit/store.py b/chatkit/store.py index 1c907c8..5ad6e26 100644 --- a/chatkit/store.py +++ b/chatkit/store.py @@ -15,7 +15,7 @@ TContext = TypeVar("TContext", default=Any) StoreItemType = Literal[ - "thread", "message", "tool_call", "task", "workflow", "attachment" + "thread", "message", "tool_call", "task", "workflow", "attachment", "hidden_context" ] @@ -26,6 +26,7 @@ "workflow": "wf", "task": "tsk", "attachment": "atc", + "hidden_context": "hcx", } From db36b662163e2e55ec18597210669d97190a8044 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Fri, 21 Nov 2025 21:48:58 -0800 Subject: [PATCH 2/9] Added a unit test --- tests/test_chatkit_server.py | 77 +++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 3526532..28d3eda 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -1,3 +1,4 @@ +import asyncio import sqlite3 from contextlib import contextmanager from datetime import datetime @@ -15,9 +16,15 @@ StreamingResult, stream_widget, ) -from chatkit.store import AttachmentStore, NotFoundError +from chatkit.store import ( + AttachmentStore, + NotFoundError, + StoreItemType, + default_generate_id, +) from chatkit.types import ( AssistantMessageContent, + AssistantMessageContentPartTextDelta, AssistantMessageItem, Attachment, AttachmentCreateParams, @@ -205,6 +212,74 @@ async def process_non_streaming(self, request_obj, context: Any | None = None): db.close() +async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + yield ThreadItemAddedEvent( + item=AssistantMessageItem( + id="assistant-message-pending", + created_at=datetime.now(), + content=[AssistantMessageContent(text="Hello, ")], + thread_id=thread.id, + ) + ) + yield ThreadItemUpdatedEvent( + item_id="assistant-message-pending", + update=AssistantMessageContentPartTextDelta( + content_index=0, + delta="World!", + ), + ) + raise asyncio.CancelledError() + + with make_server(responder) as server: + # Allow hidden_context id generation in this test store + original_generate_item_id = server.store.generate_item_id + + def generate_item_id( + item_type: StoreItemType, thread: ThreadMetadata, context: Any + ): + if item_type == "hidden_context": + return default_generate_id("hidden_context") + return original_generate_item_id(item_type, thread, context) + + server.store.generate_item_id = generate_item_id # type: ignore[method-assign] + + stream = await server.process( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ).model_dump_json(), + DEFAULT_CONTEXT, + ) + assert isinstance(stream, StreamingResult) + + events: list[ThreadStreamEvent] = [] + with pytest.raises(asyncio.CancelledError): # noqa: PT012 + async for raw in stream.json_events: + events.append(decode_event(raw)) + + thread = next(e.thread for e in events if e.type == "thread.created") + items = await server.store.load_thread_items( + thread.id, None, 1, "desc", DEFAULT_CONTEXT + ) + hidden_context_item = items.data[-1] + assert hidden_context_item.type == "hidden_context_item" + assert hidden_context_item.content == "SYSTEM: The user cancelled the stream." + + assistant_message_item = await server.store.load_item( + thread.id, "assistant-message-pending", DEFAULT_CONTEXT + ) + assert assistant_message_item.type == "assistant_message" + assert assistant_message_item.content[0].text == "Hello, World!" + + async def test_flows_context_to_responder(): responder_context = None add_feedback_context = None From 942620a2280b97946d2f8c7d573b227e716328e2 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Fri, 21 Nov 2025 21:49:58 -0800 Subject: [PATCH 3/9] types --- chatkit/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatkit/server.py b/chatkit/server.py index b1ba2b2..ff988b0 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -53,7 +53,6 @@ ThreadItemDoneEvent, ThreadItemRemovedEvent, ThreadItemReplacedEvent, - ThreadItemUpdate, ThreadItemUpdatedEvent, ThreadMetadata, ThreadsAddClientToolOutputReq, From e29b09e6f4b8b1a944669225924d61d02b6c504f Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Sat, 22 Nov 2025 09:54:23 -0800 Subject: [PATCH 4/9] Don't persist empty pending items --- chatkit/server.py | 17 ++++++++++- tests/test_chatkit_server.py | 59 ++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/chatkit/server.py b/chatkit/server.py index ff988b0..2cc2760 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -729,7 +729,9 @@ async def _persist_cancelled_stream_state( ): # Persist any streamed items that the UI should keep when cancellation happens mid-stream. for item in pending_items.values(): - if isinstance(item, (AssistantMessageItem, WidgetItem, WorkflowItem)): + if isinstance( + item, (AssistantMessageItem, WidgetItem, WorkflowItem) + ) and not self._is_streamed_item_empty(item): await self.store.add_thread_item(thread.id, item, context=context) await self.store.add_thread_item( @@ -743,6 +745,19 @@ async def _persist_cancelled_stream_state( context=context, ) + def _is_streamed_item_empty( + self, item: AssistantMessageItem | WorkflowItem | WidgetItem + ) -> bool: + if isinstance(item, AssistantMessageItem): + return len(item.content) == 0 or all( + (not content.text.strip()) for content in item.content + ) + if isinstance(item, WorkflowItem): + return len(item.workflow.tasks) == 0 and item.workflow.summary is None + + # Assume all WidgetItems are not empty + return False + def _apply_assistant_message_update( self, item: AssistantMessageItem, diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 28d3eda..817cc54 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -280,6 +280,65 @@ def generate_item_id( assert assistant_message_item.content[0].text == "Hello, World!" +async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + yield ThreadItemAddedEvent( + item=AssistantMessageItem( + id="assistant-message-pending", + created_at=datetime.now(), + content=[], + thread_id=thread.id, + ) + ) + raise asyncio.CancelledError() + + with make_server(responder) as server: + original_generate_item_id = server.store.generate_item_id + + def generate_item_id( + item_type: StoreItemType, thread: ThreadMetadata, context: Any + ): + if item_type == "hidden_context": + return default_generate_id("hidden_context") + return original_generate_item_id(item_type, thread, context) + + server.store.generate_item_id = generate_item_id # type: ignore[method-assign] + + stream = await server.process( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ).model_dump_json(), + DEFAULT_CONTEXT, + ) + assert isinstance(stream, StreamingResult) + + events: list[ThreadStreamEvent] = [] + with pytest.raises(asyncio.CancelledError): # noqa: PT012 + async for raw in stream.json_events: + events.append(decode_event(raw)) + + thread = next(e.thread for e in events if e.type == "thread.created") + items = await server.store.load_thread_items( + thread.id, None, 1, "desc", DEFAULT_CONTEXT + ) + hidden_context_item = items.data[-1] + assert hidden_context_item.type == "hidden_context_item" + assert hidden_context_item.content == "SYSTEM: The user cancelled the stream." + + with pytest.raises(NotFoundError): + await server.store.load_item( + thread.id, "assistant-message-pending", DEFAULT_CONTEXT + ) + + async def test_flows_context_to_responder(): responder_context = None add_feedback_context = None From 5a7244ba1337ba55dcb1f9e4a96cc5527b00bc93 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 24 Nov 2025 14:45:26 -0800 Subject: [PATCH 5/9] use new type of hidden context item; add stream options --- chatkit/agents.py | 30 +++++++++- chatkit/server.py | 145 ++++++++++++++++++++++++---------------------- chatkit/store.py | 10 +++- chatkit/types.py | 31 +++++++++- 4 files changed, 142 insertions(+), 74 deletions(-) diff --git a/chatkit/agents.py b/chatkit/agents.py index 61b2790..d810331 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -56,6 +56,7 @@ EndOfTurnItem, FileSource, HiddenContextItem, + SDKHiddenContextItem, Task, TaskItem, ThoughtTask, @@ -691,9 +692,6 @@ async def hidden_context_to_input( """ Convert a HiddenContextItem into input item(s) to send to the model. Required to override when HiddenContextItems with non-string content are used. - - ChatKitServer may save HiddenContextItems with string content; make sure your override - can also handle HiddenContextItems with string content. """ if not isinstance(item.content, str): raise NotImplementedError( @@ -715,6 +713,29 @@ async def hidden_context_to_input( role="user", ) + async def sdk_hidden_context_to_input( + self, item: SDKHiddenContextItem + ) -> TResponseInputItem | list[TResponseInputItem] | None: + """ + Convert a SDKHiddenContextItem into input item to send to the model. + This is used by the ChatKit Python SDK for storing additional context + for internal operations; you shouldn't need to override this. + """ + text = ( + "Hidden ChatKit SDK context for the agent (not shown to the user):\n" + f"\n{item.content}\n" + ) + return Message( + type="message", + content=[ + ResponseInputTextParam( + type="input_text", + text=text, + ) + ], + role="user", + ) + async def task_to_input( self, item: TaskItem ) -> TResponseInputItem | list[TResponseInputItem] | None: @@ -951,6 +972,9 @@ async def _thread_item_to_input_item( case HiddenContextItem(): out = await self.hidden_context_to_input(item) or [] return out if isinstance(out, list) else [out] + case SDKHiddenContextItem(): + out = await self.sdk_hidden_context_to_input(item) or [] + return out if isinstance(out, list) else [out] case _: assert_never(item) diff --git a/chatkit/server.py b/chatkit/server.py index 2cc2760..5097304 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -45,7 +45,10 @@ ItemsListReq, NonStreamingReq, Page, + SDKHiddenContextItem, StreamingReq, + StreamOptions, + StreamOptionsEvent, Thread, ThreadCreatedEvent, ThreadItem, @@ -73,6 +76,8 @@ WidgetRootUpdated, WidgetStreamingTextValueDelta, WorkflowItem, + WorkflowTaskAdded, + WorkflowTaskUpdated, is_streaming_req, ) from .version import __version__ @@ -315,6 +320,15 @@ def action( "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions" ) + def get_stream_options( + self, thread: ThreadMetadata, context: TContext + ) -> StreamOptions: + """ + Return stream-level runtime options. Allows the user to cancel the stream by default. + Override this method to customize behavior. + """ + return StreamOptions(allow_cancel=True) + async def handle_stream_cancelled( self, thread: ThreadMetadata, @@ -322,17 +336,39 @@ async def handle_stream_cancelled( context: TContext, ): """Perform custom cleanup / stop inference when a stream is cancelled. + Updates you make here will not be reflected in the UI until a reload. + + The default implementation persists any non-empty pending assistant messages + to the thread but does not auto-save pending widget items or workflow items. Args: thread: The thread that was being processed. pending_items: Items that were not done streaming at cancellation time. - By default, already-streamed assistant messages, widgets, and workflows are - saved to the store during error handling prior to this method being called. - If you want to remove them from the thread, you can do so here. - (Updates you make here will not be reflected in the UI until a reload.) context: Arbitrary per-request context provided by the caller. """ - pass + pending_assistant_message_items: list[AssistantMessageItem] = [ + item for item in pending_items if isinstance(item, AssistantMessageItem) + ] + for item in pending_assistant_message_items: + is_empty = len(item.content) == 0 or all( + (not content.text.strip()) for content in item.content + ) + if not is_empty: + await self.store.add_thread_item(thread.id, item, context=context) + + # Add a hidden context item to the thread to indicate that the stream was cancelled. + # Otherwise, depending on the timing of the cancellation, subsequent responses may + # attempt to continue the cancelled response. + await self.store.add_thread_item( + thread.id, + SDKHiddenContextItem( + thread_id=thread.id, + created_at=datetime.now(), + id=self.store.generate_item_id("sdk_hidden_context", thread, context), + content="The user cancelled the stream.", + ), + context=context, + ) async def process( self, request: str | bytes | bytearray, context: TContext @@ -633,6 +669,11 @@ async def _process_events( ) -> AsyncIterator[ThreadStreamEvent]: await asyncio.sleep(0) # allow the response to start streaming + # Send initial stream options + yield StreamOptionsEvent( + stream_options=self.get_stream_options(thread, context) + ) + last_thread = thread.model_copy(deep=True) # Keep track of items that were streamed but not yet saved @@ -662,12 +703,10 @@ async def _process_events( ) pending_items.pop(event.item.id, None) case ThreadItemUpdatedEvent(): - # Keep the pending assistant message item up to date - # so that we can persist already-streamed partial content - # if the stream is cancelled. - self._update_pending_assistant_message_items( - pending_items, event - ) + # Keep pending assistant message and workflow items up to date + # so that we have a reference to the latest version of these pending items + # when the stream is cancelled. + self._update_pending_items(pending_items, event) # special case - don't send hidden context items back to the client should_swallow_event = isinstance( @@ -690,10 +729,6 @@ async def _process_events( await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) except asyncio.CancelledError: - # When a stream is cancelled, whether it's a deliberate stop request or due to a network issue, - # save already-streamed items to the thread. - await self._persist_cancelled_stream_state(thread, pending_items, context) - # Allow custom cleanup. await self.handle_stream_cancelled( thread, list(pending_items.values()), context ) @@ -721,43 +756,6 @@ async def _process_events( await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) - async def _persist_cancelled_stream_state( - self, - thread: ThreadMetadata, - pending_items: dict[str, ThreadItem], - context: TContext, - ): - # Persist any streamed items that the UI should keep when cancellation happens mid-stream. - for item in pending_items.values(): - if isinstance( - item, (AssistantMessageItem, WidgetItem, WorkflowItem) - ) and not self._is_streamed_item_empty(item): - await self.store.add_thread_item(thread.id, item, context=context) - - await self.store.add_thread_item( - thread.id, - HiddenContextItem( - thread_id=thread.id, - created_at=datetime.now(), - id=self.store.generate_item_id("hidden_context", thread, context), - content="SYSTEM: The user cancelled the stream.", - ), - context=context, - ) - - def _is_streamed_item_empty( - self, item: AssistantMessageItem | WorkflowItem | WidgetItem - ) -> bool: - if isinstance(item, AssistantMessageItem): - return len(item.content) == 0 or all( - (not content.text.strip()) for content in item.content - ) - if isinstance(item, WorkflowItem): - return len(item.workflow.tasks) == 0 and item.workflow.summary is None - - # Assume all WidgetItems are not empty - return False - def _apply_assistant_message_update( self, item: AssistantMessageItem, @@ -788,27 +786,38 @@ def _apply_assistant_message_update( updated.content[update.content_index] = update.content return updated - def _update_pending_assistant_message_items( + def _update_pending_items( self, pending_items: dict[str, ThreadItem], event: ThreadItemUpdatedEvent, ): - if not isinstance( - event.update, - ( - AssistantMessageContentPartAdded, - AssistantMessageContentPartTextDelta, - AssistantMessageContentPartAnnotationAdded, - AssistantMessageContentPartDone, - ), - ): - return - updated_item = pending_items.get(event.item_id) - if updated_item and isinstance(updated_item, AssistantMessageItem): - pending_items[updated_item.id] = self._apply_assistant_message_update( - updated_item, event.update - ) + update = event.update + match updated_item: + case AssistantMessageItem(): + if isinstance( + update, + ( + AssistantMessageContentPartAdded, + AssistantMessageContentPartTextDelta, + AssistantMessageContentPartAnnotationAdded, + AssistantMessageContentPartDone, + ), + ): + pending_items[updated_item.id] = ( + self._apply_assistant_message_update(updated_item, update) + ) + case WorkflowItem(): + if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)): + match update: + case WorkflowTaskUpdated(): + updated_item.workflow.tasks[update.task_index] = update.task + case WorkflowTaskAdded(): + updated_item.workflow.tasks.append(update.task) + + pending_items[updated_item.id] = updated_item + case _: + pass async def _build_user_message_item( self, input: UserMessageInput, thread: ThreadMetadata, context: TContext diff --git a/chatkit/store.py b/chatkit/store.py index 5ad6e26..90be0b0 100644 --- a/chatkit/store.py +++ b/chatkit/store.py @@ -15,7 +15,13 @@ TContext = TypeVar("TContext", default=Any) StoreItemType = Literal[ - "thread", "message", "tool_call", "task", "workflow", "attachment", "hidden_context" + "thread", + "message", + "tool_call", + "task", + "workflow", + "attachment", + "sdk_hidden_context", ] @@ -26,7 +32,7 @@ "workflow": "wf", "task": "tsk", "attachment": "atc", - "hidden_context": "hcx", + "sdk_hidden_context": "shcx", } diff --git a/chatkit/types.py b/chatkit/types.py index eb7b7f3..3ba2a1c 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -317,6 +317,20 @@ class ThreadItemReplacedEvent(BaseModel): item: ThreadItem +class StreamOptions(BaseModel): + """Settings that control runtime stream behavior.""" + + allow_cancel: bool + """Allow the client to request cancellation mid-stream.""" + + +class StreamOptionsEvent(BaseModel): + """Event emitted to set stream options at runtime.""" + + type: Literal["stream_options"] = "stream_options" + stream_options: StreamOptions + + class ProgressUpdateEvent(BaseModel): """Event providing incremental progress from the assistant.""" @@ -354,6 +368,7 @@ class NoticeEvent(BaseModel): | ThreadItemUpdated | ThreadItemRemovedEvent | ThreadItemReplacedEvent + | StreamOptionsEvent | ProgressUpdateEvent | ErrorEvent | NoticeEvent, @@ -576,12 +591,25 @@ class EndOfTurnItem(ThreadItemBase): class HiddenContextItem(ThreadItemBase): - """HiddenContext is never sent to the client. It's not officially part of ChatKit. It is only used internally to store additional context in a specific place in the thread.""" + """ + HiddenContext is never sent to the client. It's not officially part of ChatKit.js. + It is only used internally to store additional context in a specific place in the thread. + """ type: Literal["hidden_context_item"] = "hidden_context_item" content: Any +class SDKHiddenContextItem(ThreadItemBase): + """ + Hidden context that is used by the ChatKit Python SDK for storing additional context + for internal operations. + """ + + type: Literal["sdk_hidden_context"] = "sdk_hidden_context" + content: str + + ThreadItem = Annotated[ UserMessageItem | AssistantMessageItem @@ -590,6 +618,7 @@ class HiddenContextItem(ThreadItemBase): | WorkflowItem | TaskItem | HiddenContextItem + | SDKHiddenContextItem | EndOfTurnItem, Field(discriminator="type"), ] From 4a4b81b082c0f5f4ea503656cb1207224e76b3f0 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 24 Nov 2025 15:22:37 -0800 Subject: [PATCH 6/9] update tests --- tests/test_chatkit_server.py | 73 ++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 817cc54..3f1380f 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -240,8 +240,8 @@ async def responder( def generate_item_id( item_type: StoreItemType, thread: ThreadMetadata, context: Any ): - if item_type == "hidden_context": - return default_generate_id("hidden_context") + if item_type == "sdk_hidden_context": + return default_generate_id("sdk_hidden_context") return original_generate_item_id(item_type, thread, context) server.store.generate_item_id = generate_item_id # type: ignore[method-assign] @@ -270,8 +270,8 @@ def generate_item_id( thread.id, None, 1, "desc", DEFAULT_CONTEXT ) hidden_context_item = items.data[-1] - assert hidden_context_item.type == "hidden_context_item" - assert hidden_context_item.content == "SYSTEM: The user cancelled the stream." + assert hidden_context_item.type == "sdk_hidden_context" + assert hidden_context_item.content == "The user cancelled the stream." assistant_message_item = await server.store.load_item( thread.id, "assistant-message-pending", DEFAULT_CONTEXT @@ -300,8 +300,8 @@ async def responder( def generate_item_id( item_type: StoreItemType, thread: ThreadMetadata, context: Any ): - if item_type == "hidden_context": - return default_generate_id("hidden_context") + if item_type == "sdk_hidden_context": + return default_generate_id("sdk_hidden_context") return original_generate_item_id(item_type, thread, context) server.store.generate_item_id = generate_item_id # type: ignore[method-assign] @@ -330,8 +330,8 @@ def generate_item_id( thread.id, None, 1, "desc", DEFAULT_CONTEXT ) hidden_context_item = items.data[-1] - assert hidden_context_item.type == "hidden_context_item" - assert hidden_context_item.content == "SYSTEM: The user cancelled the stream." + assert hidden_context_item.type == "sdk_hidden_context" + assert hidden_context_item.content == "The user cancelled the stream." with pytest.raises(NotFoundError): await server.store.load_item( @@ -643,19 +643,21 @@ async def responder( ) ) - assert len(events) == 3 + assert len(events) == 4 assert events[0].type == "thread.created" thread = events[0].thread assert events[1].type == "thread.item.done" assert events[1].item.type == "user_message" - assert events[2].type == "thread.item.done" - assert events[2].item.type == "client_tool_call" - assert events[2].item.id == "msg_1" - assert events[2].item.name == "tool_call_1" - assert events[2].item.arguments == {"arg1": "val1", "arg2": False} - assert events[2].item.call_id == "tool_call_1" + assert events[2].type == "stream_options" + + assert events[3].type == "thread.item.done" + assert events[3].item.type == "client_tool_call" + assert events[3].item.id == "msg_1" + assert events[3].item.name == "tool_call_1" + assert events[3].item.arguments == {"arg1": "val1", "arg2": False} + assert events[3].item.call_id == "tool_call_1" events = await server.process_streaming( ThreadsAddClientToolOutputReq( @@ -666,9 +668,10 @@ async def responder( ) ) - assert len(events) == 1 - assert events[0].type == "thread.item.done" - assert events[0].item.type == "assistant_message" + assert len(events) == 2 + assert events[0].type == "stream_options" + assert events[1].type == "thread.item.done" + assert events[1].item.type == "assistant_message" async def test_removes_tool_call_if_no_output_provided(): @@ -764,11 +767,12 @@ async def responder( ) ) - assert len(events) == 4 + assert len(events) == 5 assert events[0].type == "thread.created" assert events[1].type == "thread.item.done" - assert events[2].type == "progress_update" - assert events[3].type == "thread.item.done" + assert events[2].type == "stream_options" + assert events[3].type == "progress_update" + assert events[4].type == "thread.item.done" async def test_list_threads_response(): @@ -939,11 +943,12 @@ async def action( widget_item, ) - assert len(events) == 1 - assert events[0].type == "thread.item.updated" - assert isinstance(events[0], ThreadItemUpdatedEvent) - assert events[0].update.type == "widget.root.updated" - assert events[0].update.widget == Card(children=[Text(value="Email sent!")]) + assert len(events) == 2 + assert events[0].type == "stream_options" + assert events[1].type == "thread.item.updated" + assert isinstance(events[1], ThreadItemUpdatedEvent) + assert events[1].update.type == "widget.root.updated" + assert events[1].update.widget == Card(children=[Text(value="Email sent!")]) async def test_add_feedback(): @@ -1603,11 +1608,12 @@ async def responder( ) # Verify the retry generated new response - assert len(retry_events) == 1 - assert retry_events[0].type == "thread.item.done" - assert retry_events[0].item.type == "assistant_message" - assert retry_events[0].item.content[0].type == "output_text" - assert retry_events[0].item.content[0].text == "Retried response" + assert len(retry_events) == 2 + assert retry_events[0].type == "stream_options" + assert retry_events[1].type == "thread.item.done" + assert retry_events[1].item.type == "assistant_message" + assert retry_events[1].item.content[0].type == "output_text" + assert retry_events[1].item.content[0].text == "Retried response" # Verify the responder was called twice with the same user message assert len(responder_calls) == 2 @@ -1750,8 +1756,9 @@ async def responder( ) ) - assert len(retry_events) == 1 - assert retry_events[0].type == "thread.item.done" + assert len(retry_events) == 2 + assert retry_events[0].type == "stream_options" + assert retry_events[1].type == "thread.item.done" # Verify retry used the second user message assert len(responder_calls) == 3 # Original 2 + 1 retry From 1860286bf7fca86e6b70e7e87ef4681a0f482145 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 24 Nov 2025 15:34:50 -0800 Subject: [PATCH 7/9] filter out new hidden context item type --- chatkit/server.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/chatkit/server.py b/chatkit/server.py index 5097304..a63bff7 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -441,11 +441,11 @@ async def _process_non_streaming( after=items_list_params.after, context=context, ) - # filter out HiddenContextItems + # filter out hidden context items items.data = [ item for item in items.data - if not isinstance(item, HiddenContextItem) + if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem)) ] return self._serialize(items) case ThreadsUpdateReq(): @@ -711,7 +711,9 @@ async def _process_events( # special case - don't send hidden context items back to the client should_swallow_event = isinstance( event, ThreadItemDoneEvent - ) and isinstance(event.item, HiddenContextItem) + ) and isinstance( + event.item, (HiddenContextItem, SDKHiddenContextItem) + ) if not should_swallow_event: yield event @@ -867,7 +869,7 @@ def _serialize(self, obj: BaseModel) -> bytes: def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread: def is_hidden(item: ThreadItem) -> bool: - return isinstance(item, HiddenContextItem) + return isinstance(item, (HiddenContextItem, SDKHiddenContextItem)) items = thread.items if isinstance(thread, Thread) else Page() items.data = [item for item in items.data if not is_hidden(item)] From 6df9ee60da4de3f063e752e60bd3ceef920f4e78 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 24 Nov 2025 15:48:03 -0800 Subject: [PATCH 8/9] Make hidden context after cancellation more explicit --- chatkit/server.py | 2 +- tests/test_chatkit_server.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/chatkit/server.py b/chatkit/server.py index a63bff7..f038dfa 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -365,7 +365,7 @@ async def handle_stream_cancelled( thread_id=thread.id, created_at=datetime.now(), id=self.store.generate_item_id("sdk_hidden_context", thread, context), - content="The user cancelled the stream.", + content="The user cancelled the stream. Stop responding to the prior request.", ), context=context, ) diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 3f1380f..d7ac167 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -271,7 +271,10 @@ def generate_item_id( ) hidden_context_item = items.data[-1] assert hidden_context_item.type == "sdk_hidden_context" - assert hidden_context_item.content == "The user cancelled the stream." + assert ( + hidden_context_item.content + == "The user cancelled the stream. Stop responding to the prior request." + ) assistant_message_item = await server.store.load_item( thread.id, "assistant-message-pending", DEFAULT_CONTEXT @@ -331,7 +334,10 @@ def generate_item_id( ) hidden_context_item = items.data[-1] assert hidden_context_item.type == "sdk_hidden_context" - assert hidden_context_item.content == "The user cancelled the stream." + assert ( + hidden_context_item.content + == "The user cancelled the stream. Stop responding to the prior request." + ) with pytest.raises(NotFoundError): await server.store.load_item( From 7ae5852b5c01d76f1d1e72f38405144683e08277 Mon Sep 17 00:00:00 2001 From: Jiwon Kim Date: Mon, 24 Nov 2025 16:31:12 -0800 Subject: [PATCH 9/9] Make default converter more generic --- chatkit/agents.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chatkit/agents.py b/chatkit/agents.py index d810331..b56c657 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -719,11 +719,12 @@ async def sdk_hidden_context_to_input( """ Convert a SDKHiddenContextItem into input item to send to the model. This is used by the ChatKit Python SDK for storing additional context - for internal operations; you shouldn't need to override this. + for internal operations. + Override if you want to wrap the content in a different format. """ text = ( - "Hidden ChatKit SDK context for the agent (not shown to the user):\n" - f"\n{item.content}\n" + "Hidden context for the agent (not shown to the user):\n" + f"\n{item.content}\n" ) return Message( type="message",