diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index e1d7f248d..e1fc885bb 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -195,7 +195,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: stmt = ( select(self._messages.c.message_data) .where(self._messages.c.session_id == self.session_id) - .order_by(self._messages.c.created_at.asc()) + .order_by( + self._messages.c.created_at.asc(), + self._messages.c.id.asc(), + ) ) else: stmt = ( @@ -203,7 +206,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: .where(self._messages.c.session_id == self.session_id) # Use DESC + LIMIT to get the latest N # then reverse later for chronological order. - .order_by(self._messages.c.created_at.desc()) + .order_by( + self._messages.c.created_at.desc(), + self._messages.c.id.desc(), + ) .limit(limit) ) @@ -278,7 +284,10 @@ async def pop_item(self) -> TResponseInputItem | None: subq = ( select(self._messages.c.id) .where(self._messages.c.session_id == self.session_id) - .order_by(self._messages.c.created_at.desc()) + .order_by( + self._messages.c.created_at.desc(), + self._messages.c.id.desc(), + ) .limit(1) ) res = await sess.execute(subq) diff --git a/tests/extensions/memory/test_sqlalchemy_session.py b/tests/extensions/memory/test_sqlalchemy_session.py index e1ce3e10b..496d0b027 100644 --- a/tests/extensions/memory/test_sqlalchemy_session.py +++ b/tests/extensions/memory/test_sqlalchemy_session.py @@ -1,6 +1,20 @@ from __future__ import annotations +import json +from collections.abc import Iterable, Sequence +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +from typing import Any, cast + import pytest +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam +from openai.types.responses.response_output_text_param import ResponseOutputTextParam +from openai.types.responses.response_reasoning_item_param import ( + ResponseReasoningItemParam, + Summary, +) +from sqlalchemy import select, text, update +from sqlalchemy.sql import Select pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed @@ -16,6 +30,40 @@ DB_URL = "sqlite+aiosqlite:///:memory:" +def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem: + content: ResponseOutputTextParam = { + "type": "output_text", + "text": text_value, + "annotations": [], + } + message: ResponseOutputMessageParam = { + "id": item_id, + "type": "message", + "role": "assistant", + "status": "completed", + "content": [content], + } + return cast(TResponseInputItem, message) + + +def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem: + summary: Summary = {"type": "summary_text", "text": summary_text} + reasoning: ResponseReasoningItemParam = { + "id": item_id, + "type": "reasoning", + "summary": [summary], + } + return cast(TResponseInputItem, reasoning) + + +def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]: + result: list[str] = [] + for item in items: + item_dict = cast(dict[str, Any], item) + result.append(cast(str, item_dict["id"])) + return result + + @pytest.fixture def agent() -> Agent: """Fixture for a basic agent with a fake model.""" @@ -151,3 +199,195 @@ async def test_add_empty_items_list(): items_after_add = await session.get_items() assert len(items_after_add) == 0 + + +async def test_get_items_same_timestamp_consistent_order(): + """Test that items with identical timestamps keep insertion order.""" + session_id = "same_timestamp_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + older_item = _make_message_item("older_same_ts", "old") + reasoning_item = _make_reasoning_item("rs_same_ts", "...") + message_item = _make_message_item("msg_same_ts", "...") + await session.add_items([older_item]) + await session.add_items([reasoning_item, message_item]) + + async with session._session_factory() as sess: + rows = await sess.execute( + select(session._messages.c.id, session._messages.c.message_data).where( + session._messages.c.session_id == session.session_id + ) + ) + id_map = { + json.loads(message_json)["id"]: row_id + for row_id, message_json in rows.fetchall() + } + shared = datetime(2025, 10, 15, 17, 26, 39, 132483) + older = shared - timedelta(milliseconds=1) + await sess.execute( + update(session._messages) + .where(session._messages.c.id.in_( + [ + id_map["rs_same_ts"], + id_map["msg_same_ts"], + ] + )) + .values(created_at=shared) + ) + await sess.execute( + update(session._messages) + .where(session._messages.c.id == id_map["older_same_ts"]) + .values(created_at=older) + ) + await sess.commit() + + real_factory = session._session_factory + + class FakeResult: + def __init__(self, rows: Iterable[Any]): + self._rows = list(rows) + + def all(self) -> list[Any]: + return list(self._rows) + + def needs_shuffle(statement: Any) -> bool: + if not isinstance(statement, Select): + return False + orderings = list(statement._order_by_clause) + if not orderings: + return False + id_asc = session._messages.c.id.asc() + id_desc = session._messages.c.id.desc() + + def references_id(clause) -> bool: + try: + return bool(clause.compare(id_asc) or clause.compare(id_desc)) + except AttributeError: + return False + + if any(references_id(clause) for clause in orderings): + return False + # Only shuffle queries that target the messages table. + target_tables: set[str] = set() + for from_clause in statement.get_final_froms(): + name_attr = getattr(from_clause, "name", None) + if isinstance(name_attr, str): + target_tables.add(name_attr) + table_name_obj = getattr(session._messages, "name", "") + table_name = table_name_obj if isinstance(table_name_obj, str) else "" + return bool(table_name in target_tables) + + @asynccontextmanager + async def shuffled_session(): + async with real_factory() as inner: + original_execute = inner.execute + + async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any: + result = await original_execute(statement, *args, **kwargs) + if needs_shuffle(statement): + rows = result.all() + shuffled = list(rows) + shuffled.reverse() + return FakeResult(shuffled) + return result + + cast(Any, inner).execute = execute_with_shuffle + try: + yield inner + finally: + cast(Any, inner).execute = original_execute + + session._session_factory = cast(Any, shuffled_session) + try: + retrieved = await session.get_items() + assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"] + + latest_two = await session.get_items(limit=2) + assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"] + finally: + session._session_factory = real_factory + + +async def test_pop_item_same_timestamp_returns_latest(): + """Test that pop_item returns the newest item when timestamps tie.""" + session_id = "same_timestamp_pop_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...") + message_item = _make_message_item("msg_pop_same_ts", "...") + await session.add_items([reasoning_item, message_item]) + + async with session._session_factory() as sess: + await sess.execute( + text( + "UPDATE agent_messages " + "SET created_at = :created_at " + "WHERE session_id = :session_id" + ), + { + "created_at": "2025-10-15 17:26:39.132483", + "session_id": session.session_id, + }, + ) + await sess.commit() + + popped = await session.pop_item() + assert popped is not None + assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts" + + remaining = await session.get_items() + assert _item_ids(remaining) == ["rs_pop_same_ts"] + + +async def test_get_items_orders_by_id_for_ties(): + """Test that get_items adds id ordering to break timestamp ties.""" + session_id = "order_by_id_test" + session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True) + + await session.add_items( + [ + _make_reasoning_item("rs_first", "..."), + _make_message_item("msg_second", "..."), + ] + ) + + real_factory = session._session_factory + recorded: list[Any] = [] + + @asynccontextmanager + async def wrapped_session(): + async with real_factory() as inner: + original_execute = inner.execute + + async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any: + recorded.append(statement) + return await original_execute(statement, *args, **kwargs) + + cast(Any, inner).execute = recording_execute + try: + yield inner + finally: + cast(Any, inner).execute = original_execute + + session._session_factory = cast(Any, wrapped_session) + try: + retrieved_full = await session.get_items() + retrieved_limited = await session.get_items(limit=2) + finally: + session._session_factory = real_factory + + assert len(recorded) >= 2 + orderings_full = [str(clause) for clause in recorded[0]._order_by_clause] + assert orderings_full == [ + "agent_messages.created_at ASC", + "agent_messages.id ASC", + ] + + orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause] + assert orderings_limited == [ + "agent_messages.created_at DESC", + "agent_messages.id DESC", + ] + + assert _item_ids(retrieved_full) == ["rs_first", "msg_second"] + assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"]