Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 82 additions & 79 deletions src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import json
import logging
import threading
from contextlib import closing
from pathlib import Path
from typing import Any, Union, cast
Expand Down Expand Up @@ -146,7 +145,7 @@ def _get_all_items_sync():
"""Synchronous helper to get all items for a branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
with closing(conn.cursor()) as cursor:
if limit is None:
cursor.execute(
Expand Down Expand Up @@ -191,7 +190,7 @@ def _get_items_sync():
"""Synchronous helper to get items for a specific branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
with closing(conn.cursor()) as cursor:
# Get message IDs in correct order for this branch
if limit is None:
Expand Down Expand Up @@ -261,18 +260,19 @@ def _get_next_turn_number(self, branch_id: str) -> int:
The next available turn number for the specified branch.
"""
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(user_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, branch_id),
)
result = cursor.fetchone()
max_turn = result[0] if result else 0
return max_turn + 1
with self._lock:
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(user_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, branch_id),
)
result = cursor.fetchone()
max_turn = result[0] if result else 0
return max_turn + 1

def _get_next_branch_turn_number(self, branch_id: str) -> int:
"""Get the next branch turn number for a specific branch.
Expand All @@ -284,18 +284,19 @@ def _get_next_branch_turn_number(self, branch_id: str) -> int:
The next available branch turn number for the specified branch.
"""
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(branch_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, branch_id),
)
result = cursor.fetchone()
max_turn = result[0] if result else 0
return max_turn + 1
with self._lock:
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(branch_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, branch_id),
)
result = cursor.fetchone()
max_turn = result[0] if result else 0
return max_turn + 1

def _get_current_turn_number(self) -> int:
"""Get the current turn number for the current branch.
Expand All @@ -304,17 +305,18 @@ def _get_current_turn_number(self) -> int:
The current turn number for the active branch.
"""
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(user_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, self._current_branch_id),
)
result = cursor.fetchone()
return result[0] if result else 0
with self._lock:
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT COALESCE(MAX(user_turn_number), 0)
FROM message_structure
WHERE session_id = ? AND branch_id = ?
""",
(self.session_id, self._current_branch_id),
)
result = cursor.fetchone()
return result[0] if result else 0

async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None:
"""Extract structure metadata with branch-aware turn tracking.
Expand All @@ -333,7 +335,7 @@ def _add_structure_sync():
"""Synchronous helper to add structure metadata to database."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
# Get the IDs of messages we just inserted, in order
with closing(conn.cursor()) as cursor:
cursor.execute(
Expand Down Expand Up @@ -439,7 +441,7 @@ def _cleanup_sync():
"""Synchronous helper to cleanup orphaned messages."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
with closing(conn.cursor()) as cursor:
# Find messages without structure metadata
cursor.execute(
Expand Down Expand Up @@ -586,7 +588,7 @@ def _validate_turn():
except Exception:
return "Unable to parse content"

turn_content = await asyncio.to_thread(_validate_turn)
turn_content = await self._to_thread_with_lock(_validate_turn)

# Generate branch name if not provided
if branch_name is None:
Expand Down Expand Up @@ -655,7 +657,7 @@ def _validate_branch():
if count == 0:
raise ValueError(f"Branch '{branch_id}' does not exist")

await asyncio.to_thread(_validate_branch)
await self._to_thread_with_lock(_validate_branch)

old_branch = self._current_branch_id
self._current_branch_id = branch_id
Expand Down Expand Up @@ -694,7 +696,7 @@ def _delete_sync():
"""Synchronous helper to delete branch and associated data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
with closing(conn.cursor()) as cursor:
# First verify the branch exists
cursor.execute(
Expand Down Expand Up @@ -756,36 +758,37 @@ async def list_branches(self) -> list[dict[str, Any]]:
def _list_branches_sync():
"""Synchronous helper to list all branches."""
conn = self._get_connection()
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT
ms.branch_id,
COUNT(*) as message_count,
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
MIN(ms.created_at) as created_at
FROM message_structure ms
WHERE ms.session_id = ?
GROUP BY ms.branch_id
ORDER BY created_at
""",
(self.session_id,),
)

branches = []
for row in cursor.fetchall():
branch_id, msg_count, user_turns, created_at = row
branches.append(
{
"branch_id": branch_id,
"message_count": msg_count,
"user_turns": user_turns,
"is_current": branch_id == self._current_branch_id,
"created_at": created_at,
}
with self._lock:
with closing(conn.cursor()) as cursor:
cursor.execute(
"""
SELECT
ms.branch_id,
COUNT(*) as message_count,
COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns,
MIN(ms.created_at) as created_at
FROM message_structure ms
WHERE ms.session_id = ?
GROUP BY ms.branch_id
ORDER BY created_at
""",
(self.session_id,),
)

return branches
branches = []
for row in cursor.fetchall():
branch_id, msg_count, user_turns, created_at = row
branches.append(
{
"branch_id": branch_id,
"message_count": msg_count,
"user_turns": user_turns,
"is_current": branch_id == self._current_branch_id,
"created_at": created_at,
}
)

return branches

return await asyncio.to_thread(_list_branches_sync)

Expand All @@ -801,7 +804,7 @@ def _copy_sync():
"""Synchronous helper to copy messages to new branch."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
with closing(conn.cursor()) as cursor:
# Get all messages before the branch point
cursor.execute(
Expand Down Expand Up @@ -928,7 +931,7 @@ def _get_turns_sync():

return turns

return await asyncio.to_thread(_get_turns_sync)
return await self._to_thread_with_lock(_get_turns_sync)

async def find_turns_by_content(
self, search_term: str, branch_id: str | None = None
Expand Down Expand Up @@ -984,7 +987,7 @@ def _search_sync():

return matches

return await asyncio.to_thread(_search_sync)
return await self._to_thread_with_lock(_search_sync)

async def get_conversation_by_turns(
self, branch_id: str | None = None
Expand Down Expand Up @@ -1022,7 +1025,7 @@ def _get_conversation_sync():
turns[turn_num].append({"type": msg_type, "tool_name": tool_name})
return turns

return await asyncio.to_thread(_get_conversation_sync)
return await self._to_thread_with_lock(_get_conversation_sync)

async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]:
"""Get all tool usage by turn for specified branch.
Expand Down Expand Up @@ -1056,7 +1059,7 @@ def _get_tool_usage_sync():
)
return cursor.fetchall()

return await asyncio.to_thread(_get_tool_usage_sync)
return await self._to_thread_with_lock(_get_tool_usage_sync)

async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None:
"""Get cumulative usage for session or specific branch.
Expand All @@ -1072,7 +1075,7 @@ def _get_usage_sync():
"""Synchronous helper to get session usage data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
if branch_id:
# Branch-specific usage
query = """
Expand Down Expand Up @@ -1220,7 +1223,7 @@ def _get_turn_usage_sync():
)
return results

result = await asyncio.to_thread(_get_turn_usage_sync)
result = await self._to_thread_with_lock(_get_turn_usage_sync)

return cast(Union[list[dict[str, Any]], dict[str, Any]], result)

Expand All @@ -1236,7 +1239,7 @@ def _update_sync():
"""Synchronous helper to update turn usage data."""
conn = self._get_connection()
# TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501
with self._lock if self._is_memory_db else threading.Lock():
with self._lock:
# Serialize token details as JSON
input_details_json = None
output_details_json = None
Expand Down
Loading