Skip to content

Commit

Permalink
fix: Expr bugs (#649)
Browse files Browse the repository at this point in the history
* fix: expression short-circuiting was buggy
  • Loading branch information
gaurav274 committed Apr 13, 2023
1 parent 612b590 commit 9ff5a3c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 67 deletions.
2 changes: 0 additions & 2 deletions eva/expression/constant_value_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, value: Any, v_type: ColumnType = ColumnType.INTEGER):

def evaluate(self, batch: Batch, **kwargs):
batch = Batch(pd.DataFrame({0: [self._value] * len(batch)}))
if "mask" in kwargs:
batch = batch[kwargs["mask"]]
return batch

def signature(self) -> str:
Expand Down
18 changes: 8 additions & 10 deletions eva/expression/logical_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,24 @@ def __init__(
exp_type, rtype=ExpressionReturnType.BOOLEAN, children=children
)

def evaluate(self, *args, **kwargs):
def evaluate(self, batch, **kwargs):
if self.get_children_count() == 2:

left_batch = self.get_child(0).evaluate(*args, **kwargs)
left_batch = self.get_child(0).evaluate(batch, **kwargs)
if self.etype == ExpressionType.LOGICAL_AND:

if left_batch.all_false(): # check if all are false
return left_batch
kwargs["mask"] = left_batch.create_mask()
mask = left_batch.create_mask()
elif self.etype == ExpressionType.LOGICAL_OR:
if left_batch.all_true(): # check if all are true
return left_batch
kwargs["mask"] = left_batch.create_inverted_mask()
right_batch = self.get_child(1).evaluate(*args, **kwargs)
left_batch.update_indices(kwargs["mask"], right_batch)
mask = left_batch.create_inverted_mask()

right_batch = self.get_child(1).evaluate(batch[mask], **kwargs)
left_batch.update_indices(mask, right_batch)

return left_batch
else:
batch = self.get_child(0).evaluate(*args, **kwargs)
batch = self.get_child(0).evaluate(batch, **kwargs)
if self.etype == ExpressionType.LOGICAL_NOT:
batch.invert()
return batch
Expand All @@ -65,7 +64,6 @@ def __eq__(self, other):
return is_subtree_equal and self.etype == other.etype

def get_symbol(self) -> str:

if self.etype == ExpressionType.LOGICAL_AND:
return "AND"
elif self.etype == ExpressionType.LOGICAL_OR:
Expand Down
2 changes: 0 additions & 2 deletions eva/expression/tuple_value_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def col_alias(self, value: str):
self._col_alias = value

def evaluate(self, batch: Batch, *args, **kwargs):
if "mask" in kwargs:
batch = batch[kwargs["mask"]]
return batch.project([self.col_alias])

def signature(self):
Expand Down
84 changes: 82 additions & 2 deletions test/expression/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_short_circuiting_and_partial(self):
self.assertEqual(
[True, False, False, False], logical_exp.evaluate(tuples).frames[0].tolist()
)
comp_exp_r.evaluate.assert_called_once_with(tuples, mask=[0, 1])
comp_exp_r.evaluate.assert_called_once_with(tuples[[0, 1]])

def test_short_circuiting_or_partial(self):
# tests whether right-hand side is partially executed with or
Expand All @@ -169,4 +169,84 @@ def test_short_circuiting_or_partial(self):
self.assertEqual(
[True, False, True, True], logical_exp.evaluate(tuples).frames[0].tolist()
)
comp_exp_r.evaluate.assert_called_once_with(tuples, mask=[0, 1])
comp_exp_r.evaluate.assert_called_once_with(tuples[[0, 1]])

def test_multiple_logical(self):
batch = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))

# col > 1
comp_left = ComparisonExpression(
ExpressionType.COMPARE_GREATER,
TupleValueExpression(col_alias="col"),
ConstantValueExpression(1),
)

batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[list(range(2, 10))]
batch_copy.drop_zero(comp_left.evaluate(batch))
self.assertEqual(batch_copy, expected)

# col < 8
comp_right = ComparisonExpression(
ExpressionType.COMPARE_LESSER,
TupleValueExpression(col_alias="col"),
ConstantValueExpression(8),
)

batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[list(range(0, 8))]
batch_copy.drop_zero(comp_right.evaluate(batch))
self.assertEqual(batch_copy, expected)

# col >= 5
comp_expr = ComparisonExpression(
ExpressionType.COMPARE_GEQ,
TupleValueExpression(col_alias="col"),
ConstantValueExpression(5),
)
batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[list(range(5, 10))]
batch_copy.drop_zero(comp_expr.evaluate(batch))
self.assertEqual(batch_copy, expected)

# (col >= 5) AND (col > 1 AND col < 8)
l_expr = LogicalExpression(ExpressionType.LOGICAL_AND, comp_left, comp_right)
root_l_expr = LogicalExpression(ExpressionType.LOGICAL_AND, comp_expr, l_expr)
batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[[5, 6, 7]]
batch_copy.drop_zero(root_l_expr.evaluate(batch))
self.assertEqual(batch_copy, expected)

# (col > 1 AND col < 8) AND (col >= 5)
root_l_expr = LogicalExpression(ExpressionType.LOGICAL_AND, l_expr, comp_expr)
batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[[5, 6, 7]]
batch_copy.drop_zero(root_l_expr.evaluate(batch))
self.assertEqual(batch_copy, expected)

# (col >=4 AND col <= 7) AND (col > 1 AND col < 8)
between_4_7 = LogicalExpression(
ExpressionType.LOGICAL_AND,
ComparisonExpression(
ExpressionType.COMPARE_GEQ,
TupleValueExpression(col_alias="col"),
ConstantValueExpression(4),
),
ComparisonExpression(
ExpressionType.COMPARE_LEQ,
TupleValueExpression(col_alias="col"),
ConstantValueExpression(7),
),
)
test_expr = LogicalExpression(ExpressionType.LOGICAL_AND, between_4_7, l_expr)
batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[[4, 5, 6, 7]]
batch_copy.drop_zero(test_expr.evaluate(batch))
self.assertEqual(batch_copy, expected)

# (col >=4 AND col <= 7) OR (col > 1 AND col < 8)
test_expr = LogicalExpression(ExpressionType.LOGICAL_OR, between_4_7, l_expr)
batch_copy = Batch(pd.DataFrame({"col": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}))
expected = batch[[2, 3, 4, 5, 6, 7]]
batch_copy.drop_zero(test_expr.evaluate(batch))
self.assertEqual(batch_copy, expected)
51 changes: 0 additions & 51 deletions test/expression/test_tuple_value.py

This file was deleted.

0 comments on commit 9ff5a3c

Please sign in to comment.