Skip to content

Commit

Permalink
BUG: Fix equality
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 5, 2018
1 parent eece41b commit 2b0ceba
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 73 deletions.
6 changes: 3 additions & 3 deletions ibis/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,16 +383,16 @@ def test_where_use_if(con, alltypes, translate):
# con.explain(expr)


@pytest.mark.xfail(
raises=com.RelationError, reason='Expression equality is broken')
def test_filter_predicates(diamonds):
t = diamonds

predicates = [
lambda x: x.color.lower().like('%de%'),
# lambda x: x.color.lower().contains('de'),
lambda x: x.color.lower().rlike('.*ge.*')
]

expr = t
expr = diamonds
for pred in predicates:
expr = expr[pred(expr)].projection([expr])

Expand Down
11 changes: 4 additions & 7 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def _check_fusion(self, root):

# a * projection
if (isinstance(val, ir.TableExpr) and
(self.parent.op().equals(val.op()) or
(self.parent.op().compatible_with(val.op()) or
# gross we share the same table root. Better way to
# detect?
len(roots) == 1 and val._root_tables()[0] is roots[0])):
Expand Down Expand Up @@ -851,19 +851,16 @@ def roots_shared(self, node):

def shares_some_roots(self, expr):
expr_roots = expr._root_tables()
return any(self._among_roots(root)
for root in expr_roots)
return any(self._among_roots(root) for root in expr_roots)

def shares_one_root(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(root)
for root in expr_roots)
total = sum(self.roots_shared(root) for root in expr_roots)
return total == 1

def shares_multiple_roots(self, expr):
expr_roots = expr._root_tables()
total = sum(self.roots_shared(expr_roots)
for root in expr_roots)
total = sum(self.roots_shared(expr_roots) for root in expr_roots)
return total > 1

def validate_all(self, exprs):
Expand Down
54 changes: 41 additions & 13 deletions ibis/expr/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def equals(self, other, cache=None):
cache[(self, other)] = True
return True

def compatible_with(self, other):
return self.equals(other)

def is_ancestor(self, other):
if isinstance(other, ir.Expr):
other = other.op()
Expand Down Expand Up @@ -139,8 +142,12 @@ def has_resolved_name(self):

def all_equal(left, right, cache=None):
if util.is_iterable(left):
return util.is_iterable(right) and all(
map(functools.partial(all_equal, cache=cache), left, right))
return util.is_iterable(right) and len(left) == len(right) and all(
itertools.starmap(
functools.partial(all_equal, cache=cache),
zip(left, right)
)
)

if hasattr(left, 'equals'):
return left.equals(right, cache=cache)
Expand Down Expand Up @@ -168,6 +175,22 @@ def aggregate(self, this, metrics, by=None, having=None):
def sort_by(self, expr, sort_exprs):
return Selection(expr, [], sort_keys=sort_exprs)

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()

if self.equals(other):
return True

fn = lambda e: (lin.proceed, e.op()) # noqa: E731
expr = self.to_expr()
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
return False


class TableColumn(ValueOp):
"""Selects a column from a TableExpr"""
Expand Down Expand Up @@ -1862,22 +1885,27 @@ def root_tables(self):
def can_add_filters(self, wrapped_expr, predicates):
pass

def is_ancestor(self, other):
import ibis.expr.lineage as lin

if isinstance(other, ir.Expr):
other = other.op()
@staticmethod
def compare_argument_sequences(lefts, rights):
return not lefts or not rights or all_equal(lefts, rights)

def compatible_with(self, other):
# self and other are equivalent except for predicates, selections, or
# sort keys any of which is allowed to be empty. If both are not empty
# then they must be equal
if self.equals(other):
return True

expr = self.to_expr()
fn = lambda e: (lin.proceed, e.op()) # noqa: E731
for child in lin.traverse(fn, expr):
if child.equals(other):
return True
if not isinstance(other, type(self)):
return False

return False
return self.table.equals(other.table) and (
self.compare_argument_sequences(
self.predicates, other.predicates) and
self.compare_argument_sequences(
self.selections, other.selections) and
self.compare_argument_sequences(
self.sort_keys, other.sort_keys))

# Operator combination / fusion logic

Expand Down
30 changes: 0 additions & 30 deletions ibis/expr/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2014 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import ibis
Expand All @@ -29,22 +15,6 @@
# Place to collect esoteric expression analysis bugs and tests


def test_rewrite_substitute_distinct_tables(con):
t = con.table('test1')
tt = con.table('test1')

expr = t[t.c > 0]
expr2 = tt[tt.c > 0]

metric = t.f.sum().name('metric')
expr3 = expr.aggregate(metric)

result = L.sub_for(expr3, [(expr2, t)])
expected = t.aggregate(metric)

assert_equal(result, expected)


def test_rewrite_join_projection_without_other_ops(con):
# See #790, predicate pushdown in joins not supported

Expand Down
163 changes: 157 additions & 6 deletions ibis/impala/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def test_nested_join_multiple_ctes():
top_user_old_movie_ids = joined3.filter([
joined3.userid == 118205,
joined3.datetime.year() < 2009
])[joined3.movieid]
])[['movieid']] # projection from a filter was hiding an insidious bug,
# so we're disabling that for now see issue #1295
cond = joined3.movieid.isin(top_user_old_movie_ids.movieid)
result = joined3[cond]

Expand All @@ -146,15 +147,14 @@ def test_nested_join_multiple_ctes():
FROM t0
INNER JOIN movies t5
ON t0.`movieid` = t5.`movieid`
),
t2 AS (
)
SELECT t2.*
FROM (
SELECT t1.*
FROM t1
WHERE (t1.`userid` = 118205) AND
(extract(t1.`datetime`, 'year') > 2001)
)
SELECT t2.*
FROM t2
) t2
WHERE t2.`movieid` IN (
SELECT `movieid`
FROM (
Expand Down Expand Up @@ -362,3 +362,154 @@ def test_multiple_filters2():
)) AND
(`b` = 'a')"""
assert result == expected


@pytest.fixture
def region():
return ibis.table(
[('r_regionkey', 'int16'),
('r_name', 'string'),
('r_comment', 'string')],
name='tpch_region'
)


@pytest.fixture
def nation():
return ibis.table(
[('n_nationkey', 'int32'),
('n_name', 'string'),
('n_regionkey', 'int32'),
('n_comment', 'string')],
name='tpch_nation'
)


@pytest.fixture
def customer():
return ibis.table(
[('c_custkey', 'int64'),
('c_name', 'string'),
('c_address', 'string'),
('c_nationkey', 'int32'),
('c_phone', 'string'),
('c_acctbal', 'decimal(12, 2)'),
('c_mktsegment', 'string'),
('c_comment', 'string')],
name='tpch_customer',
)


@pytest.fixture
def orders():
return ibis.table(
[('o_orderkey', 'int64'),
('o_custkey', 'int64'),
('o_orderstatus', 'string'),
('o_totalprice', 'decimal(12, 2)'),
('o_orderdate', 'string'),
('o_orderpriority', 'string'),
('o_clerk', 'string'),
('o_shippriority', 'int32'),
('o_comment', 'string')],
name='tpch_orders'
)


@pytest.fixture
def tpch(region, nation, customer, orders):
fields_of_interest = [customer,
region.r_name.name('region'),
orders.o_totalprice,
orders.o_orderdate.cast('timestamp').name('odate')]

return (region.join(nation, region.r_regionkey == nation.n_regionkey)
.join(customer, customer.c_nationkey == nation.n_nationkey)
.join(orders, orders.o_custkey == customer.c_custkey)
[fields_of_interest])


def test_join_key_name(tpch):
year = tpch.odate.year().name('year')

pre_sizes = tpch.group_by(year).size()
t2 = tpch.view()
conditional_avg = t2[t2.region == tpch.region].o_totalprice.mean()
amount_filter = tpch.o_totalprice > conditional_avg
post_sizes = tpch[amount_filter].group_by(year).size()

percent = ((post_sizes['count'] / pre_sizes['count'].cast('double'))
.name('fraction'))

expr = (pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year)
[pre_sizes.year,
pre_sizes['count'].name('pre_count'),
post_sizes['count'].name('post_count'),
percent])
result = ibis.impala.compile(expr)
expected = """\
WITH t0 AS (
SELECT t5.*, t3.`r_name` AS `region`, t6.`o_totalprice`,
CAST(t6.`o_orderdate` AS timestamp) AS `odate`
FROM tpch_region t3
INNER JOIN tpch_nation t4
ON t3.`r_regionkey` = t4.`n_regionkey`
INNER JOIN tpch_customer t5
ON t5.`c_nationkey` = t4.`n_nationkey`
INNER JOIN tpch_orders t6
ON t6.`o_custkey` = t5.`c_custkey`
)
SELECT t1.`year`, t1.`count` AS `pre_count`, t2.`count` AS `post_count`,
t2.`count` / CAST(t1.`count` AS double) AS `fraction`
FROM (
SELECT extract(`odate`, 'year') AS `year`, count(*) AS `count`
FROM t0
GROUP BY 1
) t1
INNER JOIN (
SELECT extract(t0.`odate`, 'year') AS `year`, count(*) AS `count`
FROM t0
WHERE t0.`o_totalprice` > (
SELECT avg(t7.`o_totalprice`) AS `mean`
FROM t0 t7
WHERE t7.`region` = t0.`region`
)
GROUP BY 1
) t2
ON t1.`year` = t2.`year`"""
assert result == expected


def test_join_key_name2(tpch):
year = tpch.odate.year().name('year')

pre_sizes = tpch.group_by(year).size()
post_sizes = tpch.group_by(year).size().view()

expr = (pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year)
[pre_sizes.year,
pre_sizes['count'].name('pre_count'),
post_sizes['count'].name('post_count')])
result = ibis.impala.compile(expr)
expected = """\
WITH t0 AS (
SELECT t5.*, t3.`r_name` AS `region`, t6.`o_totalprice`,
CAST(t6.`o_orderdate` AS timestamp) AS `odate`
FROM tpch_region t3
INNER JOIN tpch_nation t4
ON t3.`r_regionkey` = t4.`n_regionkey`
INNER JOIN tpch_customer t5
ON t5.`c_nationkey` = t4.`n_nationkey`
INNER JOIN tpch_orders t6
ON t6.`o_custkey` = t5.`c_custkey`
),
t1 AS (
SELECT extract(`odate`, 'year') AS `year`, count(*) AS `count`
FROM t0
GROUP BY 1
)
SELECT t1.`year`, t1.`count` AS `pre_count`, t2.`count` AS `post_count`
FROM t1
INNER JOIN t1 t2
ON t1.`year` = t2.`year`"""
assert result == expected
14 changes: 0 additions & 14 deletions ibis/sql/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Copyright 2014 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest
Expand Down

0 comments on commit 2b0ceba

Please sign in to comment.