Skip to content

Commit

Permalink
add Zip
Browse files Browse the repository at this point in the history
  • Loading branch information
masonproffitt committed Dec 9, 2020
1 parent 94157c2 commit 67c9167
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 13 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ All defined s-expressions are listed here, though this specification will be exp
- Max: `(Max <source>)`
- Min: `(Min <source>)`
- Sum: `(Sum <source>)`
- Zip: `(Zip <source>)`


## Example
Expand Down
13 changes: 12 additions & 1 deletion qastle/linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def __init__(self, source):
self.source = source


class Zip(ast.AST):
def __init__(self, source):
self._fields = ['source']
self.source = source


linq_operator_names = ('Where',
'Select',
'SelectMany',
Expand All @@ -70,7 +76,8 @@ def __init__(self, source):
'Count',
'Max',
'Min',
'Sum')
'Sum',
'Zip')


class InsertLINQNodesTransformer(ast.NodeTransformer):
Expand Down Expand Up @@ -150,6 +157,10 @@ def visit_Call(self, node):
if len(args) != 0:
raise SyntaxError('Sum() call must have zero arguments')
return Sum(source=self.visit(source))
elif function_name == 'Zip':
if len(args) != 0:
raise SyntaxError('Zip() call must have zero arguments')
return Zip(source=self.visit(source))
else:
raise NameError('Unhandled LINQ operator: ' + function_name)

Expand Down
10 changes: 9 additions & 1 deletion qastle/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .linq_util import Where, Select, SelectMany, First, Aggregate, Count, Max, Min, Sum
from .linq_util import Where, Select, SelectMany, First, Aggregate, Count, Max, Min, Sum, Zip
from .ast_util import wrap_ast, unwrap_ast

import lark
Expand Down Expand Up @@ -200,6 +200,9 @@ def visit_Min(self, node):
def visit_Sum(self, node):
return self.make_composite_node_string('Sum', self.visit(node.source))

def visit_Zip(self, node):
return self.make_composite_node_string('Zip', self.visit(node.source))

def generic_visit(self, node):
raise SyntaxError('Unsupported node type: ' + str(type(node)))

Expand Down Expand Up @@ -430,5 +433,10 @@ def composite(self, children):
raise SyntaxError('Sum node must have one field; found ' + str(len(fields)))
return Sum(source=fields[0])

elif node_type == 'Zip':
if len(fields) != 1:
raise SyntaxError('Zip node must have one field; found ' + str(len(fields)))
return Zip(source=fields[0])

else:
raise SyntaxError('Unknown composite node type: ' + node_type)
6 changes: 6 additions & 0 deletions tests/test_ast_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,9 @@ def test_Sum():
first_node = Sum(source=unwrap_ast(ast.parse('data_source')))
assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node),
'(Sum data_source)')


def test_Zip():
first_node = Zip(source=unwrap_ast(ast.parse('data_source')))
assert_equivalent_python_ast_and_text_ast(wrap_ast(first_node),
'(Zip data_source)')
49 changes: 38 additions & 11 deletions tests/test_linq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def test_where_bad():
insert_linq_nodes(ast.parse('the_source.Where(None)'))


def test_select():
initial_ast = ast.parse("the_source.Select('lambda row: row')")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Select(source=unwrap_ast(ast.parse('the_source')),
selector=unwrap_ast(ast.parse('lambda row: row'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_select_composite():
initial_ast = ast.parse("the_source.First().Select('lambda row: row')")
final_ast = insert_linq_nodes(initial_ast)
Expand Down Expand Up @@ -193,17 +201,17 @@ def test_max_bad():
insert_linq_nodes(ast.parse('the_source.Max(None)'))


def test_min_composite():
initial_ast = ast.parse("the_source.First().Min()")
def test_min():
initial_ast = ast.parse("the_source.Min()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Min(source=First(source=unwrap_ast(ast.parse('the_source')))))
expected_ast = wrap_ast(Min(source=unwrap_ast(ast.parse('the_source'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_min():
initial_ast = ast.parse("the_source.Min()")
def test_min_composite():
initial_ast = ast.parse("the_source.First().Min()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Min(source=unwrap_ast(ast.parse('the_source'))))
expected_ast = wrap_ast(Min(source=First(source=unwrap_ast(ast.parse('the_source')))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


Expand All @@ -212,20 +220,39 @@ def test_min_bad():
insert_linq_nodes(ast.parse('the_source.Min(None)'))


def test_sum():
initial_ast = ast.parse("the_source.Sum()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Sum(source=unwrap_ast(ast.parse('the_source'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_sum_composite():
initial_ast = ast.parse("the_source.First().Sum()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Sum(source=First(source=unwrap_ast(ast.parse('the_source')))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_sum():
initial_ast = ast.parse("the_source.Sum()")
def test_sum_bad():
with pytest.raises(SyntaxError):
insert_linq_nodes(ast.parse('the_source.Sum(None)'))


def test_zip():
initial_ast = ast.parse("the_source.Zip()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Sum(source=unwrap_ast(ast.parse('the_source'))))
expected_ast = wrap_ast(Zip(source=unwrap_ast(ast.parse('the_source'))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_sum_bad():
def test_zip_composite():
initial_ast = ast.parse("the_source.First().Zip()")
final_ast = insert_linq_nodes(initial_ast)
expected_ast = wrap_ast(Zip(source=First(source=unwrap_ast(ast.parse('the_source')))))
assert_ast_nodes_are_equal(final_ast, expected_ast)


def test_zip_bad():
with pytest.raises(SyntaxError):
insert_linq_nodes(ast.parse('the_source.Sum(None)'))
insert_linq_nodes(ast.parse('the_source.Zip(None)'))

0 comments on commit 67c9167

Please sign in to comment.