diff --git a/src/fraiseql/sql/where_generator.py b/src/fraiseql/sql/where_generator.py index 546829606..926b25dcc 100644 --- a/src/fraiseql/sql/where_generator.py +++ b/src/fraiseql/sql/where_generator.py @@ -38,9 +38,12 @@ class DynamicType(Protocol): """Protocol for dynamic filter types convertible to SQL WHERE clause strings.""" - def to_sql(self) -> Composed | None: + def to_sql(self, parent_path: str | None = None) -> Composed | None: """Return a properly parameterized SQL snippet representing this filter. + Args: + parent_path: Optional JSONB path from parent for nested objects. + Returns: A psycopg Composed object with parameterized SQL, or None if no condition. """ @@ -70,6 +73,21 @@ def build_operator_composed( return registry.build_sql(path_sql, op, val, field_type) +def _build_nested_path(parent_path: str | None, field_name: str) -> str: + """Build a JSONB path for nested object fields. + + Args: + parent_path: The parent JSONB path (e.g., "data -> 'parent'") + field_name: The field name to append to the path + + Returns: + A JSONB path string for the nested field + """ + if parent_path: + return f"{parent_path} -> '{field_name}'" + return f"data -> '{field_name}'" + + def _make_filter_field_composed( name: str, valdict: dict[str, object], @@ -118,7 +136,7 @@ def _build_where_to_sql( fields: list[str], type_hints: dict[str, type] | None = None, graphql_info: Any | None = None, -) -> Callable[[object], Composed | None]: +) -> Callable[[object, str | None], Composed | None]: """Build a `to_sql` method for a dynamic filter dataclass. Args: @@ -127,10 +145,10 @@ def _build_where_to_sql( graphql_info: Optional GraphQL resolve info context for field type extraction. Returns: - A function suitable as a `to_sql(self)` method returning Composed SQL. + A function suitable as a `to_sql(self, parent_path)` method returning Composed SQL. """ - def to_sql(self: object) -> Composed | None: + def to_sql(self: object, parent_path: str | None = None) -> Composed | None: # Enhance type hints with GraphQL context if available enhanced_type_hints = type_hints if graphql_info: @@ -161,33 +179,36 @@ def to_sql(self: object) -> Composed | None: if isinstance(val, list): for item in val: if hasattr(item, "to_sql"): - item_sql = item.to_sql() + item_sql = item.to_sql(parent_path) if item_sql: logical_or.append(item_sql) elif name == "AND": if isinstance(val, list): for item in val: if hasattr(item, "to_sql"): - item_sql = item.to_sql() + item_sql = item.to_sql(parent_path) if item_sql: logical_and.append(item_sql) elif name == "NOT": if hasattr(val, "to_sql"): - not_sql = val.to_sql() + not_sql = val.to_sql(parent_path) if not_sql: logical_not = Composed([SQL("NOT ("), not_sql, SQL(")")]) # Handle regular fields elif hasattr(val, "to_sql"): - # Assume val is another DynamicType - sql = val.to_sql() + # For nested objects, build the JSONB path by appending the field name + nested_path = _build_nested_path(parent_path, name) + sql = val.to_sql(nested_path) if sql: conditions.append(sql) elif isinstance(val, dict): field_type = enhanced_type_hints.get(name) if enhanced_type_hints else None + # Use parent_path if provided, otherwise default to "data" + json_path = parent_path if parent_path else "data" cond = _make_filter_field_composed( name, cast("dict[str, object]", val), - "data", + json_path, field_type, ) if cond: diff --git a/tests/integration/database/repository/test_nested_object_filter_integration.py b/tests/integration/database/repository/test_nested_object_filter_integration.py index 2862846c2..2df3db5cb 100644 --- a/tests/integration/database/repository/test_nested_object_filter_integration.py +++ b/tests/integration/database/repository/test_nested_object_filter_integration.py @@ -42,9 +42,12 @@ def test_nested_filter_conversion_to_sql(self): AllocationWhereInput = create_graphql_where_input(Allocation) # Create a nested filter + test_machine_id = uuid.uuid4() where_input = AllocationWhereInput( machine=MachineWhereInput( - is_current=BooleanFilter(eq=True), name=StringFilter(contains="Server") + id=UUIDFilter(eq=test_machine_id), + is_current=BooleanFilter(eq=True), + name=StringFilter(contains="Server") ), status=StringFilter(eq="active"), ) @@ -60,10 +63,27 @@ def test_nested_filter_conversion_to_sql(self): assert sql_where.machine is not None assert sql_where.status == {"eq": "active"} - # Generate SQL to ensure it doesn't error + # Generate SQL and validate its correctness sql = sql_where.to_sql() assert sql is not None + # To properly check the generated SQL, we need to examine the SQL components + # Check that the nested path is correctly constructed as SQL("data -> 'machine'") + sql_str = str(sql) + + # The SQL object should contain the nested path for machine fields + # Looking for SQL("data -> 'machine'") in the representation + assert 'SQL("data -> \'machine\'")' in sql_str, \ + f"Expected nested JSONB path for machine fields, but got: {sql_str}" + + # Root level status filter should just use 'data' + # Count occurrences - should have both nested and root level paths + assert sql_str.count('SQL("data -> \'machine\'")') == 3, \ + f"Expected 3 nested machine paths (for id, name, is_current), but got: {sql_str}" + + assert 'SQL(\'data\')' in sql_str, \ + f"Expected root-level data access for status field, but got: {sql_str}" + def test_nested_filter_with_none_values(self): """Test that None values in nested filters are handled correctly.""" create_graphql_where_input(Machine) @@ -118,6 +138,20 @@ class AllocationDeep: assert hasattr(sql_where, "machine") assert sql_where.machine is not None + # Generate SQL and verify deep nesting paths + sql = sql_where.to_sql() + assert sql is not None + sql_str = str(sql) + + # Check that deeply nested paths are correctly generated + # Machine name should be at: data -> 'machine' ->> 'name' + assert 'SQL("data -> \'machine\'")' in sql_str, \ + f"Expected nested path for machine.name, but got: {sql_str}" + + # Location city should be at: data -> 'machine' -> 'location' ->> 'city' + assert 'SQL("data -> \'machine\' -> \'location\'")' in sql_str, \ + f"Expected deeply nested path for machine.location.city, but got: {sql_str}" + def test_mixed_scalar_and_nested_filters(self): """Test mixing scalar and nested object filters.""" MachineWhereInput = create_graphql_where_input(Machine) diff --git a/tests/integration/database/sql/test_where_generator.py b/tests/integration/database/sql/test_where_generator.py index e9083acd2..56918b9f4 100644 --- a/tests/integration/database/sql/test_where_generator.py +++ b/tests/integration/database/sql/test_where_generator.py @@ -475,7 +475,8 @@ class Parent: # Validate complete SQL - adjusted for our casting approach assert "((data ->> 'id'))::numeric = 1" in sql_str - assert "(data ->> 'name') = 'test'" in sql_str + # Child's name should now be accessed via nested path: data -> 'child' ->> 'name' + assert "(data -> 'child' ->> 'name') = 'test'" in sql_str class TestEdgeCases: