Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions chatkit/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
EndOfTurnItem,
FileSource,
HiddenContextItem,
SDKHiddenContextItem,
Task,
TaskItem,
ThoughtTask,
Expand Down Expand Up @@ -712,6 +713,30 @@ async def hidden_context_to_input(
role="user",
)

async def sdk_hidden_context_to_input(
self, item: SDKHiddenContextItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
"""
Convert a SDKHiddenContextItem into input item to send to the model.
This is used by the ChatKit Python SDK for storing additional context
for internal operations.
Override if you want to wrap the content in a different format.
"""
text = (
"Hidden context for the agent (not shown to the user):\n"
f"<HiddenContext>\n{item.content}\n</HiddenContext>"
)
return Message(
type="message",
content=[
ResponseInputTextParam(
type="input_text",
text=text,
)
],
role="user",
)

async def task_to_input(
self, item: TaskItem
) -> TResponseInputItem | list[TResponseInputItem] | None:
Expand Down Expand Up @@ -948,6 +973,9 @@ async def _thread_item_to_input_item(
case HiddenContextItem():
out = await self.hidden_context_to_input(item) or []
return out if isinstance(out, list) else [out]
case SDKHiddenContextItem():
out = await self.sdk_hidden_context_to_input(item) or []
return out if isinstance(out, list) else [out]
case _:
assert_never(item)

Expand Down
163 changes: 159 additions & 4 deletions chatkit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
from .store import AttachmentStore, Store, StoreItemType, default_generate_id
from .types import (
Action,
AssistantMessageContent,
AssistantMessageContentPartAdded,
AssistantMessageContentPartAnnotationAdded,
AssistantMessageContentPartDone,
AssistantMessageContentPartTextDelta,
AssistantMessageItem,
AttachmentsCreateReq,
AttachmentsDeleteReq,
ChatKitReq,
Expand All @@ -39,7 +45,10 @@
ItemsListReq,
NonStreamingReq,
Page,
SDKHiddenContextItem,
StreamingReq,
StreamOptions,
StreamOptionsEvent,
Thread,
ThreadCreatedEvent,
ThreadItem,
Expand All @@ -66,6 +75,9 @@
WidgetItem,
WidgetRootUpdated,
WidgetStreamingTextValueDelta,
WorkflowItem,
WorkflowTaskAdded,
WorkflowTaskUpdated,
is_streaming_req,
)
from .version import __version__
Expand Down Expand Up @@ -308,6 +320,56 @@ def action(
"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
)

def get_stream_options(
self, thread: ThreadMetadata, context: TContext
) -> StreamOptions:
"""
Return stream-level runtime options. Allows the user to cancel the stream by default.
Override this method to customize behavior.
"""
return StreamOptions(allow_cancel=True)

async def handle_stream_cancelled(
self,
thread: ThreadMetadata,
pending_items: list[ThreadItem],
context: TContext,
):
"""Perform custom cleanup / stop inference when a stream is cancelled.
Updates you make here will not be reflected in the UI until a reload.

The default implementation persists any non-empty pending assistant messages
to the thread but does not auto-save pending widget items or workflow items.

Args:
thread: The thread that was being processed.
pending_items: Items that were not done streaming at cancellation time.
context: Arbitrary per-request context provided by the caller.
"""
pending_assistant_message_items: list[AssistantMessageItem] = [
item for item in pending_items if isinstance(item, AssistantMessageItem)
]
for item in pending_assistant_message_items:
is_empty = len(item.content) == 0 or all(
(not content.text.strip()) for content in item.content
)
if not is_empty:
await self.store.add_thread_item(thread.id, item, context=context)

# Add a hidden context item to the thread to indicate that the stream was cancelled.
# Otherwise, depending on the timing of the cancellation, subsequent responses may
# attempt to continue the cancelled response.
await self.store.add_thread_item(
thread.id,
SDKHiddenContextItem(
thread_id=thread.id,
created_at=datetime.now(),
id=self.store.generate_item_id("sdk_hidden_context", thread, context),
content="The user cancelled the stream. Stop responding to the prior request.",
),
context=context,
)

async def process(
self, request: str | bytes | bytearray, context: TContext
) -> StreamingResult | NonStreamingResult:
Expand Down Expand Up @@ -379,11 +441,11 @@ async def _process_non_streaming(
after=items_list_params.after,
context=context,
)
# filter out HiddenContextItems
# filter out hidden context items
items.data = [
item
for item in items.data
if not isinstance(item, HiddenContextItem)
if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
]
return self._serialize(items)
case ThreadsUpdateReq():
Expand All @@ -408,6 +470,9 @@ async def _process_streaming(
async for event in self._process_streaming_impl(request, context):
b = self._serialize(event)
yield b"data: " + b + b"\n\n"
except asyncio.CancelledError:
# Let cancellation bubble up without logging as an error.
raise
except Exception:
logger.exception("Error while generating streamed response")
raise
Expand Down Expand Up @@ -604,29 +669,51 @@ async def _process_events(
) -> AsyncIterator[ThreadStreamEvent]:
await asyncio.sleep(0) # allow the response to start streaming

# Send initial stream options
yield StreamOptionsEvent(
stream_options=self.get_stream_options(thread, context)
)

last_thread = thread.model_copy(deep=True)

# Keep track of items that were streamed but not yet saved
# so that we can persist them when the stream is cancelled.
pending_items: dict[str, ThreadItem] = {}

try:
with agents_sdk_user_agent_override():
async for event in stream():
if isinstance(event, ThreadItemAddedEvent):
pending_items[event.item.id] = event.item

match event:
case ThreadItemDoneEvent():
await self.store.add_thread_item(
thread.id, event.item, context=context
)
pending_items.pop(event.item.id, None)
case ThreadItemRemovedEvent():
await self.store.delete_thread_item(
thread.id, event.item_id, context=context
)
pending_items.pop(event.item_id, None)
case ThreadItemReplacedEvent():
await self.store.save_item(
thread.id, event.item, context=context
)
pending_items.pop(event.item.id, None)
case ThreadItemUpdatedEvent():
# Keep pending assistant message and workflow items up to date
# so that we have a reference to the latest version of these pending items
# when the stream is cancelled.
self._update_pending_items(pending_items, event)

# special case - don't send hidden context items back to the client
should_swallow_event = isinstance(
event, ThreadItemDoneEvent
) and isinstance(event.item, HiddenContextItem)
) and isinstance(
event.item, (HiddenContextItem, SDKHiddenContextItem)
)

if not should_swallow_event:
yield event
Expand All @@ -643,6 +730,11 @@ async def _process_events(
last_thread = thread.model_copy(deep=True)
await self.store.save_thread(thread, context=context)
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
except asyncio.CancelledError:
await self.handle_stream_cancelled(
thread, list(pending_items.values()), context
)
raise
except CustomStreamError as e:
yield ErrorEvent(
code="custom",
Expand All @@ -666,6 +758,69 @@ async def _process_events(
await self.store.save_thread(thread, context=context)
yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))

def _apply_assistant_message_update(
self,
item: AssistantMessageItem,
update: AssistantMessageContentPartAdded
| AssistantMessageContentPartTextDelta
| AssistantMessageContentPartAnnotationAdded
| AssistantMessageContentPartDone,
) -> AssistantMessageItem:
updated = item.model_copy(deep=True)

# Pad the content list so the requested content_index exists before we write into it.
# (Streaming updates can arrive for an index that hasn’t been created yet)
while len(updated.content) <= update.content_index:
updated.content.append(AssistantMessageContent(text="", annotations=[]))

match update:
case AssistantMessageContentPartAdded():
updated.content[update.content_index] = update.content
case AssistantMessageContentPartTextDelta():
updated.content[update.content_index].text += update.delta
case AssistantMessageContentPartAnnotationAdded():
annotations = updated.content[update.content_index].annotations
if update.annotation_index <= len(annotations):
annotations.insert(update.annotation_index, update.annotation)
else:
annotations.append(update.annotation)
case AssistantMessageContentPartDone():
updated.content[update.content_index] = update.content
return updated

def _update_pending_items(
self,
pending_items: dict[str, ThreadItem],
event: ThreadItemUpdatedEvent,
):
updated_item = pending_items.get(event.item_id)
update = event.update
match updated_item:
case AssistantMessageItem():
if isinstance(
update,
(
AssistantMessageContentPartAdded,
AssistantMessageContentPartTextDelta,
AssistantMessageContentPartAnnotationAdded,
AssistantMessageContentPartDone,
),
):
pending_items[updated_item.id] = (
self._apply_assistant_message_update(updated_item, update)
)
case WorkflowItem():
if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
match update:
case WorkflowTaskUpdated():
updated_item.workflow.tasks[update.task_index] = update.task
case WorkflowTaskAdded():
updated_item.workflow.tasks.append(update.task)

pending_items[updated_item.id] = updated_item
case _:
pass

async def _build_user_message_item(
self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
) -> UserMessageItem:
Expand Down Expand Up @@ -714,7 +869,7 @@ def _serialize(self, obj: BaseModel) -> bytes:

def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
def is_hidden(item: ThreadItem) -> bool:
return isinstance(item, HiddenContextItem)
return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))

items = thread.items if isinstance(thread, Thread) else Page()
items.data = [item for item in items.data if not is_hidden(item)]
Expand Down
9 changes: 8 additions & 1 deletion chatkit/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
TContext = TypeVar("TContext", default=Any)

StoreItemType = Literal[
"thread", "message", "tool_call", "task", "workflow", "attachment"
"thread",
"message",
"tool_call",
"task",
"workflow",
"attachment",
"sdk_hidden_context",
]


Expand All @@ -26,6 +32,7 @@
"workflow": "wf",
"task": "tsk",
"attachment": "atc",
"sdk_hidden_context": "shcx",
}


Expand Down
31 changes: 30 additions & 1 deletion chatkit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,20 @@ class ThreadItemReplacedEvent(BaseModel):
item: ThreadItem


class StreamOptions(BaseModel):
"""Settings that control runtime stream behavior."""

allow_cancel: bool
"""Allow the client to request cancellation mid-stream."""


class StreamOptionsEvent(BaseModel):
"""Event emitted to set stream options at runtime."""

type: Literal["stream_options"] = "stream_options"
stream_options: StreamOptions


class ProgressUpdateEvent(BaseModel):
"""Event providing incremental progress from the assistant."""

Expand Down Expand Up @@ -354,6 +368,7 @@ class NoticeEvent(BaseModel):
| ThreadItemUpdated
| ThreadItemRemovedEvent
| ThreadItemReplacedEvent
| StreamOptionsEvent
| ProgressUpdateEvent
| ErrorEvent
| NoticeEvent,
Expand Down Expand Up @@ -576,12 +591,25 @@ class EndOfTurnItem(ThreadItemBase):


class HiddenContextItem(ThreadItemBase):
"""HiddenContext is never sent to the client. It's not officially part of ChatKit. It is only used internally to store additional context in a specific place in the thread."""
"""
HiddenContext is never sent to the client. It's not officially part of ChatKit.js.
It is only used internally to store additional context in a specific place in the thread.
"""

type: Literal["hidden_context_item"] = "hidden_context_item"
content: Any


class SDKHiddenContextItem(ThreadItemBase):
"""
Hidden context that is used by the ChatKit Python SDK for storing additional context
for internal operations.
"""

type: Literal["sdk_hidden_context"] = "sdk_hidden_context"
content: str


ThreadItem = Annotated[
UserMessageItem
| AssistantMessageItem
Expand All @@ -590,6 +618,7 @@ class HiddenContextItem(ThreadItemBase):
| WorkflowItem
| TaskItem
| HiddenContextItem
| SDKHiddenContextItem
| EndOfTurnItem,
Field(discriminator="type"),
]
Expand Down
Loading