Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,11 +1973,93 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t
contents.extend(coalesced_contents)


def _content_items_text(items: Any) -> str | None:
"""Return concatenated text when a content item list only contains text."""
if not isinstance(items, list):
return None
text_parts: list[str] = []
for item in items:
if not isinstance(item, Content) or item.type != "text":
return None
text_parts.append(item.text or "")
return "".join(text_parts)


def _merge_content_item_lists(existing: Any, incoming: Any) -> Any:
"""Merge streamed nested content lists, replacing deltas with a later full value when present."""
if incoming is None:
return existing
if existing is None:
return deepcopy(incoming)

existing_text = _content_items_text(existing)
incoming_text = _content_items_text(incoming)
if existing_text is not None and incoming_text is not None:
if incoming_text.startswith(existing_text):
return deepcopy(incoming)
if existing_text.startswith(incoming_text):
return existing

merged = deepcopy(existing[0])
merged.text = existing_text + incoming_text
return [merged]

if isinstance(existing, list) and isinstance(incoming, list):
return [*existing, *deepcopy(incoming)]
return deepcopy(incoming)


def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None:
"""Merge two code interpreter content items for the same logical call."""
existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs)
existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs)
existing.annotations = _combine_annotations(existing.annotations, incoming.annotations)
existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties}
existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation)


def _code_interpreter_key(content: Content) -> tuple[str, str] | None:
"""Return the aggregation key for code interpreter call/result content."""
if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}:
return None
call_id = content.call_id or content.additional_properties.get("item_id")
if not isinstance(call_id, str) or not call_id:
return None
return content.type, call_id


def _coalesce_code_interpreter_content(contents: list[Content]) -> None:
"""Coalesce streaming code interpreter chunks by call id."""
if not contents:
return

coalesced_contents: list[Content] = []
seen: dict[tuple[str, str], Content] = {}
for content in contents:
key = _code_interpreter_key(content)
if key is None:
coalesced_contents.append(content)
continue

existing = seen.get(key)
if existing is None:
copied = deepcopy(content)
seen[key] = copied
coalesced_contents.append(copied)
continue

_merge_code_interpreter_content(existing, content)

contents.clear()
contents.extend(coalesced_contents)


def _finalize_response(response: ChatResponse | AgentResponse) -> None:
"""Finalizes the response by performing any necessary post-processing."""
for msg in response.messages:
_coalesce_text_content(msg.contents, "text")
_coalesce_text_content(msg.contents, "text_reasoning")
_coalesce_code_interpreter_content(msg.contents)


# region ContinuationToken
Expand Down
57 changes: 57 additions & 0 deletions python/packages/core/tests/core/test_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,63 @@ async def test_after_run_stores_inputs_and_responses(self) -> None:
assert provider.stored[0].text == "hello"
assert provider.stored[1].text == "hi"

async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None:
from agent_framework import AgentResponse, AgentResponseUpdate, Content

provider = ConcreteHistoryProvider("mem", store_inputs=False)
updates = [
AgentResponseUpdate(
role="assistant",
contents=[
Content.from_code_interpreter_tool_result(
call_id="ci_123",
outputs=[],
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import")],
additional_properties={"sequence_number": 1},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text=" pandas")],
additional_properties={"sequence_number": 2},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import pandas as pd")],
additional_properties={"sequence_number": 3},
)
],
),
]
ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])])
ctx._response = AgentResponse.from_updates(updates)

await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type]

assert len(provider.stored) == 1
stored_contents = provider.stored[0].contents
calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"]
results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"]
assert len(calls) == 1
assert len(results) == 1
assert calls[0].inputs is not None
assert len(calls[0].inputs) == 1
assert calls[0].inputs[0].text == "import pandas as pd"

async def test_after_run_skips_inputs_when_disabled(self) -> None:
from agent_framework import AgentResponse

Expand Down
Loading