Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/fraiseql/sql/where_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
class DynamicType(Protocol):
"""Protocol for dynamic filter types convertible to SQL WHERE clause strings."""

def to_sql(self) -> Composed | None:
def to_sql(self, parent_path: str | None = None) -> Composed | None:
"""Return a properly parameterized SQL snippet representing this filter.

Args:
parent_path: Optional JSONB path from parent for nested objects.

Returns:
A psycopg Composed object with parameterized SQL, or None if no condition.
"""
Expand Down Expand Up @@ -70,6 +73,21 @@ def build_operator_composed(
return registry.build_sql(path_sql, op, val, field_type)


def _build_nested_path(parent_path: str | None, field_name: str) -> str:
"""Build a JSONB path for nested object fields.

Args:
parent_path: The parent JSONB path (e.g., "data -> 'parent'")
field_name: The field name to append to the path

Returns:
A JSONB path string for the nested field
"""
if parent_path:
return f"{parent_path} -> '{field_name}'"
return f"data -> '{field_name}'"


def _make_filter_field_composed(
name: str,
valdict: dict[str, object],
Expand Down Expand Up @@ -118,7 +136,7 @@ def _build_where_to_sql(
fields: list[str],
type_hints: dict[str, type] | None = None,
graphql_info: Any | None = None,
) -> Callable[[object], Composed | None]:
) -> Callable[[object, str | None], Composed | None]:
"""Build a `to_sql` method for a dynamic filter dataclass.

Args:
Expand All @@ -127,10 +145,10 @@ def _build_where_to_sql(
graphql_info: Optional GraphQL resolve info context for field type extraction.

Returns:
A function suitable as a `to_sql(self)` method returning Composed SQL.
A function suitable as a `to_sql(self, parent_path)` method returning Composed SQL.
"""

def to_sql(self: object) -> Composed | None:
def to_sql(self: object, parent_path: str | None = None) -> Composed | None:
# Enhance type hints with GraphQL context if available
enhanced_type_hints = type_hints
if graphql_info:
Expand Down Expand Up @@ -161,33 +179,36 @@ def to_sql(self: object) -> Composed | None:
if isinstance(val, list):
for item in val:
if hasattr(item, "to_sql"):
item_sql = item.to_sql()
item_sql = item.to_sql(parent_path)
if item_sql:
logical_or.append(item_sql)
elif name == "AND":
if isinstance(val, list):
for item in val:
if hasattr(item, "to_sql"):
item_sql = item.to_sql()
item_sql = item.to_sql(parent_path)
if item_sql:
logical_and.append(item_sql)
elif name == "NOT":
if hasattr(val, "to_sql"):
not_sql = val.to_sql()
not_sql = val.to_sql(parent_path)
if not_sql:
logical_not = Composed([SQL("NOT ("), not_sql, SQL(")")])
# Handle regular fields
elif hasattr(val, "to_sql"):
# Assume val is another DynamicType
sql = val.to_sql()
# For nested objects, build the JSONB path by appending the field name
nested_path = _build_nested_path(parent_path, name)
sql = val.to_sql(nested_path)
if sql:
conditions.append(sql)
elif isinstance(val, dict):
field_type = enhanced_type_hints.get(name) if enhanced_type_hints else None
# Use parent_path if provided, otherwise default to "data"
json_path = parent_path if parent_path else "data"
cond = _make_filter_field_composed(
name,
cast("dict[str, object]", val),
"data",
json_path,
field_type,
)
if cond:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def test_nested_filter_conversion_to_sql(self):
AllocationWhereInput = create_graphql_where_input(Allocation)

# Create a nested filter
test_machine_id = uuid.uuid4()
where_input = AllocationWhereInput(
machine=MachineWhereInput(
is_current=BooleanFilter(eq=True), name=StringFilter(contains="Server")
id=UUIDFilter(eq=test_machine_id),
is_current=BooleanFilter(eq=True),
name=StringFilter(contains="Server")
),
status=StringFilter(eq="active"),
)
Expand All @@ -60,10 +63,27 @@ def test_nested_filter_conversion_to_sql(self):
assert sql_where.machine is not None
assert sql_where.status == {"eq": "active"}

# Generate SQL to ensure it doesn't error
# Generate SQL and validate its correctness
sql = sql_where.to_sql()
assert sql is not None

# To properly check the generated SQL, we need to examine the SQL components
# Check that the nested path is correctly constructed as SQL("data -> 'machine'")
sql_str = str(sql)

# The SQL object should contain the nested path for machine fields
# Looking for SQL("data -> 'machine'") in the representation
assert 'SQL("data -> \'machine\'")' in sql_str, \
f"Expected nested JSONB path for machine fields, but got: {sql_str}"

# Root level status filter should just use 'data'
# Count occurrences - should have both nested and root level paths
assert sql_str.count('SQL("data -> \'machine\'")') == 3, \
f"Expected 3 nested machine paths (for id, name, is_current), but got: {sql_str}"

assert 'SQL(\'data\')' in sql_str, \
f"Expected root-level data access for status field, but got: {sql_str}"

def test_nested_filter_with_none_values(self):
"""Test that None values in nested filters are handled correctly."""
create_graphql_where_input(Machine)
Expand Down Expand Up @@ -118,6 +138,20 @@ class AllocationDeep:
assert hasattr(sql_where, "machine")
assert sql_where.machine is not None

# Generate SQL and verify deep nesting paths
sql = sql_where.to_sql()
assert sql is not None
sql_str = str(sql)

# Check that deeply nested paths are correctly generated
# Machine name should be at: data -> 'machine' ->> 'name'
assert 'SQL("data -> \'machine\'")' in sql_str, \
f"Expected nested path for machine.name, but got: {sql_str}"

# Location city should be at: data -> 'machine' -> 'location' ->> 'city'
assert 'SQL("data -> \'machine\' -> \'location\'")' in sql_str, \
f"Expected deeply nested path for machine.location.city, but got: {sql_str}"

def test_mixed_scalar_and_nested_filters(self):
"""Test mixing scalar and nested object filters."""
MachineWhereInput = create_graphql_where_input(Machine)
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/database/sql/test_where_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ class Parent:

# Validate complete SQL - adjusted for our casting approach
assert "((data ->> 'id'))::numeric = 1" in sql_str
assert "(data ->> 'name') = 'test'" in sql_str
# Child's name should now be accessed via nested path: data -> 'child' ->> 'name'
assert "(data -> 'child' ->> 'name') = 'test'" in sql_str


class TestEdgeCases:
Expand Down
Loading