Skip to content

Commit

Permalink
BUG: Fix equality
Browse files Browse the repository at this point in the history
Closes #1599  This is a temporary fix that preserves most existing
behavior until we can refactor the compiler to separate compilation
from optimization.

Author: Phillip Cloud <cpcloud@gmail.com>

Closes #1600 from cpcloud/fix-equality and squashes the following commits:

92f9cbf [Phillip Cloud] Revert and use compat zip
05bab25 [Phillip Cloud] Rename compare_argument_sequences to empty_or_equal
346114b [Phillip Cloud] Clean up code
e83a5ea [Phillip Cloud] Xfail on newly broken test
9334c23 [Phillip Cloud] Fix notebook
75e8b65 [Phillip Cloud] Flake8
2b0ceba [Phillip Cloud] BUG: Fix equality
  • Loading branch information
cpcloud committed Sep 6, 2018
1 parent eece41b commit 591c058
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
"expr = (t2\n",
" [t2.bigint_col > 30]\n",
" .group_by('string_col')\n",
" .aggregate([t2.foo.min().name('min_foo'),\n",
" t2.foo.max().name('max_foo'),\n",
" t2.foo.sum().name('sum_foo')]))\n",
" .aggregate(min_foo=lambda t: t.foo.min(),\n",
" max_foo=lambda t: t.foo.max(),\n",
" sum_foo=lambda t: t.foo.sum()))\n",
"expr"
]
},
Expand Down
6 changes: 3 additions & 3 deletions ibis/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,16 @@ def test_where_use_if(con, alltypes, translate):
# con.explain(expr)


@pytest.mark.xfail(
raises=com.RelationError, reason='Expression equality is broken')
def test_filter_predicates(diamonds):
t = diamonds

predicates = [
lambda x: x.color.lower().like('%de%'),
# lambda x: x.color.lower().contains('de'),
lambda x: x.color.lower().rlike('.*ge.*')
]

expr = t
expr = diamonds
for pred in predicates:
expr = expr[pred(expr)].projection([expr])

Expand Down
11 changes: 4 additions & 7 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _check_fusion(self, root):

# a * projection
if (isinstance(val, ir.TableExpr) and
(self.parent.op().equals(val.op()) or
(self.parent.op().compatible_with(val.op()) or
# gross we share the same table root. Better way to
# detect?
len(roots) == 1 and val._root_tables()[0] is roots[0])):
Expand Down Expand Up @@ -851,19 +851,16 @@ def roots_shared(self, node):

def shares_some_roots(self, expr):
expr_roots = expr._root_tables()
return any(self._among_roots(root)
for root in expr_roots)
return any(self._among_roots(root) for root in expr_roots)

def shares_one_root(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(root)
for root in expr_roots)
total = sum(self.roots_shared(root) for root in expr_roots)
return total == 1

def shares_multiple_roots(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(expr_roots)
for root in expr_roots)
total = sum(self.roots_shared(expr_roots) for root in expr_roots)
return total > 1

def validate_all(self, exprs):
Expand Down
69 changes: 54 additions & 15 deletions ibis/expr/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import ibis.expr.datatypes as dt

from ibis import util, compat
from ibis.compat import functools, map
from ibis.compat import functools, map, zip
from ibis.expr.signature import Annotable, Argument as Arg


Expand Down Expand Up @@ -98,9 +98,13 @@ def equals(self, other, cache=None):
if not all_equal(left, right, cache=cache):
cache[(self, other)] = False
return False

cache[(self, other)] = True
return True

def compatible_with(self, other):
return self.equals(other)

def is_ancestor(self, other):
if isinstance(other, ir.Expr):
other = other.op()
Expand Down Expand Up @@ -138,9 +142,24 @@ def has_resolved_name(self):


def all_equal(left, right, cache=None):
"""Check whether two objects `left` and `right` are equal.
Parameters
----------
left : Union[object, Expr, Node]
right : Union[object, Expr, Node]
cache : Optional[Dict[Tuple[Node, Node], bool]]
A dictionary indicating whether two Nodes are equal
"""
if util.is_iterable(left):
return util.is_iterable(right) and all(
map(functools.partial(all_equal, cache=cache), left, right))
# check that left and right are equal length iterables and that all
# of their elements are equal
return util.is_iterable(right) and len(left) == len(right) and all(
itertools.starmap(
functools.partial(all_equal, cache=cache),
zip(left, right)
)
)

if hasattr(left, 'equals'):
return left.equals(right, cache=cache)
Expand All @@ -155,7 +174,6 @@ def genname():


class TableNode(Node):

def get_type(self, name):
return self.schema[name]

Expand All @@ -168,6 +186,22 @@ def aggregate(self, this, metrics, by=None, having=None):
def sort_by(self, expr, sort_exprs):
return Selection(expr, [], sort_keys=sort_exprs)

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()

if self.equals(other):
return True

fn = lambda e: (lin.proceed, e.op()) # noqa: E731
expr = self.to_expr()
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
return False


class TableColumn(ValueOp):
"""Selects a column from a TableExpr"""
Expand Down Expand Up @@ -1862,22 +1896,27 @@ def root_tables(self):
def can_add_filters(self, wrapped_expr, predicates):
pass

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()
@staticmethod
def empty_or_equal(lefts, rights):
return not lefts or not rights or all_equal(lefts, rights)

def compatible_with(self, other):
# self and other are equivalent except for predicates, selections, or
# sort keys any of which is allowed to be empty. If both are not empty
# then they must be equal
if self.equals(other):
return True

expr = self.to_expr()
fn = lambda e: (lin.proceed, e.op()) # noqa: E731
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
if not isinstance(other, type(self)):
return False

return False
return self.table.equals(other.table) and (
self.empty_or_equal(
self.predicates, other.predicates) and
self.empty_or_equal(
self.selections, other.selections) and
self.empty_or_equal(
self.sort_keys, other.sort_keys))

# Operator combination / fusion logic

Expand Down
30 changes: 0 additions & 30 deletions ibis/expr/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2014 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import ibis
Expand All @@ -29,22 +15,6 @@
# Place to collect esoteric expression analysis bugs and tests


def test_rewrite_substitute_distinct_tables(con):
t = con.table('test1')
tt = con.table('test1')

expr = t[t.c > 0]
expr2 = tt[tt.c > 0]

metric = t.f.sum().name('metric')
expr3 = expr.aggregate(metric)

result = L.sub_for(expr3, [(expr2, t)])
expected = t.aggregate(metric)

assert_equal(result, expected)


def test_rewrite_join_projection_without_other_ops(con):
# See #790, predicate pushdown in joins not supported

Expand Down
4 changes: 4 additions & 0 deletions ibis/impala/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ibis import literal as L
from ibis.expr.datatypes import Category

from ibis.common import RelationError
from ibis.compat import StringIO, Decimal
from ibis.expr.tests.mocks import MockConnection

Expand Down Expand Up @@ -1354,6 +1355,9 @@ def test_div_floordiv(con, expr, expected):
assert result == expected


@pytest.mark.xfail(
raises=RelationError,
reason='Equality was broken, and fixing it broke this test')
def test_filter_predicates(con):
t = con.table('tpch_nation')

Expand Down
Loading

0 comments on commit 591c058

Please sign in to comment.