diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..3b0a5c0356 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -786,14 +786,23 @@ def _delete_sync(): structure_deleted = cursor.rowcount + # Drop messages that only this branch referenced. The helper + # is scoped to ``self.session_id`` and removes only rows + # without any ``message_structure`` reference, so it keeps + # messages that the main branch or other branches still + # share via ``_copy_messages_to_new_branch`` (#3346). + orphaned_deleted = self._cleanup_orphaned_messages_sync(conn) + conn.commit() - return usage_deleted, structure_deleted + return usage_deleted, structure_deleted, orphaned_deleted - usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + usage_deleted, structure_deleted, orphaned_deleted = await asyncio.to_thread(_delete_sync) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + f"Deleted branch '{branch_id}': {structure_deleted} message entries, " + f"{orphaned_deleted} orphaned messages, " + f"{usage_deleted} usage entries" ) async def list_branches(self) -> list[dict[str, Any]]: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..cf8bd5b602 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -1422,3 +1422,102 @@ async def test_output_tokens_details_persisted_when_input_details_missing(): assert turn_usage["output_tokens_details"] == {"reasoning_tokens": 42} assert turn_usage["input_tokens_details"] is None session.close() + + +async def test_delete_branch_cleans_branch_only_messages(): + """delete_branch must drop messages that were only referenced by that branch. + + Regression test for https://github.com/openai/openai-agents-python/issues/3346. + Previously ``delete_branch`` removed ``turn_usage`` and ``message_structure`` + rows for the deleted branch but left the underlying ``agent_messages`` rows + in the table when no other branch still referenced them. After this fix the + orphaned rows are cleaned up while messages still shared with main or other + branches are preserved. + """ + session_id = "delete_branch_orphan_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Two messages on main, then a branch-only pair that only the branch sees. + await session.add_items( + [ + {"role": "user", "content": "main question"}, + {"role": "assistant", "content": "main answer"}, + ] + ) + + await session.create_branch_from_turn(1, "branch_only") + await session.add_items( + [ + {"role": "user", "content": "branch-only question"}, + {"role": "assistant", "content": "branch-only answer"}, + ] + ) + + await session.delete_branch("branch_only", force=True) + + # After deletion only the two main-branch messages should remain in the + # base ``agent_messages`` table; ``message_structure`` should only have the + # main-branch references. + with session._locked_connection() as conn: + message_rows = conn.execute( + f""" + SELECT id, message_data + FROM {session.messages_table} + WHERE session_id = ? + ORDER BY id + """, + (session.session_id,), + ).fetchall() + structure_rows = conn.execute( + """ + SELECT branch_id, message_id + FROM message_structure + WHERE session_id = ? + ORDER BY message_id + """, + (session.session_id,), + ).fetchall() + + message_contents = [json.loads(row[1]).get("content") for row in message_rows] + assert message_contents == ["main question", "main answer"] + + branch_ids = {row[0] for row in structure_rows} + assert branch_ids == {"main"} + + session.close() + + +async def test_delete_branch_preserves_messages_shared_with_main(): + """delete_branch must keep messages that the main branch still references.""" + session_id = "delete_branch_shared_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + # Two user-led turns on main so a branch can fork at turn 2 while still + # sharing turn 1's messages. + await session.add_items( + [ + {"role": "user", "content": "main question"}, + {"role": "assistant", "content": "main answer"}, + ] + ) + await session.add_items( + [ + {"role": "user", "content": "main follow-up"}, + {"role": "assistant", "content": "main follow-up answer"}, + ] + ) + + await session.create_branch_from_turn(2, "experiment") + + # Delete the branch. Messages still referenced by main must stay. + await session.delete_branch("experiment", force=True) + + main_items = await session.get_items() + assert [item.get("content") for item in main_items] == [ + "main question", + "main answer", + "main follow-up", + "main follow-up answer", + ] + + session.close()