Skip to content

Commit

Permalink
fix(compiler): fix bool bool comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed May 26, 2022
1 parent 8b26832 commit 1ac9a9e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
54 changes: 28 additions & 26 deletions ibis/backends/base/sql/registry/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,34 @@ def quote_identifier(name, quotechar='`', force=False):
return name


def needs_parens(op):
if isinstance(op, ir.Expr):
op = op.op()
op_klass = type(op)
# function calls don't need parens
return op_klass in {
ops.Negate,
ops.IsNull,
ops.NotNull,
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Power,
ops.Modulus,
ops.Equals,
ops.NotEquals,
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.IdenticalTo,
ops.And,
ops.Or,
ops.Xor,
}
_NEEDS_PARENS_OPS = (
ops.Negate,
ops.IsNull,
ops.NotNull,
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Power,
ops.Modulus,
ops.Equals,
ops.NotEquals,
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.IdenticalTo,
ops.And,
ops.Or,
ops.Xor,
)


def needs_parens(expr: ir.Expr):
op = expr.op()
if isinstance(op, ops.Alias):
op = op.arg.op()
return isinstance(op, _NEEDS_PARENS_OPS)


parenthesize = '({})'.format
Expand Down
30 changes: 26 additions & 4 deletions ibis/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,8 @@ def test_filter_subquery_derived_reduction(filter_subquery_derived_reduction):
expr3, expr4 = filter_subquery_derived_reduction

result = Compiler.to_sql(expr3)
expected = """SELECT *
expected = """\
SELECT *
FROM star1
WHERE `f` > ln((
SELECT avg(`f`) AS `mean`
Expand All @@ -793,13 +794,14 @@ def test_filter_subquery_derived_reduction(filter_subquery_derived_reduction):
assert result == expected

result = Compiler.to_sql(expr4)
expected = """SELECT *
expected = """\
SELECT *
FROM star1
WHERE `f` > ln((
WHERE `f` > (ln((
SELECT avg(`f`) AS `mean`
FROM star1
WHERE `foo_id` = 'foo'
)) + 1"""
)) + 1)"""
assert result == expected


Expand Down Expand Up @@ -931,6 +933,26 @@ def test_topk_to_aggregate():
assert result == expected


def test_bool_bool():
import ibis
from ibis.backends.base.sql.compiler import Compiler

t = ibis.table(
[('dest', 'string'), ('origin', 'string'), ('arrdelay', 'int32')],
'airlines',
)

x = ibis.literal(True)
top = t[(t.dest.cast('int64') == 0) == x]

result = Compiler.to_sql(top)
expected = """\
SELECT *
FROM airlines
WHERE (CAST(`dest` AS bigint) = 0) = TRUE"""
assert result == expected


def test_case_in_projection(alltypes):
t = alltypes

Expand Down

0 comments on commit 1ac9a9e

Please sign in to comment.