Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix equality #1600

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
"expr = (t2\n",
" [t2.bigint_col > 30]\n",
" .group_by('string_col')\n",
" .aggregate([t2.foo.min().name('min_foo'),\n",
" t2.foo.max().name('max_foo'),\n",
" t2.foo.sum().name('sum_foo')]))\n",
" .aggregate(min_foo=lambda t: t.foo.min(),\n",
" max_foo=lambda t: t.foo.max(),\n",
" sum_foo=lambda t: t.foo.sum()))\n",
"expr"
]
},
Expand Down
6 changes: 3 additions & 3 deletions ibis/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,16 @@ def test_where_use_if(con, alltypes, translate):
# con.explain(expr)


@pytest.mark.xfail(
raises=com.RelationError, reason='Expression equality is broken')
def test_filter_predicates(diamonds):
t = diamonds

predicates = [
lambda x: x.color.lower().like('%de%'),
# lambda x: x.color.lower().contains('de'),
lambda x: x.color.lower().rlike('.*ge.*')
]

expr = t
expr = diamonds
for pred in predicates:
expr = expr[pred(expr)].projection([expr])

Expand Down
11 changes: 4 additions & 7 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _check_fusion(self, root):

# a * projection
if (isinstance(val, ir.TableExpr) and
(self.parent.op().equals(val.op()) or
(self.parent.op().compatible_with(val.op()) or
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
# gross we share the same table root. Better way to
# detect?
len(roots) == 1 and val._root_tables()[0] is roots[0])):
Expand Down Expand Up @@ -851,19 +851,16 @@ def roots_shared(self, node):

def shares_some_roots(self, expr):
expr_roots = expr._root_tables()
return any(self._among_roots(root)
for root in expr_roots)
return any(self._among_roots(root) for root in expr_roots)

def shares_one_root(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(root)
for root in expr_roots)
total = sum(self.roots_shared(root) for root in expr_roots)
return total == 1

def shares_multiple_roots(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(expr_roots)
for root in expr_roots)
total = sum(self.roots_shared(expr_roots) for root in expr_roots)
return total > 1

def validate_all(self, exprs):
Expand Down
71 changes: 54 additions & 17 deletions ibis/expr/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def equals(self, other, cache=None):
cache[(self, other)] = True
return True

def compatible_with(self, other):
return self.equals(other)

def is_ancestor(self, other):
if isinstance(other, ir.Expr):
other = other.op()
Expand Down Expand Up @@ -138,13 +141,27 @@ def has_resolved_name(self):


def all_equal(left, right, cache=None):
"""Check whether two objects `left` and `right` are equal.

Parameters
----------
left : Union[object, Expr, Node]
right : Union[object, Expr, Node]
cache : Optional[Dict[Tuple[Node, Node], bool]]
A dictionary indicating whether two Nodes are equal
"""
if util.is_iterable(left):
return util.is_iterable(right) and all(
map(functools.partial(all_equal, cache=cache), left, right))
# check that left and right are equal length iterables and that all
# of their elements are equal
return util.is_iterable(right) and len(left) == len(right) and all(
itertools.starmap(
functools.partial(all_equal, cache=cache),
zip(left, right)
)
)

if hasattr(left, 'equals'):
return left.equals(right, cache=cache)
return left == right
return (hasattr(left, 'equals') and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the previous was more readable.

left.equals(right, cache=cache)) or left == right


_table_names = ('unbound_table_{:d}'.format(i) for i in itertools.count())
Expand All @@ -155,7 +172,6 @@ def genname():


class TableNode(Node):

def get_type(self, name):
return self.schema[name]

Expand All @@ -168,6 +184,22 @@ def aggregate(self, this, metrics, by=None, having=None):
def sort_by(self, expr, sort_exprs):
return Selection(expr, [], sort_keys=sort_exprs)

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()

if self.equals(other):
return True

fn = lambda e: (lin.proceed, e.op()) # noqa: E731
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps It'd be useful to factor out this block to a lin.flatten function. I start to like even more this lin.traverse :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, can you open an issue for that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done #1612

expr = self.to_expr()
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
return False


class TableColumn(ValueOp):
"""Selects a column from a TableExpr"""
Expand Down Expand Up @@ -1862,22 +1894,27 @@ def root_tables(self):
def can_add_filters(self, wrapped_expr, predicates):
pass

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()
@staticmethod
def compare_argument_sequences(lefts, rights):
return not lefts or not rights or all_equal(lefts, rights)

def compatible_with(self, other):
# self and other are equivalent except for predicates, selections, or
# sort keys any of which is allowed to be empty. If both are not empty
# then they must be equal
if self.equals(other):
return True

expr = self.to_expr()
fn = lambda e: (lin.proceed, e.op()) # noqa: E731
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
if not isinstance(other, type(self)):
return False

return False
return self.table.equals(other.table) and (
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
self.compare_argument_sequences(
self.predicates, other.predicates) and
self.compare_argument_sequences(
self.selections, other.selections) and
self.compare_argument_sequences(
self.sort_keys, other.sort_keys))

# Operator combination / fusion logic

Expand Down
30 changes: 0 additions & 30 deletions ibis/expr/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2014 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import ibis
Expand All @@ -29,22 +15,6 @@
# Place to collect esoteric expression analysis bugs and tests


def test_rewrite_substitute_distinct_tables(con):
t = con.table('test1')
tt = con.table('test1')

expr = t[t.c > 0]
expr2 = tt[tt.c > 0]

metric = t.f.sum().name('metric')
expr3 = expr.aggregate(metric)

result = L.sub_for(expr3, [(expr2, t)])
expected = t.aggregate(metric)

assert_equal(result, expected)


def test_rewrite_join_projection_without_other_ops(con):
# See #790, predicate pushdown in joins not supported

Expand Down
4 changes: 4 additions & 0 deletions ibis/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ibis import literal as L
from ibis.expr.datatypes import Category

from ibis.common import RelationError
from ibis.compat import StringIO, Decimal
from ibis.expr.tests.mocks import MockConnection

Expand Down Expand Up @@ -1354,6 +1355,9 @@ def test_div_floordiv(con, expr, expected):
assert result == expected


@pytest.mark.xfail(
raises=RelationError,
reason='Equality was broken, and fixing it broke this test')
def test_filter_predicates(con):
t = con.table('tpch_nation')

Expand Down
Loading