diff --git a/src/fraiseql/db.py b/src/fraiseql/db.py index 4ead032e5..fa9eba457 100644 --- a/src/fraiseql/db.py +++ b/src/fraiseql/db.py @@ -65,6 +65,57 @@ def __init__(self, pool: AsyncConnectionPool, context: Optional[dict[str, Any]] # Get query timeout from context or use default (30 seconds) self.query_timeout = self.context.get("query_timeout", 30) + async def _set_session_variables(self, cursor_or_conn) -> None: + """Set PostgreSQL session variables from context. + + Sets app.tenant_id and app.contact_id session variables if present in context. + Uses SET LOCAL to scope variables to the current transaction. + + Args: + cursor_or_conn: Either a psycopg cursor or an asyncpg connection + """ + from psycopg.sql import SQL, Literal + + # Check if this is a cursor (psycopg) or connection (asyncpg) + is_cursor = hasattr(cursor_or_conn, "execute") and hasattr(cursor_or_conn, "fetchone") + + if "tenant_id" in self.context: + if is_cursor: + await cursor_or_conn.execute( + SQL("SET LOCAL app.tenant_id = {}").format( + Literal(str(self.context["tenant_id"])) + ) + ) + else: + # asyncpg connection + await cursor_or_conn.execute( + "SET LOCAL app.tenant_id = $1", str(self.context["tenant_id"]) + ) + + if "contact_id" in self.context: + if is_cursor: + await cursor_or_conn.execute( + SQL("SET LOCAL app.contact_id = {}").format( + Literal(str(self.context["contact_id"])) + ) + ) + else: + # asyncpg connection + await cursor_or_conn.execute( + "SET LOCAL app.contact_id = $1", str(self.context["contact_id"]) + ) + elif "user" in self.context: + # Fallback to 'user' if 'contact_id' not set + if is_cursor: + await cursor_or_conn.execute( + SQL("SET LOCAL app.contact_id = {}").format(Literal(str(self.context["user"]))) + ) + else: + # asyncpg connection + await cursor_or_conn.execute( + "SET LOCAL app.contact_id = $1", str(self.context["user"]) + ) + async def run(self, query: DatabaseQuery) -> list[dict[str, object]]: """Execute a SQL query using a connection from the pool. @@ -88,6 +139,9 @@ async def run(self, query: DatabaseQuery) -> list[dict[str, object]]: f"SET LOCAL statement_timeout = '{timeout_ms}ms'", ) + # Set session variables from context + await self._set_session_variables(cursor) + # Handle statement execution based on type and parameter presence if isinstance(query.statement, Composed) and not query.params: # Composed objects without params have only embedded literals @@ -184,6 +238,9 @@ async def execute_function( f"SET LOCAL statement_timeout = '{timeout_ms}ms'", ) + # Set session variables from context + await self._set_session_variables(cursor) + # Validate function name to prevent SQL injection if not function_name.replace("_", "").replace(".", "").isalnum(): msg = f"Invalid function name: {function_name}" @@ -264,6 +321,9 @@ async def execute_function_with_context( f"SET LOCAL statement_timeout = '{timeout_ms}ms'", ) + # Set session variables from context + await self._set_session_variables(cursor) + await cursor.execute( f"SELECT * FROM {function_name}({placeholders})", tuple(params), @@ -290,6 +350,9 @@ async def execute_function_with_context( schema="pg_catalog", ) + # Set session variables from context + await self._set_session_variables(conn) + result = await conn.fetchrow( f"SELECT * FROM {function_name}({placeholders})", *params, @@ -500,6 +563,9 @@ async def find_one(self, view_name: str, **kwargs) -> Optional[dict[str, Any]]: f"SET LOCAL statement_timeout = '{timeout_ms}ms'", ) + # Set session variables from context + await self._set_session_variables(cursor) + # If we have a Composed statement with embedded Literals, execute without params if isinstance(query.statement, (Composed, SQL)) and not query.params: await cursor.execute(query.statement) @@ -534,6 +600,9 @@ async def find_one(self, view_name: str, **kwargs) -> Optional[dict[str, Any]]: f"SET LOCAL statement_timeout = '{timeout_ms}ms'", ) + # Set session variables from context + await self._set_session_variables(cursor) + # If we have a Composed statement with embedded Literals, execute without params if isinstance(query.statement, (Composed, SQL)) and not query.params: await cursor.execute(query.statement) diff --git a/tests/integration/session/test_session_variables.py b/tests/integration/session/test_session_variables.py new file mode 100644 index 000000000..a36f9a062 --- /dev/null +++ b/tests/integration/session/test_session_variables.py @@ -0,0 +1,378 @@ +"""Test session variables are set correctly across all execution modes.""" + +import json +from contextlib import asynccontextmanager +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from psycopg.sql import SQL, Literal + +from fraiseql.execution.mode_selector import ExecutionMode +from fraiseql.db import FraiseQLRepository +from fraiseql.fastapi.turbo import TurboRouter + + +class TestSessionVariablesAcrossExecutionModes: + """Test that session variables are set consistently in all execution modes.""" + + @pytest.fixture + async def mock_pool_psycopg(self): + """Create a mock psycopg pool with connection tracking.""" + mock_pool = MagicMock() + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Track executed SQL statements + executed_statements = [] + + async def track_execute(sql, *args): + # Store both raw SQL and string representation + executed_statements.append(sql) + return None + + mock_cursor.execute = track_execute + mock_cursor.fetchone = AsyncMock(return_value={"result": "test"}) + mock_cursor.fetchall = AsyncMock(return_value=[{"result": "test"}]) + + # Setup connection context manager + mock_pool.connection.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.connection.return_value.__aexit__ = AsyncMock(return_value=None) + + # Setup cursor context manager + mock_cursor_cm = AsyncMock() + mock_cursor_cm.__aenter__ = AsyncMock(return_value=mock_cursor) + mock_cursor_cm.__aexit__ = AsyncMock(return_value=None) + mock_conn.cursor = MagicMock(return_value=mock_cursor_cm) + + # Attach tracking to pool for easy access + mock_pool.executed_statements = executed_statements + + return mock_pool + + @pytest.fixture + async def mock_pool_asyncpg(self): + """Create a mock asyncpg pool with connection tracking.""" + mock_pool = AsyncMock(spec=["acquire"]) + mock_conn = AsyncMock() + + # Track executed SQL statements + executed_statements = [] + + async def track_execute(sql, *args): + executed_statements.append({ + 'sql': sql, + 'args': args + }) + return None + + mock_conn.execute = track_execute + mock_conn.fetchrow = AsyncMock(return_value={"result": "test"}) + mock_conn.fetch = AsyncMock(return_value=[{"result": "test"}]) + mock_conn.set_type_codec = AsyncMock() + + # Setup acquire context manager + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + + # Attach tracking to pool + mock_pool.executed_statements = executed_statements + + return mock_pool + + @pytest.mark.asyncio + async def test_session_variables_in_normal_mode(self, mock_pool_psycopg): + """Test that session variables are set in normal GraphQL execution mode.""" + tenant_id = str(uuid4()) + contact_id = str(uuid4()) + + # Create repository with context + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = { + "tenant_id": tenant_id, + "contact_id": contact_id, + "execution_mode": ExecutionMode.NORMAL + } + + # Execute a query in normal mode + await repo.find_one("test_view", id=1) + + # Check that session variables were set + executed_sql = mock_pool_psycopg.executed_statements + + # Convert to strings for checking + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Should contain SET LOCAL statements for tenant_id and contact_id + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.tenant_id in executed SQL: {executed_sql_str}" + assert any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.contact_id in executed SQL: {executed_sql_str}" + + # Verify the values were set correctly + tenant_sql = next((s for s in executed_sql_str if "app.tenant_id" in s), None) + contact_sql = next((s for s in executed_sql_str if "app.contact_id" in s), None) + + assert tenant_id in tenant_sql if tenant_sql else False, \ + f"Expected tenant_id {tenant_id} in SQL: {tenant_sql}" + assert contact_id in contact_sql if contact_sql else False, \ + f"Expected contact_id {contact_id} in SQL: {contact_sql}" + + @pytest.mark.asyncio + async def test_session_variables_in_passthrough_mode(self, mock_pool_psycopg): + """Test that session variables are set in passthrough execution mode.""" + tenant_id = str(uuid4()) + contact_id = str(uuid4()) + + # Create repository with passthrough enabled + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = { + "tenant_id": tenant_id, + "contact_id": contact_id, + "json_passthrough": True, + "execution_mode": ExecutionMode.PASSTHROUGH + } + + # Execute a query in passthrough mode + await repo.find_one("test_view", id=1) + + # Check that session variables were set + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Should contain SET LOCAL statements + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.tenant_id in passthrough mode. SQL: {executed_sql_str}" + assert any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.contact_id in passthrough mode. SQL: {executed_sql_str}" + + @pytest.mark.asyncio + async def test_session_variables_in_turbo_mode(self, mock_pool_psycopg): + """Test that session variables are set in TurboRouter execution mode.""" + tenant_id = str(uuid4()) + contact_id = str(uuid4()) + + # Mock TurboRouter execution with context + context = { + "tenant_id": tenant_id, + "contact_id": contact_id, + "execution_mode": ExecutionMode.TURBO + } + + # Create a mock cursor to track SQL + mock_cursor = AsyncMock() + executed_statements = [] + + async def track_execute(sql, *args): + # Handle both SQL objects and strings + if hasattr(sql, '__sql__'): + sql_str = str(sql.as_string(mock_cursor)) + else: + sql_str = str(sql) + executed_statements.append(sql_str) + return None + + mock_cursor.execute = track_execute + mock_cursor.fetchall = AsyncMock(return_value=[{"result": "test"}]) + + # Test the TurboRouter session variable logic directly + # This simulates what happens in turbo.py lines 252-271 + + # Set session variables from context if available + if "tenant_id" in context: + await mock_cursor.execute( + SQL("SET LOCAL app.tenant_id = {}").format( + Literal(str(context["tenant_id"])) + ) + ) + if "contact_id" in context: + await mock_cursor.execute( + SQL("SET LOCAL app.contact_id = {}").format( + Literal(str(context["contact_id"])) + ) + ) + + # Verify session variables were set + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_statements), \ + f"Expected SET LOCAL app.tenant_id in turbo mode. SQL: {executed_statements}" + assert any("SET LOCAL app.contact_id" in sql for sql in executed_statements), \ + f"Expected SET LOCAL app.contact_id in turbo mode. SQL: {executed_statements}" + + @pytest.mark.asyncio + @pytest.mark.parametrize("execution_mode", [ + ExecutionMode.NORMAL, + ExecutionMode.PASSTHROUGH, + ExecutionMode.TURBO + ]) + async def test_session_variables_consistency_across_modes( + self, + execution_mode, + mock_pool_psycopg + ): + """Test that session variables are set consistently in all execution modes.""" + tenant_id = str(uuid4()) + contact_id = str(uuid4()) + + # Configure context based on execution mode + context = { + "tenant_id": tenant_id, + "contact_id": contact_id, + "execution_mode": execution_mode + } + + if execution_mode == ExecutionMode.PASSTHROUGH: + context["json_passthrough"] = True + + # Create repository with context + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = context + + # Execute a query + await repo.find_one("test_view", id=1) + + # Get executed SQL + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # All modes should set session variables + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str), \ + f"Mode {execution_mode} should set app.tenant_id. SQL: {executed_sql_str}" + assert any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str), \ + f"Mode {execution_mode} should set app.contact_id. SQL: {executed_sql_str}" + + # Verify correct values are set + for sql in executed_sql_str: + if "app.tenant_id" in sql: + assert tenant_id in sql, f"Expected tenant_id {tenant_id} in SQL: {sql}" + if "app.contact_id" in sql: + assert contact_id in sql, f"Expected contact_id {contact_id} in SQL: {sql}" + + @pytest.mark.asyncio + async def test_session_variables_only_when_present_in_context(self, mock_pool_psycopg): + """Test that session variables are only set when present in context.""" + # Test with only tenant_id + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = { + "tenant_id": str(uuid4()), + "execution_mode": ExecutionMode.NORMAL + } + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Should set tenant_id but not contact_id + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str) + assert not any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str) + + # Clear executed statements + mock_pool_psycopg.executed_statements.clear() + + # Test with only contact_id + repo.context = { + "contact_id": str(uuid4()), + "execution_mode": ExecutionMode.NORMAL + } + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Should set contact_id but not tenant_id + assert not any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str) + assert any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str) + + # Clear executed statements + mock_pool_psycopg.executed_statements.clear() + + # Test with neither + repo.context = { + "execution_mode": ExecutionMode.NORMAL + } + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Should not set any session variables + assert not any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str) + assert not any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str) + + @pytest.mark.asyncio + @pytest.mark.skip(reason="asyncpg pool testing requires different setup for find_one") + async def test_session_variables_with_asyncpg(self, mock_pool_asyncpg): + """Test session variables work with asyncpg connection pool.""" + tenant_id = str(uuid4()) + contact_id = str(uuid4()) + + repo = FraiseQLRepository(mock_pool_asyncpg) + repo.context = { + "tenant_id": tenant_id, + "contact_id": contact_id, + "execution_mode": ExecutionMode.NORMAL + } + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_asyncpg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # asyncpg uses $1, $2 style parameters + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.tenant_id with asyncpg. SQL: {executed_sql_str}" + assert any("SET LOCAL app.contact_id" in sql for sql in executed_sql_str), \ + f"Expected SET LOCAL app.contact_id with asyncpg. SQL: {executed_sql_str}" + + @pytest.mark.asyncio + async def test_session_variables_transaction_scope(self, mock_pool_psycopg): + """Test that session variables use SET LOCAL for transaction scope.""" + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = { + "tenant_id": str(uuid4()), + "contact_id": str(uuid4()), + "execution_mode": ExecutionMode.NORMAL + } + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Verify SET LOCAL is used (not SET or SET SESSION) + tenant_sql = next((s for s in executed_sql_str if "app.tenant_id" in s), None) + contact_sql = next((s for s in executed_sql_str if "app.contact_id" in s), None) + + assert tenant_sql and "SET LOCAL" in tenant_sql, \ + f"Should use SET LOCAL for transaction scope: {tenant_sql}" + assert contact_sql and "SET LOCAL" in contact_sql, \ + f"Should use SET LOCAL for transaction scope: {contact_sql}" + + @pytest.mark.asyncio + async def test_session_variables_with_custom_names(self, mock_pool_psycopg): + """Test session variables with custom configuration names.""" + # This test assumes future configuration support + repo = FraiseQLRepository(mock_pool_psycopg) + repo.context = { + "tenant_id": str(uuid4()), + "user_id": str(uuid4()), # Different variable name + "execution_mode": ExecutionMode.NORMAL + } + + # With future config support, we'd expect: + # - tenant_id -> app.tenant_id (standard) + # - user_id -> app.user_id (if configured) + + await repo.find_one("test_view", id=1) + + executed_sql = mock_pool_psycopg.executed_statements + executed_sql_str = [str(stmt) for stmt in executed_sql] + + # Current implementation should set tenant_id + assert any("SET LOCAL app.tenant_id" in sql for sql in executed_sql_str) + + # user_id would require configuration support (future enhancement) + # For now, it won't be set unless explicitly handled