diff --git a/chatkit/agents.py b/chatkit/agents.py index 9f6d48e..b56c657 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -56,6 +56,7 @@ EndOfTurnItem, FileSource, HiddenContextItem, + SDKHiddenContextItem, Task, TaskItem, ThoughtTask, @@ -712,6 +713,30 @@ async def hidden_context_to_input( role="user", ) + async def sdk_hidden_context_to_input( + self, item: SDKHiddenContextItem + ) -> TResponseInputItem | list[TResponseInputItem] | None: + """ + Convert a SDKHiddenContextItem into input item to send to the model. + This is used by the ChatKit Python SDK for storing additional context + for internal operations. + Override if you want to wrap the content in a different format. + """ + text = ( + "Hidden context for the agent (not shown to the user):\n" + f"\n{item.content}\n" + ) + return Message( + type="message", + content=[ + ResponseInputTextParam( + type="input_text", + text=text, + ) + ], + role="user", + ) + async def task_to_input( self, item: TaskItem ) -> TResponseInputItem | list[TResponseInputItem] | None: @@ -948,6 +973,9 @@ async def _thread_item_to_input_item( case HiddenContextItem(): out = await self.hidden_context_to_input(item) or [] return out if isinstance(out, list) else [out] + case SDKHiddenContextItem(): + out = await self.sdk_hidden_context_to_input(item) or [] + return out if isinstance(out, list) else [out] case _: assert_never(item) diff --git a/chatkit/server.py b/chatkit/server.py index 66d7b58..f038dfa 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -27,6 +27,12 @@ from .store import AttachmentStore, Store, StoreItemType, default_generate_id from .types import ( Action, + AssistantMessageContent, + AssistantMessageContentPartAdded, + AssistantMessageContentPartAnnotationAdded, + AssistantMessageContentPartDone, + AssistantMessageContentPartTextDelta, + AssistantMessageItem, AttachmentsCreateReq, AttachmentsDeleteReq, ChatKitReq, @@ -39,7 +45,10 @@ ItemsListReq, NonStreamingReq, Page, + SDKHiddenContextItem, StreamingReq, + StreamOptions, + StreamOptionsEvent, Thread, ThreadCreatedEvent, ThreadItem, @@ -66,6 +75,9 @@ WidgetItem, WidgetRootUpdated, WidgetStreamingTextValueDelta, + WorkflowItem, + WorkflowTaskAdded, + WorkflowTaskUpdated, is_streaming_req, ) from .version import __version__ @@ -308,6 +320,56 @@ def action( "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions" ) + def get_stream_options( + self, thread: ThreadMetadata, context: TContext + ) -> StreamOptions: + """ + Return stream-level runtime options. Allows the user to cancel the stream by default. + Override this method to customize behavior. + """ + return StreamOptions(allow_cancel=True) + + async def handle_stream_cancelled( + self, + thread: ThreadMetadata, + pending_items: list[ThreadItem], + context: TContext, + ): + """Perform custom cleanup / stop inference when a stream is cancelled. + Updates you make here will not be reflected in the UI until a reload. + + The default implementation persists any non-empty pending assistant messages + to the thread but does not auto-save pending widget items or workflow items. + + Args: + thread: The thread that was being processed. + pending_items: Items that were not done streaming at cancellation time. + context: Arbitrary per-request context provided by the caller. + """ + pending_assistant_message_items: list[AssistantMessageItem] = [ + item for item in pending_items if isinstance(item, AssistantMessageItem) + ] + for item in pending_assistant_message_items: + is_empty = len(item.content) == 0 or all( + (not content.text.strip()) for content in item.content + ) + if not is_empty: + await self.store.add_thread_item(thread.id, item, context=context) + + # Add a hidden context item to the thread to indicate that the stream was cancelled. + # Otherwise, depending on the timing of the cancellation, subsequent responses may + # attempt to continue the cancelled response. + await self.store.add_thread_item( + thread.id, + SDKHiddenContextItem( + thread_id=thread.id, + created_at=datetime.now(), + id=self.store.generate_item_id("sdk_hidden_context", thread, context), + content="The user cancelled the stream. Stop responding to the prior request.", + ), + context=context, + ) + async def process( self, request: str | bytes | bytearray, context: TContext ) -> StreamingResult | NonStreamingResult: @@ -379,11 +441,11 @@ async def _process_non_streaming( after=items_list_params.after, context=context, ) - # filter out HiddenContextItems + # filter out hidden context items items.data = [ item for item in items.data - if not isinstance(item, HiddenContextItem) + if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem)) ] return self._serialize(items) case ThreadsUpdateReq(): @@ -408,6 +470,9 @@ async def _process_streaming( async for event in self._process_streaming_impl(request, context): b = self._serialize(event) yield b"data: " + b + b"\n\n" + except asyncio.CancelledError: + # Let cancellation bubble up without logging as an error. + raise except Exception: logger.exception("Error while generating streamed response") raise @@ -604,29 +669,51 @@ async def _process_events( ) -> AsyncIterator[ThreadStreamEvent]: await asyncio.sleep(0) # allow the response to start streaming + # Send initial stream options + yield StreamOptionsEvent( + stream_options=self.get_stream_options(thread, context) + ) + last_thread = thread.model_copy(deep=True) + # Keep track of items that were streamed but not yet saved + # so that we can persist them when the stream is cancelled. + pending_items: dict[str, ThreadItem] = {} + try: with agents_sdk_user_agent_override(): async for event in stream(): + if isinstance(event, ThreadItemAddedEvent): + pending_items[event.item.id] = event.item + match event: case ThreadItemDoneEvent(): await self.store.add_thread_item( thread.id, event.item, context=context ) + pending_items.pop(event.item.id, None) case ThreadItemRemovedEvent(): await self.store.delete_thread_item( thread.id, event.item_id, context=context ) + pending_items.pop(event.item_id, None) case ThreadItemReplacedEvent(): await self.store.save_item( thread.id, event.item, context=context ) + pending_items.pop(event.item.id, None) + case ThreadItemUpdatedEvent(): + # Keep pending assistant message and workflow items up to date + # so that we have a reference to the latest version of these pending items + # when the stream is cancelled. + self._update_pending_items(pending_items, event) # special case - don't send hidden context items back to the client should_swallow_event = isinstance( event, ThreadItemDoneEvent - ) and isinstance(event.item, HiddenContextItem) + ) and isinstance( + event.item, (HiddenContextItem, SDKHiddenContextItem) + ) if not should_swallow_event: yield event @@ -643,6 +730,11 @@ async def _process_events( last_thread = thread.model_copy(deep=True) await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) + except asyncio.CancelledError: + await self.handle_stream_cancelled( + thread, list(pending_items.values()), context + ) + raise except CustomStreamError as e: yield ErrorEvent( code="custom", @@ -666,6 +758,69 @@ async def _process_events( await self.store.save_thread(thread, context=context) yield ThreadUpdatedEvent(thread=self._to_thread_response(thread)) + def _apply_assistant_message_update( + self, + item: AssistantMessageItem, + update: AssistantMessageContentPartAdded + | AssistantMessageContentPartTextDelta + | AssistantMessageContentPartAnnotationAdded + | AssistantMessageContentPartDone, + ) -> AssistantMessageItem: + updated = item.model_copy(deep=True) + + # Pad the content list so the requested content_index exists before we write into it. + # (Streaming updates can arrive for an index that hasn’t been created yet) + while len(updated.content) <= update.content_index: + updated.content.append(AssistantMessageContent(text="", annotations=[])) + + match update: + case AssistantMessageContentPartAdded(): + updated.content[update.content_index] = update.content + case AssistantMessageContentPartTextDelta(): + updated.content[update.content_index].text += update.delta + case AssistantMessageContentPartAnnotationAdded(): + annotations = updated.content[update.content_index].annotations + if update.annotation_index <= len(annotations): + annotations.insert(update.annotation_index, update.annotation) + else: + annotations.append(update.annotation) + case AssistantMessageContentPartDone(): + updated.content[update.content_index] = update.content + return updated + + def _update_pending_items( + self, + pending_items: dict[str, ThreadItem], + event: ThreadItemUpdatedEvent, + ): + updated_item = pending_items.get(event.item_id) + update = event.update + match updated_item: + case AssistantMessageItem(): + if isinstance( + update, + ( + AssistantMessageContentPartAdded, + AssistantMessageContentPartTextDelta, + AssistantMessageContentPartAnnotationAdded, + AssistantMessageContentPartDone, + ), + ): + pending_items[updated_item.id] = ( + self._apply_assistant_message_update(updated_item, update) + ) + case WorkflowItem(): + if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)): + match update: + case WorkflowTaskUpdated(): + updated_item.workflow.tasks[update.task_index] = update.task + case WorkflowTaskAdded(): + updated_item.workflow.tasks.append(update.task) + + pending_items[updated_item.id] = updated_item + case _: + pass + async def _build_user_message_item( self, input: UserMessageInput, thread: ThreadMetadata, context: TContext ) -> UserMessageItem: @@ -714,7 +869,7 @@ def _serialize(self, obj: BaseModel) -> bytes: def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread: def is_hidden(item: ThreadItem) -> bool: - return isinstance(item, HiddenContextItem) + return isinstance(item, (HiddenContextItem, SDKHiddenContextItem)) items = thread.items if isinstance(thread, Thread) else Page() items.data = [item for item in items.data if not is_hidden(item)] diff --git a/chatkit/store.py b/chatkit/store.py index 1c907c8..90be0b0 100644 --- a/chatkit/store.py +++ b/chatkit/store.py @@ -15,7 +15,13 @@ TContext = TypeVar("TContext", default=Any) StoreItemType = Literal[ - "thread", "message", "tool_call", "task", "workflow", "attachment" + "thread", + "message", + "tool_call", + "task", + "workflow", + "attachment", + "sdk_hidden_context", ] @@ -26,6 +32,7 @@ "workflow": "wf", "task": "tsk", "attachment": "atc", + "sdk_hidden_context": "shcx", } diff --git a/chatkit/types.py b/chatkit/types.py index eb7b7f3..3ba2a1c 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -317,6 +317,20 @@ class ThreadItemReplacedEvent(BaseModel): item: ThreadItem +class StreamOptions(BaseModel): + """Settings that control runtime stream behavior.""" + + allow_cancel: bool + """Allow the client to request cancellation mid-stream.""" + + +class StreamOptionsEvent(BaseModel): + """Event emitted to set stream options at runtime.""" + + type: Literal["stream_options"] = "stream_options" + stream_options: StreamOptions + + class ProgressUpdateEvent(BaseModel): """Event providing incremental progress from the assistant.""" @@ -354,6 +368,7 @@ class NoticeEvent(BaseModel): | ThreadItemUpdated | ThreadItemRemovedEvent | ThreadItemReplacedEvent + | StreamOptionsEvent | ProgressUpdateEvent | ErrorEvent | NoticeEvent, @@ -576,12 +591,25 @@ class EndOfTurnItem(ThreadItemBase): class HiddenContextItem(ThreadItemBase): - """HiddenContext is never sent to the client. It's not officially part of ChatKit. It is only used internally to store additional context in a specific place in the thread.""" + """ + HiddenContext is never sent to the client. It's not officially part of ChatKit.js. + It is only used internally to store additional context in a specific place in the thread. + """ type: Literal["hidden_context_item"] = "hidden_context_item" content: Any +class SDKHiddenContextItem(ThreadItemBase): + """ + Hidden context that is used by the ChatKit Python SDK for storing additional context + for internal operations. + """ + + type: Literal["sdk_hidden_context"] = "sdk_hidden_context" + content: str + + ThreadItem = Annotated[ UserMessageItem | AssistantMessageItem @@ -590,6 +618,7 @@ class HiddenContextItem(ThreadItemBase): | WorkflowItem | TaskItem | HiddenContextItem + | SDKHiddenContextItem | EndOfTurnItem, Field(discriminator="type"), ] diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 3526532..d7ac167 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -1,3 +1,4 @@ +import asyncio import sqlite3 from contextlib import contextmanager from datetime import datetime @@ -15,9 +16,15 @@ StreamingResult, stream_widget, ) -from chatkit.store import AttachmentStore, NotFoundError +from chatkit.store import ( + AttachmentStore, + NotFoundError, + StoreItemType, + default_generate_id, +) from chatkit.types import ( AssistantMessageContent, + AssistantMessageContentPartTextDelta, AssistantMessageItem, Attachment, AttachmentCreateParams, @@ -205,6 +212,139 @@ async def process_non_streaming(self, request_obj, context: Any | None = None): db.close() +async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + yield ThreadItemAddedEvent( + item=AssistantMessageItem( + id="assistant-message-pending", + created_at=datetime.now(), + content=[AssistantMessageContent(text="Hello, ")], + thread_id=thread.id, + ) + ) + yield ThreadItemUpdatedEvent( + item_id="assistant-message-pending", + update=AssistantMessageContentPartTextDelta( + content_index=0, + delta="World!", + ), + ) + raise asyncio.CancelledError() + + with make_server(responder) as server: + # Allow hidden_context id generation in this test store + original_generate_item_id = server.store.generate_item_id + + def generate_item_id( + item_type: StoreItemType, thread: ThreadMetadata, context: Any + ): + if item_type == "sdk_hidden_context": + return default_generate_id("sdk_hidden_context") + return original_generate_item_id(item_type, thread, context) + + server.store.generate_item_id = generate_item_id # type: ignore[method-assign] + + stream = await server.process( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ).model_dump_json(), + DEFAULT_CONTEXT, + ) + assert isinstance(stream, StreamingResult) + + events: list[ThreadStreamEvent] = [] + with pytest.raises(asyncio.CancelledError): # noqa: PT012 + async for raw in stream.json_events: + events.append(decode_event(raw)) + + thread = next(e.thread for e in events if e.type == "thread.created") + items = await server.store.load_thread_items( + thread.id, None, 1, "desc", DEFAULT_CONTEXT + ) + hidden_context_item = items.data[-1] + assert hidden_context_item.type == "sdk_hidden_context" + assert ( + hidden_context_item.content + == "The user cancelled the stream. Stop responding to the prior request." + ) + + assistant_message_item = await server.store.load_item( + thread.id, "assistant-message-pending", DEFAULT_CONTEXT + ) + assert assistant_message_item.type == "assistant_message" + assert assistant_message_item.content[0].text == "Hello, World!" + + +async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + yield ThreadItemAddedEvent( + item=AssistantMessageItem( + id="assistant-message-pending", + created_at=datetime.now(), + content=[], + thread_id=thread.id, + ) + ) + raise asyncio.CancelledError() + + with make_server(responder) as server: + original_generate_item_id = server.store.generate_item_id + + def generate_item_id( + item_type: StoreItemType, thread: ThreadMetadata, context: Any + ): + if item_type == "sdk_hidden_context": + return default_generate_id("sdk_hidden_context") + return original_generate_item_id(item_type, thread, context) + + server.store.generate_item_id = generate_item_id # type: ignore[method-assign] + + stream = await server.process( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ).model_dump_json(), + DEFAULT_CONTEXT, + ) + assert isinstance(stream, StreamingResult) + + events: list[ThreadStreamEvent] = [] + with pytest.raises(asyncio.CancelledError): # noqa: PT012 + async for raw in stream.json_events: + events.append(decode_event(raw)) + + thread = next(e.thread for e in events if e.type == "thread.created") + items = await server.store.load_thread_items( + thread.id, None, 1, "desc", DEFAULT_CONTEXT + ) + hidden_context_item = items.data[-1] + assert hidden_context_item.type == "sdk_hidden_context" + assert ( + hidden_context_item.content + == "The user cancelled the stream. Stop responding to the prior request." + ) + + with pytest.raises(NotFoundError): + await server.store.load_item( + thread.id, "assistant-message-pending", DEFAULT_CONTEXT + ) + + async def test_flows_context_to_responder(): responder_context = None add_feedback_context = None @@ -509,19 +649,21 @@ async def responder( ) ) - assert len(events) == 3 + assert len(events) == 4 assert events[0].type == "thread.created" thread = events[0].thread assert events[1].type == "thread.item.done" assert events[1].item.type == "user_message" - assert events[2].type == "thread.item.done" - assert events[2].item.type == "client_tool_call" - assert events[2].item.id == "msg_1" - assert events[2].item.name == "tool_call_1" - assert events[2].item.arguments == {"arg1": "val1", "arg2": False} - assert events[2].item.call_id == "tool_call_1" + assert events[2].type == "stream_options" + + assert events[3].type == "thread.item.done" + assert events[3].item.type == "client_tool_call" + assert events[3].item.id == "msg_1" + assert events[3].item.name == "tool_call_1" + assert events[3].item.arguments == {"arg1": "val1", "arg2": False} + assert events[3].item.call_id == "tool_call_1" events = await server.process_streaming( ThreadsAddClientToolOutputReq( @@ -532,9 +674,10 @@ async def responder( ) ) - assert len(events) == 1 - assert events[0].type == "thread.item.done" - assert events[0].item.type == "assistant_message" + assert len(events) == 2 + assert events[0].type == "stream_options" + assert events[1].type == "thread.item.done" + assert events[1].item.type == "assistant_message" async def test_removes_tool_call_if_no_output_provided(): @@ -630,11 +773,12 @@ async def responder( ) ) - assert len(events) == 4 + assert len(events) == 5 assert events[0].type == "thread.created" assert events[1].type == "thread.item.done" - assert events[2].type == "progress_update" - assert events[3].type == "thread.item.done" + assert events[2].type == "stream_options" + assert events[3].type == "progress_update" + assert events[4].type == "thread.item.done" async def test_list_threads_response(): @@ -805,11 +949,12 @@ async def action( widget_item, ) - assert len(events) == 1 - assert events[0].type == "thread.item.updated" - assert isinstance(events[0], ThreadItemUpdatedEvent) - assert events[0].update.type == "widget.root.updated" - assert events[0].update.widget == Card(children=[Text(value="Email sent!")]) + assert len(events) == 2 + assert events[0].type == "stream_options" + assert events[1].type == "thread.item.updated" + assert isinstance(events[1], ThreadItemUpdatedEvent) + assert events[1].update.type == "widget.root.updated" + assert events[1].update.widget == Card(children=[Text(value="Email sent!")]) async def test_add_feedback(): @@ -1469,11 +1614,12 @@ async def responder( ) # Verify the retry generated new response - assert len(retry_events) == 1 - assert retry_events[0].type == "thread.item.done" - assert retry_events[0].item.type == "assistant_message" - assert retry_events[0].item.content[0].type == "output_text" - assert retry_events[0].item.content[0].text == "Retried response" + assert len(retry_events) == 2 + assert retry_events[0].type == "stream_options" + assert retry_events[1].type == "thread.item.done" + assert retry_events[1].item.type == "assistant_message" + assert retry_events[1].item.content[0].type == "output_text" + assert retry_events[1].item.content[0].text == "Retried response" # Verify the responder was called twice with the same user message assert len(responder_calls) == 2 @@ -1616,8 +1762,9 @@ async def responder( ) ) - assert len(retry_events) == 1 - assert retry_events[0].type == "thread.item.done" + assert len(retry_events) == 2 + assert retry_events[0].type == "stream_options" + assert retry_events[1].type == "thread.item.done" # Verify retry used the second user message assert len(responder_calls) == 3 # Original 2 + 1 retry