Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,21 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.asc())
.order_by(
self._messages.c.created_at.asc(),
self._messages.c.id.asc(),
)
)
else:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
# Use DESC + LIMIT to get the latest N
# then reverse later for chronological order.
.order_by(self._messages.c.created_at.desc())
.order_by(
self._messages.c.created_at.desc(),
self._messages.c.id.desc(),
)
.limit(limit)
)

Expand Down Expand Up @@ -278,7 +284,10 @@ async def pop_item(self) -> TResponseInputItem | None:
subq = (
select(self._messages.c.id)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.desc())
.order_by(
self._messages.c.created_at.desc(),
self._messages.c.id.desc(),
)
.limit(1)
)
res = await sess.execute(subq)
Expand Down
240 changes: 240 additions & 0 deletions tests/extensions/memory/test_sqlalchemy_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from __future__ import annotations

import json
from collections.abc import Iterable, Sequence
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from typing import Any, cast

import pytest
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
from openai.types.responses.response_output_text_param import ResponseOutputTextParam
from openai.types.responses.response_reasoning_item_param import (
ResponseReasoningItemParam,
Summary,
)
from sqlalchemy import select, text, update
from sqlalchemy.sql import Select

pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed

Expand All @@ -16,6 +30,40 @@
DB_URL = "sqlite+aiosqlite:///:memory:"


def _make_message_item(item_id: str, text_value: str) -> TResponseInputItem:
content: ResponseOutputTextParam = {
"type": "output_text",
"text": text_value,
"annotations": [],
}
message: ResponseOutputMessageParam = {
"id": item_id,
"type": "message",
"role": "assistant",
"status": "completed",
"content": [content],
}
return cast(TResponseInputItem, message)


def _make_reasoning_item(item_id: str, summary_text: str) -> TResponseInputItem:
summary: Summary = {"type": "summary_text", "text": summary_text}
reasoning: ResponseReasoningItemParam = {
"id": item_id,
"type": "reasoning",
"summary": [summary],
}
return cast(TResponseInputItem, reasoning)


def _item_ids(items: Sequence[TResponseInputItem]) -> list[str]:
result: list[str] = []
for item in items:
item_dict = cast(dict[str, Any], item)
result.append(cast(str, item_dict["id"]))
return result


@pytest.fixture
def agent() -> Agent:
"""Fixture for a basic agent with a fake model."""
Expand Down Expand Up @@ -151,3 +199,195 @@ async def test_add_empty_items_list():

items_after_add = await session.get_items()
assert len(items_after_add) == 0


async def test_get_items_same_timestamp_consistent_order():
"""Test that items with identical timestamps keep insertion order."""
session_id = "same_timestamp_test"
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)

older_item = _make_message_item("older_same_ts", "old")
reasoning_item = _make_reasoning_item("rs_same_ts", "...")
message_item = _make_message_item("msg_same_ts", "...")
await session.add_items([older_item])
await session.add_items([reasoning_item, message_item])

async with session._session_factory() as sess:
rows = await sess.execute(
select(session._messages.c.id, session._messages.c.message_data).where(
session._messages.c.session_id == session.session_id
)
)
id_map = {
json.loads(message_json)["id"]: row_id
for row_id, message_json in rows.fetchall()
}
shared = datetime(2025, 10, 15, 17, 26, 39, 132483)
older = shared - timedelta(milliseconds=1)
await sess.execute(
update(session._messages)
.where(session._messages.c.id.in_(
[
id_map["rs_same_ts"],
id_map["msg_same_ts"],
]
))
.values(created_at=shared)
)
await sess.execute(
update(session._messages)
.where(session._messages.c.id == id_map["older_same_ts"])
.values(created_at=older)
)
await sess.commit()

real_factory = session._session_factory

class FakeResult:
def __init__(self, rows: Iterable[Any]):
self._rows = list(rows)

def all(self) -> list[Any]:
return list(self._rows)

def needs_shuffle(statement: Any) -> bool:
if not isinstance(statement, Select):
return False
orderings = list(statement._order_by_clause)
if not orderings:
return False
id_asc = session._messages.c.id.asc()
id_desc = session._messages.c.id.desc()

def references_id(clause) -> bool:
try:
return bool(clause.compare(id_asc) or clause.compare(id_desc))
except AttributeError:
return False

if any(references_id(clause) for clause in orderings):
return False
# Only shuffle queries that target the messages table.
target_tables: set[str] = set()
for from_clause in statement.get_final_froms():
name_attr = getattr(from_clause, "name", None)
if isinstance(name_attr, str):
target_tables.add(name_attr)
table_name_obj = getattr(session._messages, "name", "")
table_name = table_name_obj if isinstance(table_name_obj, str) else ""
return bool(table_name in target_tables)

@asynccontextmanager
async def shuffled_session():
async with real_factory() as inner:
original_execute = inner.execute

async def execute_with_shuffle(statement: Any, *args: Any, **kwargs: Any) -> Any:
result = await original_execute(statement, *args, **kwargs)
if needs_shuffle(statement):
rows = result.all()
shuffled = list(rows)
shuffled.reverse()
return FakeResult(shuffled)
return result

cast(Any, inner).execute = execute_with_shuffle
try:
yield inner
finally:
cast(Any, inner).execute = original_execute

session._session_factory = cast(Any, shuffled_session)
try:
retrieved = await session.get_items()
assert _item_ids(retrieved) == ["older_same_ts", "rs_same_ts", "msg_same_ts"]

latest_two = await session.get_items(limit=2)
assert _item_ids(latest_two) == ["rs_same_ts", "msg_same_ts"]
finally:
session._session_factory = real_factory


async def test_pop_item_same_timestamp_returns_latest():
"""Test that pop_item returns the newest item when timestamps tie."""
session_id = "same_timestamp_pop_test"
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)

reasoning_item = _make_reasoning_item("rs_pop_same_ts", "...")
message_item = _make_message_item("msg_pop_same_ts", "...")
await session.add_items([reasoning_item, message_item])

async with session._session_factory() as sess:
await sess.execute(
text(
"UPDATE agent_messages "
"SET created_at = :created_at "
"WHERE session_id = :session_id"
),
{
"created_at": "2025-10-15 17:26:39.132483",
"session_id": session.session_id,
},
)
await sess.commit()

popped = await session.pop_item()
assert popped is not None
assert cast(dict[str, Any], popped)["id"] == "msg_pop_same_ts"

remaining = await session.get_items()
assert _item_ids(remaining) == ["rs_pop_same_ts"]


async def test_get_items_orders_by_id_for_ties():
"""Test that get_items adds id ordering to break timestamp ties."""
session_id = "order_by_id_test"
session = SQLAlchemySession.from_url(session_id, url=DB_URL, create_tables=True)

await session.add_items(
[
_make_reasoning_item("rs_first", "..."),
_make_message_item("msg_second", "..."),
]
)

real_factory = session._session_factory
recorded: list[Any] = []

@asynccontextmanager
async def wrapped_session():
async with real_factory() as inner:
original_execute = inner.execute

async def recording_execute(statement: Any, *args: Any, **kwargs: Any) -> Any:
recorded.append(statement)
return await original_execute(statement, *args, **kwargs)

cast(Any, inner).execute = recording_execute
try:
yield inner
finally:
cast(Any, inner).execute = original_execute

session._session_factory = cast(Any, wrapped_session)
try:
retrieved_full = await session.get_items()
retrieved_limited = await session.get_items(limit=2)
finally:
session._session_factory = real_factory

assert len(recorded) >= 2
orderings_full = [str(clause) for clause in recorded[0]._order_by_clause]
assert orderings_full == [
"agent_messages.created_at ASC",
"agent_messages.id ASC",
]

orderings_limited = [str(clause) for clause in recorded[1]._order_by_clause]
assert orderings_limited == [
"agent_messages.created_at DESC",
"agent_messages.id DESC",
]

assert _item_ids(retrieved_full) == ["rs_first", "msg_second"]
assert _item_ids(retrieved_limited) == ["rs_first", "msg_second"]