From a4d77fe9c3e07ce47bb1784d5634bdceebb561ae Mon Sep 17 00:00:00 2001 From: Lucas Wang Date: Thu, 23 Oct 2025 11:14:08 +0800 Subject: [PATCH 1/2] Fix SQLiteSession threading.Lock() bug and file descriptor leak This PR addresses two critical bugs in SQLiteSession: ## Bug 1: threading.Lock() creating new instances **Problem:** In SQLiteSession (4 places) and AdvancedSQLiteSession (8 places), the code used: ```python with self._lock if self._is_memory_db else threading.Lock(): ``` For file-based databases, this creates a NEW Lock() instance on every operation, providing NO thread safety whatsoever. Only in-memory databases used self._lock. **Impact:** - File-based SQLiteSession had zero thread protection - Race conditions possible but masked by WAL mode's own concurrency handling ## Bug 2: File descriptor leak **Problem:** Thread-local connections in ThreadPoolExecutor are never cleaned up: - asyncio.to_thread() uses ThreadPoolExecutor internally - Each worker thread creates a connection on first use - ThreadPoolExecutor reuses threads indefinitely - Connections persist until program exit, accumulating file descriptors **Evidence:** Testing on main branch (60s, 40 concurrent workers): - My system (FD limit 1,048,575): +789 FDs leaked, 0 errors (limit not reached) - @ihower's system (likely limit 1,024): 646,802 errors in 20 seconds Error: `sqlite3.OperationalError: unable to open database file` ## Solution: Unified shared connection approach Instead of managing thread-local connections that can't be reliably cleaned up in ThreadPoolExecutor, use a single shared connection for all database types. **Changes:** 1. Removed thread-local connection logic (eliminates FD leak root cause) 2. All database types now use shared connection + self._lock 3. SQLite's WAL mode provides sufficient concurrency even with single connection 4. Fixed all 12 instances of threading.Lock() bug (4 in SQLiteSession, 8 in AdvancedSQLiteSession) 5. Kept _is_memory_db attribute for backward compatibility with AdvancedSQLiteSession 6. Added close() and __del__() methods for proper cleanup **Results (60s stress test, 30 writers + 10 readers):** ``` Main branch: - FD growth: +789 (leak) - Throughput: 701 ops/s - Errors: 0 on high-limit systems, 646k+ on normal systems After fix: - FD growth: +44 (stable) - Throughput: 726 ops/s (+3.6% improvement) - Errors: 0 on all systems - All 29 SQLite tests pass ``` ## Why shared connection performs better SQLite's WAL (Write-Ahead Logging) mode already provides: - Multiple concurrent readers - One writer coexisting with readers - Readers don't block writer - Writer doesn't block readers (except during checkpoint) The overhead of managing multiple connections outweighs any concurrency benefit. ## Backward compatibility The _is_memory_db attribute is preserved for AdvancedSQLiteSession compatibility, even though the implementation no longer differentiates connection strategies. ## Testing Comprehensive stress test available at: https://gist.github.com/gn00295120/0b6a65fe6c0ac6b7a1ce23654eed3ffe Run with: `python sqlite_stress_test_final.py` --- .../memory/advanced_sqlite_session.py | 161 +++++++++--------- src/agents/memory/sqlite_session.py | 84 ++++----- 2 files changed, 127 insertions(+), 118 deletions(-) diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index fefb73026..9ab233ac6 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -3,7 +3,6 @@ import asyncio import json import logging -import threading from contextlib import closing from pathlib import Path from typing import Any, Union, cast @@ -146,7 +145,7 @@ def _get_all_items_sync(): """Synchronous helper to get all items for a branch.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: with closing(conn.cursor()) as cursor: if limit is None: cursor.execute( @@ -191,7 +190,7 @@ def _get_items_sync(): """Synchronous helper to get items for a specific branch.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: with closing(conn.cursor()) as cursor: # Get message IDs in correct order for this branch if limit is None: @@ -261,18 +260,19 @@ def _get_next_turn_number(self, branch_id: str) -> int: The next available turn number for the specified branch. """ conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(user_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, branch_id), - ) - result = cursor.fetchone() - max_turn = result[0] if result else 0 - return max_turn + 1 + with self._lock: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 def _get_next_branch_turn_number(self, branch_id: str) -> int: """Get the next branch turn number for a specific branch. @@ -284,18 +284,19 @@ def _get_next_branch_turn_number(self, branch_id: str) -> int: The next available branch turn number for the specified branch. """ conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(branch_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, branch_id), - ) - result = cursor.fetchone() - max_turn = result[0] if result else 0 - return max_turn + 1 + with self._lock: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(branch_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, branch_id), + ) + result = cursor.fetchone() + max_turn = result[0] if result else 0 + return max_turn + 1 def _get_current_turn_number(self) -> int: """Get the current turn number for the current branch. @@ -304,17 +305,18 @@ def _get_current_turn_number(self) -> int: The current turn number for the active branch. """ conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT COALESCE(MAX(user_turn_number), 0) - FROM message_structure - WHERE session_id = ? AND branch_id = ? - """, - (self.session_id, self._current_branch_id), - ) - result = cursor.fetchone() - return result[0] if result else 0 + with self._lock: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT COALESCE(MAX(user_turn_number), 0) + FROM message_structure + WHERE session_id = ? AND branch_id = ? + """, + (self.session_id, self._current_branch_id), + ) + result = cursor.fetchone() + return result[0] if result else 0 async def _add_structure_metadata(self, items: list[TResponseInputItem]) -> None: """Extract structure metadata with branch-aware turn tracking. @@ -333,7 +335,7 @@ def _add_structure_sync(): """Synchronous helper to add structure metadata to database.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: # Get the IDs of messages we just inserted, in order with closing(conn.cursor()) as cursor: cursor.execute( @@ -439,7 +441,7 @@ def _cleanup_sync(): """Synchronous helper to cleanup orphaned messages.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: with closing(conn.cursor()) as cursor: # Find messages without structure metadata cursor.execute( @@ -586,7 +588,7 @@ def _validate_turn(): except Exception: return "Unable to parse content" - turn_content = await asyncio.to_thread(_validate_turn) + turn_content = await self._to_thread_with_lock(_validate_turn) # Generate branch name if not provided if branch_name is None: @@ -655,7 +657,7 @@ def _validate_branch(): if count == 0: raise ValueError(f"Branch '{branch_id}' does not exist") - await asyncio.to_thread(_validate_branch) + await self._to_thread_with_lock(_validate_branch) old_branch = self._current_branch_id self._current_branch_id = branch_id @@ -694,7 +696,7 @@ def _delete_sync(): """Synchronous helper to delete branch and associated data.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: with closing(conn.cursor()) as cursor: # First verify the branch exists cursor.execute( @@ -756,36 +758,37 @@ async def list_branches(self) -> list[dict[str, Any]]: def _list_branches_sync(): """Synchronous helper to list all branches.""" conn = self._get_connection() - with closing(conn.cursor()) as cursor: - cursor.execute( - """ - SELECT - ms.branch_id, - COUNT(*) as message_count, - COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, - MIN(ms.created_at) as created_at - FROM message_structure ms - WHERE ms.session_id = ? - GROUP BY ms.branch_id - ORDER BY created_at - """, - (self.session_id,), - ) - - branches = [] - for row in cursor.fetchall(): - branch_id, msg_count, user_turns, created_at = row - branches.append( - { - "branch_id": branch_id, - "message_count": msg_count, - "user_turns": user_turns, - "is_current": branch_id == self._current_branch_id, - "created_at": created_at, - } + with self._lock: + with closing(conn.cursor()) as cursor: + cursor.execute( + """ + SELECT + ms.branch_id, + COUNT(*) as message_count, + COUNT(CASE WHEN ms.message_type = 'user' THEN 1 END) as user_turns, + MIN(ms.created_at) as created_at + FROM message_structure ms + WHERE ms.session_id = ? + GROUP BY ms.branch_id + ORDER BY created_at + """, + (self.session_id,), ) - return branches + branches = [] + for row in cursor.fetchall(): + branch_id, msg_count, user_turns, created_at = row + branches.append( + { + "branch_id": branch_id, + "message_count": msg_count, + "user_turns": user_turns, + "is_current": branch_id == self._current_branch_id, + "created_at": created_at, + } + ) + + return branches return await asyncio.to_thread(_list_branches_sync) @@ -801,7 +804,7 @@ def _copy_sync(): """Synchronous helper to copy messages to new branch.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: with closing(conn.cursor()) as cursor: # Get all messages before the branch point cursor.execute( @@ -928,7 +931,7 @@ def _get_turns_sync(): return turns - return await asyncio.to_thread(_get_turns_sync) + return await self._to_thread_with_lock(_get_turns_sync) async def find_turns_by_content( self, search_term: str, branch_id: str | None = None @@ -984,7 +987,7 @@ def _search_sync(): return matches - return await asyncio.to_thread(_search_sync) + return await self._to_thread_with_lock(_search_sync) async def get_conversation_by_turns( self, branch_id: str | None = None @@ -1022,7 +1025,7 @@ def _get_conversation_sync(): turns[turn_num].append({"type": msg_type, "tool_name": tool_name}) return turns - return await asyncio.to_thread(_get_conversation_sync) + return await self._to_thread_with_lock(_get_conversation_sync) async def get_tool_usage(self, branch_id: str | None = None) -> list[tuple[str, int, int]]: """Get all tool usage by turn for specified branch. @@ -1056,7 +1059,7 @@ def _get_tool_usage_sync(): ) return cursor.fetchall() - return await asyncio.to_thread(_get_tool_usage_sync) + return await self._to_thread_with_lock(_get_tool_usage_sync) async def get_session_usage(self, branch_id: str | None = None) -> dict[str, int] | None: """Get cumulative usage for session or specific branch. @@ -1072,7 +1075,7 @@ def _get_usage_sync(): """Synchronous helper to get session usage data.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: if branch_id: # Branch-specific usage query = """ @@ -1220,7 +1223,7 @@ def _get_turn_usage_sync(): ) return results - result = await asyncio.to_thread(_get_turn_usage_sync) + result = await self._to_thread_with_lock(_get_turn_usage_sync) return cast(Union[list[dict[str, Any]], dict[str, Any]], result) @@ -1236,7 +1239,7 @@ def _update_sync(): """Synchronous helper to update turn usage data.""" conn = self._get_connection() # TODO: Refactor SQLiteSession to use asyncio.Lock instead of threading.Lock and update this code # noqa: E501 - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: # Serialize token details as JSON input_details_json = None output_details_json = None diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 2c2386ec7..434760d51 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -38,40 +38,43 @@ def __init__( self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table - self._local = threading.local() - self._lock = threading.Lock() + self._lock = threading.RLock() - # For in-memory databases, we need a shared connection to avoid thread isolation - # For file databases, we use thread-local connections for better concurrency + # Keep _is_memory_db for backward compatibility with AdvancedSQLiteSession self._is_memory_db = str(db_path) == ":memory:" - if self._is_memory_db: - self._shared_connection = sqlite3.connect(":memory:", check_same_thread=False) - self._shared_connection.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(self._shared_connection) - else: - # For file databases, initialize the schema once since it persists - init_conn = sqlite3.connect(str(self.db_path), check_same_thread=False) - init_conn.execute("PRAGMA journal_mode=WAL") - self._init_db_for_connection(init_conn) - init_conn.close() + + # Use a shared connection for all database types + # This avoids file descriptor leaks from thread-local connections + # WAL mode enables concurrent readers/writers even with a shared connection + self._shared_connection = sqlite3.connect(str(db_path), check_same_thread=False) + self._shared_connection.execute("PRAGMA journal_mode=WAL") + self._init_db_for_connection(self._shared_connection) def _get_connection(self) -> sqlite3.Connection: """Get a database connection.""" - if self._is_memory_db: - # Use shared connection for in-memory database to avoid thread isolation - return self._shared_connection - else: - # Use thread-local connections for file databases - if not hasattr(self._local, "connection"): - self._local.connection = sqlite3.connect( - str(self.db_path), - check_same_thread=False, - ) - self._local.connection.execute("PRAGMA journal_mode=WAL") - assert isinstance(self._local.connection, sqlite3.Connection), ( - f"Expected sqlite3.Connection, got {type(self._local.connection)}" - ) - return self._local.connection + return self._shared_connection + + async def _to_thread_with_lock(self, func, *args, **kwargs): + """Execute a function in a thread pool with lock protection. + + This ensures thread-safe access to the shared database connection + when operations are executed via asyncio.to_thread(). Uses RLock + so it's safe to call even if the lock is already held. + + Args: + func: The function to execute + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + The result of the function execution + """ + + def wrapped(): + with self._lock: + return func(*args, **kwargs) + + return await asyncio.to_thread(wrapped) def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: """Initialize the database schema for a specific connection.""" @@ -120,7 +123,7 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: def _get_items_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: if limit is None: # Fetch all items in chronological order cursor = conn.execute( @@ -174,7 +177,7 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: def _add_items_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: # Ensure session exists conn.execute( f""" @@ -215,7 +218,7 @@ async def pop_item(self) -> TResponseInputItem | None: def _pop_item_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: # Use DELETE with RETURNING to atomically delete and return the most recent item cursor = conn.execute( f""" @@ -252,7 +255,7 @@ async def clear_session(self) -> None: def _clear_session_sync(): conn = self._get_connection() - with self._lock if self._is_memory_db else threading.Lock(): + with self._lock: conn.execute( f"DELETE FROM {self.messages_table} WHERE session_id = ?", (self.session_id,), @@ -267,9 +270,12 @@ def _clear_session_sync(): def close(self) -> None: """Close the database connection.""" - if self._is_memory_db: - if hasattr(self, "_shared_connection"): - self._shared_connection.close() - else: - if hasattr(self._local, "connection"): - self._local.connection.close() + if hasattr(self, "_shared_connection"): + self._shared_connection.close() + + def __del__(self) -> None: + """Ensure connection is closed when the session is garbage collected.""" + try: + self.close() + except Exception: + pass # Ignore errors during finalization From 7d72810127fd3e4eb5587521dace5798691e763c Mon Sep 17 00:00:00 2001 From: Lucas Wang Date: Thu, 23 Oct 2025 12:08:08 +0800 Subject: [PATCH 2/2] Fix mypy type annotations for _to_thread_with_lock Add TypeVar to preserve return types through the _to_thread_with_lock wrapper method. This fixes mypy errors about returning Any from functions with concrete return type declarations. - Import Callable and TypeVar from typing - Add generic TypeVar T - Annotate _to_thread_with_lock with proper generic types - Annotate wrapped() inner function return type --- src/agents/memory/sqlite_session.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 434760d51..635541c76 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -5,10 +5,13 @@ import sqlite3 import threading from pathlib import Path +from typing import Callable, TypeVar from ..items import TResponseInputItem from .session import SessionABC +T = TypeVar("T") + class SQLiteSession(SessionABC): """SQLite-based implementation of session storage. @@ -54,7 +57,7 @@ def _get_connection(self) -> sqlite3.Connection: """Get a database connection.""" return self._shared_connection - async def _to_thread_with_lock(self, func, *args, **kwargs): + async def _to_thread_with_lock(self, func: Callable[..., T], *args, **kwargs) -> T: """Execute a function in a thread pool with lock protection. This ensures thread-safe access to the shared database connection @@ -70,7 +73,7 @@ async def _to_thread_with_lock(self, func, *args, **kwargs): The result of the function execution """ - def wrapped(): + def wrapped() -> T: with self._lock: return func(*args, **kwargs)