Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flowquery-py/src/parsing/data_structures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from .associative_array import AssociativeArray
from .json_array import JSONArray
from .key_value_pair import KeyValuePair
from .list_comprehension import ListComprehension
from .lookup import Lookup
from .range_lookup import RangeLookup

__all__ = [
"AssociativeArray",
"JSONArray",
"KeyValuePair",
"ListComprehension",
"Lookup",
"RangeLookup",
]
90 changes: 90 additions & 0 deletions flowquery-py/src/parsing/data_structures/list_comprehension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Represents a Cypher-style list comprehension in the AST.

List comprehensions allow mapping and filtering arrays inline using the syntax:
[variable IN list | expression]
[variable IN list WHERE condition | expression]
[variable IN list WHERE condition]
[variable IN list]

Example:
[n IN [1, 2, 3] WHERE n > 1 | n * 2] => [4, 6]
"""

from typing import Any, List, Optional

from ..ast_node import ASTNode
from ..expressions.expression import Expression
from ..functions.value_holder import ValueHolder
from ..operations.where import Where


class ListComprehension(ASTNode):
"""Represents a list comprehension expression.

Children layout:
- Child 0: Reference (iteration variable)
- Child 1: Expression (source array)
- Child 2 (optional): Where (filter condition) or Expression (mapping)
- Child 3 (optional): Expression (mapping, when Where is child 2)
"""

def __init__(self) -> None:
super().__init__()
self._value_holder = ValueHolder()

@property
def reference(self) -> ASTNode:
"""The iteration variable reference."""
return self.first_child()

@property
def array(self) -> ASTNode:
"""The source array expression (unwrapped from its Expression wrapper)."""
return self.get_children()[1].first_child()

@property
def _return(self) -> Optional[Expression]:
"""The mapping expression, or None if not specified."""
children = self.get_children()
if len(children) <= 2:
return None
last = children[-1]
if isinstance(last, Where):
return None
return last if isinstance(last, Expression) else None

@property
def where(self) -> Optional[Where]:
"""The optional WHERE filter condition."""
for child in self.get_children():
if isinstance(child, Where):
return child
return None

def value(self) -> List[Any]:
"""Evaluate the list comprehension.

Iterates over the source array, applies the optional filter,
and maps each element through the return expression.

Returns:
The resulting filtered/mapped array.
"""
ref = self.reference
if hasattr(ref, "referred"):
ref.referred = self._value_holder
array = self.array.value()
if array is None or not isinstance(array, list):
raise ValueError("Expected array for list comprehension")
result: List[Any] = []
for item in array:
self._value_holder.holder = item
if self.where is None or self.where.value():
if self._return is not None:
result.append(self._return.value())
else:
result.append(item)
return result

def __str__(self) -> str:
return "ListComprehension"
86 changes: 86 additions & 0 deletions flowquery-py/src/parsing/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .data_structures.associative_array import AssociativeArray
from .data_structures.json_array import JSONArray
from .data_structures.key_value_pair import KeyValuePair
from .data_structures.list_comprehension import ListComprehension
from .data_structures.lookup import Lookup
from .data_structures.range_lookup import RangeLookup
from .expressions.expression import Expression
Expand Down Expand Up @@ -877,6 +878,13 @@ def _parse_operand(self, expression: Expression) -> bool:
lookup = self._parse_lookup(sub)
expression.add_node(lookup)
return True
elif self.token.is_opening_bracket() and self._looks_like_list_comprehension():
list_comp = self._parse_list_comprehension()
if list_comp is None:
raise ValueError("Expected list comprehension")
lookup = self._parse_lookup(list_comp)
expression.add_node(lookup)
return True
elif self.token.is_opening_brace() or self.token.is_opening_bracket():
json = self._parse_json()
if json is None:
Expand Down Expand Up @@ -1290,6 +1298,84 @@ def _parse_function_parameters(self) -> Iterator[ASTNode]:
break
self.set_next_token()

def _looks_like_list_comprehension(self) -> bool:
"""Peek ahead from an opening bracket to determine whether the
upcoming tokens form a list comprehension (e.g. ``[n IN list | n.name]``)
rather than a plain JSON array literal (e.g. ``[1, 2, 3]``).

The heuristic is: ``[`` identifier ``IN`` -> list comprehension.
"""
saved_index = self._token_index
self.set_next_token() # skip '['
self._skip_whitespace_and_comments()

if not self.token.is_identifier_or_keyword():
self._token_index = saved_index
return False

self.set_next_token() # skip identifier
self._skip_whitespace_and_comments()
result = self.token.is_in()
self._token_index = saved_index
return result

def _parse_list_comprehension(self) -> Optional[ListComprehension]:
"""Parse a list comprehension expression.

Syntax: ``[variable IN list [WHERE condition] [| expression]]``
"""
if not self.token.is_opening_bracket():
return None

list_comp = ListComprehension()
self.set_next_token() # skip '['
self._skip_whitespace_and_comments()

# Parse iteration variable
if not self.token.is_identifier_or_keyword():
raise ValueError("Expected identifier")
reference = Reference(self.token.value or "")
self._state.variables[reference.identifier] = reference
list_comp.add_child(reference)
self.set_next_token()
self._expect_and_skip_whitespace_and_comments()

# Parse IN keyword
if not self.token.is_in():
raise ValueError("Expected IN")
self.set_next_token()
self._expect_and_skip_whitespace_and_comments()

# Parse source array expression
array_expr = self._parse_expression()
if array_expr is None:
raise ValueError("Expected expression")
list_comp.add_child(array_expr)

# Optional WHERE clause
self._skip_whitespace_and_comments()
where = self._parse_where()
if where is not None:
list_comp.add_child(where)

# Optional | mapping expression
self._skip_whitespace_and_comments()
if self.token.is_pipe():
self.set_next_token()
self._skip_whitespace_and_comments()
return_expr = self._parse_expression()
if return_expr is None:
raise ValueError("Expected expression after |")
list_comp.add_child(return_expr)

self._skip_whitespace_and_comments()
if not self.token.is_closing_bracket():
raise ValueError("Expected closing bracket")
self.set_next_token()

del self._state.variables[reference.identifier]
return list_comp

def _parse_json(self) -> Optional[ASTNode]:
if self.token.is_opening_brace():
return self._parse_associative_array()
Expand Down
77 changes: 77 additions & 0 deletions flowquery-py/tests/compute/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,83 @@ async def test_range_function(self):
assert len(results) == 1
assert results[0] == {"range": [1, 2, 3]}

@pytest.mark.asyncio
async def test_list_comprehension_with_mapping(self):
"""Test list comprehension with mapping expression."""
runner = Runner("RETURN [n IN [1, 2, 3] | n * 2] AS doubled")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"doubled": [2, 4, 6]}

@pytest.mark.asyncio
async def test_list_comprehension_with_where_filter(self):
"""Test list comprehension with WHERE filter."""
runner = Runner("RETURN [n IN [1, 2, 3, 4, 5] WHERE n > 2] AS filtered")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"filtered": [3, 4, 5]}

@pytest.mark.asyncio
async def test_list_comprehension_with_where_and_mapping(self):
"""Test list comprehension with WHERE and mapping."""
runner = Runner("RETURN [n IN [1, 2, 3, 4] WHERE n > 1 | n ^ 2] AS result")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"result": [4, 9, 16]}

@pytest.mark.asyncio
async def test_list_comprehension_identity(self):
"""Test list comprehension identity (no WHERE, no mapping)."""
runner = Runner("RETURN [n IN [10, 20, 30]] AS result")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"result": [10, 20, 30]}

@pytest.mark.asyncio
async def test_list_comprehension_with_variable_reference(self):
"""Test list comprehension with variable reference."""
runner = Runner("WITH [1, 2, 3] AS nums RETURN [n IN nums | n + 10] AS result")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"result": [11, 12, 13]}

@pytest.mark.asyncio
async def test_list_comprehension_with_property_access(self):
"""Test list comprehension with property access."""
runner = Runner(
'WITH [{name: "Alice", age: 30}, {name: "Bob", age: 25}] AS people '
'RETURN [p IN people | p.name] AS names'
)
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"names": ["Alice", "Bob"]}

@pytest.mark.asyncio
async def test_list_comprehension_with_function_source(self):
"""Test list comprehension with function as source."""
runner = Runner("RETURN [n IN range(1, 5) WHERE n > 3 | n * 10] AS result")
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"result": [40, 50]}

@pytest.mark.asyncio
async def test_list_comprehension_with_size(self):
"""Test list comprehension composed with size."""
runner = Runner(
"RETURN size([n IN [1, 2, 3, 4, 5] WHERE n > 2]) AS count"
)
await runner.run()
results = runner.results
assert len(results) == 1
assert results[0] == {"count": 3}

@pytest.mark.asyncio
async def test_range_function_with_unwind_and_case(self):
"""Test range function with unwind and case."""
Expand Down
36 changes: 36 additions & 0 deletions flowquery-py/tests/parsing/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,3 +1235,39 @@ def test_order_by_expression_with_limit(self):
"return x order by toLower(x) asc limit 2"
)
assert ast is not None

def test_list_comprehension_with_mapping(self):
"""Test list comprehension with mapping parses correctly."""
parser = Parser()
ast = parser.parse("RETURN [n IN [1, 2, 3] | n * 2] AS doubled")
assert "ListComprehension" in ast.print()

def test_list_comprehension_with_where_and_mapping(self):
"""Test list comprehension with WHERE and mapping."""
parser = Parser()
ast = parser.parse("RETURN [n IN [1, 2, 3] WHERE n > 1 | n * 2] AS result")
output = ast.print()
assert "ListComprehension" in output
assert "Where" in output

def test_list_comprehension_with_where_only(self):
"""Test list comprehension with WHERE only."""
parser = Parser()
ast = parser.parse("RETURN [n IN [1, 2, 3, 4] WHERE n > 2] AS filtered")
output = ast.print()
assert "ListComprehension" in output
assert "Where" in output

def test_list_comprehension_identity(self):
"""Test list comprehension identity."""
parser = Parser()
ast = parser.parse("RETURN [n IN [1, 2, 3]] AS result")
assert "ListComprehension" in ast.print()

def test_regular_json_array_still_parses(self):
"""Regular JSON array still parses correctly alongside list comprehension."""
parser = Parser()
ast = parser.parse("RETURN [1, 2, 3] AS arr")
output = ast.print()
assert "JSONArray" in output
assert "ListComprehension" not in output
Loading