Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/fraiseql/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,57 @@ def __init__(self, pool: AsyncConnectionPool, context: Optional[dict[str, Any]]
# Get query timeout from context or use default (30 seconds)
self.query_timeout = self.context.get("query_timeout", 30)

async def _set_session_variables(self, cursor_or_conn) -> None:
"""Set PostgreSQL session variables from context.

Sets app.tenant_id and app.contact_id session variables if present in context.
Uses SET LOCAL to scope variables to the current transaction.

Args:
cursor_or_conn: Either a psycopg cursor or an asyncpg connection
"""
from psycopg.sql import SQL, Literal

# Check if this is a cursor (psycopg) or connection (asyncpg)
is_cursor = hasattr(cursor_or_conn, "execute") and hasattr(cursor_or_conn, "fetchone")

if "tenant_id" in self.context:
if is_cursor:
await cursor_or_conn.execute(
SQL("SET LOCAL app.tenant_id = {}").format(
Literal(str(self.context["tenant_id"]))
)
)
else:
# asyncpg connection
await cursor_or_conn.execute(
"SET LOCAL app.tenant_id = $1", str(self.context["tenant_id"])
)

if "contact_id" in self.context:
if is_cursor:
await cursor_or_conn.execute(
SQL("SET LOCAL app.contact_id = {}").format(
Literal(str(self.context["contact_id"]))
)
)
else:
# asyncpg connection
await cursor_or_conn.execute(
"SET LOCAL app.contact_id = $1", str(self.context["contact_id"])
)
elif "user" in self.context:
# Fallback to 'user' if 'contact_id' not set
if is_cursor:
await cursor_or_conn.execute(
SQL("SET LOCAL app.contact_id = {}").format(Literal(str(self.context["user"])))
)
else:
# asyncpg connection
await cursor_or_conn.execute(
"SET LOCAL app.contact_id = $1", str(self.context["user"])
)

async def run(self, query: DatabaseQuery) -> list[dict[str, object]]:
"""Execute a SQL query using a connection from the pool.

Expand All @@ -88,6 +139,9 @@ async def run(self, query: DatabaseQuery) -> list[dict[str, object]]:
f"SET LOCAL statement_timeout = '{timeout_ms}ms'",
)

# Set session variables from context
await self._set_session_variables(cursor)

# Handle statement execution based on type and parameter presence
if isinstance(query.statement, Composed) and not query.params:
# Composed objects without params have only embedded literals
Expand Down Expand Up @@ -184,6 +238,9 @@ async def execute_function(
f"SET LOCAL statement_timeout = '{timeout_ms}ms'",
)

# Set session variables from context
await self._set_session_variables(cursor)

# Validate function name to prevent SQL injection
if not function_name.replace("_", "").replace(".", "").isalnum():
msg = f"Invalid function name: {function_name}"
Expand Down Expand Up @@ -264,6 +321,9 @@ async def execute_function_with_context(
f"SET LOCAL statement_timeout = '{timeout_ms}ms'",
)

# Set session variables from context
await self._set_session_variables(cursor)

await cursor.execute(
f"SELECT * FROM {function_name}({placeholders})",
tuple(params),
Expand All @@ -290,6 +350,9 @@ async def execute_function_with_context(
schema="pg_catalog",
)

# Set session variables from context
await self._set_session_variables(conn)

result = await conn.fetchrow(
f"SELECT * FROM {function_name}({placeholders})",
*params,
Expand Down Expand Up @@ -500,6 +563,9 @@ async def find_one(self, view_name: str, **kwargs) -> Optional[dict[str, Any]]:
f"SET LOCAL statement_timeout = '{timeout_ms}ms'",
)

# Set session variables from context
await self._set_session_variables(cursor)

# If we have a Composed statement with embedded Literals, execute without params
if isinstance(query.statement, (Composed, SQL)) and not query.params:
await cursor.execute(query.statement)
Expand Down Expand Up @@ -534,6 +600,9 @@ async def find_one(self, view_name: str, **kwargs) -> Optional[dict[str, Any]]:
f"SET LOCAL statement_timeout = '{timeout_ms}ms'",
)

# Set session variables from context
await self._set_session_variables(cursor)

# If we have a Composed statement with embedded Literals, execute without params
if isinstance(query.statement, (Composed, SQL)) and not query.params:
await cursor.execute(query.statement)
Expand Down
Loading
Loading