Skip to content

Commit

Permalink
refactor(ir): decompose Contains into InValues and InColumn
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent e826037 commit fe9a289
Show file tree
Hide file tree
Showing 28 changed files with 182 additions and 218 deletions.
20 changes: 18 additions & 2 deletions ibis/backends/base/sql/alchemy/registry.py
Expand Up @@ -208,6 +208,22 @@ def translate(t, op):
return translate


def _in_values(t, op):
if not op.options:
return sa.literal(False)
value = t.translate(op.value)
options = [t.translate(x) for x in op.options]
return value.in_(options)


def _in_column(t, op):
value = t.translate(op.value)
options = t.translate(ops.TableArrayView(op.options.to_expr().as_table()))
if not isinstance(options, sa.sql.Selectable):
options = sa.select(options)
return value.in_(options)


def _alias(t, op):
# just compile the underlying argument because the naming is handled
# by the translator for the top level expression
Expand Down Expand Up @@ -552,8 +568,8 @@ class array_filter(FunctionElement):
ops.Cast: _cast,
ops.Coalesce: varargs(sa.func.coalesce),
ops.NullIf: fixed_arity(sa.func.nullif, 2),
ops.Contains: _contains(lambda left, right: left.in_(right)),
ops.NotContains: _contains(lambda left, right: left.notin_(right)),
ops.InValues: _in_values,
ops.InColumn: _in_column,
ops.Count: reduction(sa.func.count),
ops.CountStar: _count_star,
ops.CountDistinctStar: _count_distinct_star,
Expand Down
74 changes: 38 additions & 36 deletions ibis/backends/base/sql/registry/binary_infix.py
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Literal

import ibis.expr.analysis as an
from ibis.backends.base.sql.registry import helpers

Expand Down Expand Up @@ -50,37 +48,41 @@ def xor(translator, op):
return "({0} OR {1}) AND NOT ({0} AND {1})".format(left_arg, right_arg)


def contains(op_string: Literal["IN", "NOT IN"]) -> str:
def translate(translator, op):
from ibis.backends.base.sql.registry.main import table_array_view

if isinstance(op.options, tuple) and not op.options:
return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string]

left = translator.translate(op.value)
if helpers.needs_parens(op.value):
left = helpers.parenthesize(left)

ctx = translator.context

if isinstance(op.options, tuple):
values = [translator.translate(x) for x in op.options]
right = helpers.parenthesize(", ".join(values))
elif op.options.shape.is_columnar():
right = translator.translate(op.options)
if not any(
ctx.is_foreign_expr(leaf)
for leaf in an.find_immediate_parent_tables(op.options)
):
array = op.options.to_expr().as_table().to_array().op()
right = table_array_view(translator, array)
else:
right = translator.translate(op.options)
else:
right = translator.translate(op.options)

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{left} {op_string} {right}"

return translate
def in_values(translator, op):
if not op.options:
return "FALSE"

left = translator.translate(op.value)
if helpers.needs_parens(op.value):
left = helpers.parenthesize(left)

values = [translator.translate(x) for x in op.options]
right = helpers.parenthesize(", ".join(values))

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{left} IN {right}"


def in_column(translator, op):
from ibis.backends.base.sql.registry.main import table_array_view

ctx = translator.context

left = translator.translate(op.value)
if helpers.needs_parens(op.value):
left = helpers.parenthesize(left)

right = translator.translate(op.options)
if not any(
ctx.is_foreign_expr(leaf)
for leaf in an.find_immediate_parent_tables(op.options)
):
array = op.options.to_expr().as_table().to_array().op()
right = table_array_view(translator, array)
else:
right = translator.translate(op.options)

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{left} IN {right}"
9 changes: 4 additions & 5 deletions ibis/backends/base/sql/registry/main.py
Expand Up @@ -47,9 +47,8 @@ def is_null(translator, op):


def not_(translator, op):
(arg,) = op.args
formatted_arg = translator.translate(arg)
if helpers.needs_parens(arg):
formatted_arg = translator.translate(op.arg)
if helpers.needs_parens(op.arg):
formatted_arg = helpers.parenthesize(formatted_arg)
return f"NOT {formatted_arg}"

Expand Down Expand Up @@ -351,8 +350,8 @@ def count_star(translator, op):
ops.Least: varargs("least"),
ops.Where: fixed_arity("if", 3),
ops.Between: between,
ops.Contains: binary_infix.contains("IN"),
ops.NotContains: binary_infix.contains("NOT IN"),
ops.InValues: binary_infix.in_values,
ops.InColumn: binary_infix.in_column,
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
Expand Down
50 changes: 24 additions & 26 deletions ibis/backends/clickhouse/compiler/values.py
Expand Up @@ -5,7 +5,7 @@
import functools
from functools import partial
from operator import add, mul, sub
from typing import Any, Literal, Mapping
from typing import Any, Mapping

import sqlglot as sg
from sqlglot.dialects.dialect import rename_func
Expand Down Expand Up @@ -835,38 +835,36 @@ def _string_contains(op, **kw):
return f"locate({haystack}, {needle}) > 0"


def contains(op_string: Literal["IN", "NOT IN"]) -> str:
def tr(op, *, cache, **kw):
from ibis.backends.clickhouse.compiler import translate
@translate_val.register(ops.InValues)
def _in_values(op, **kw):
# TODO(kszucs): move this optimization to expression construction
if not op.options:
return "FALSE"

value = op.value
options = op.options
if isinstance(options, tuple) and not options:
return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string]
value = translate_val(op.value, **kw)
if helpers.needs_parens(op.value):
value = helpers.parenthesize(value)

left_arg = translate_val(value, **kw)
if helpers.needs_parens(value):
left_arg = helpers.parenthesize(left_arg)
options = _sql(translate_val(op.options, **kw))

# special case non-foreign isin/notin expressions
if not isinstance(options, tuple) and options.shape.is_columnar():
# this will fail to execute if there's a correlation, but it's too
# annoying to detect so we let it through to enable the
# uncorrelated use case (pandas-style `.isin`)
subquery = translate(options.to_expr().as_table().op(), {})
right_arg = f"({_sql(subquery)})"
else:
right_arg = _sql(translate_val(options, cache=cache, **kw))
# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{value} IN {options}"

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{left_arg} {op_string} {right_arg}"

return tr
@translate_val.register(ops.InColumn)
def _in_column(op, **kw):
from ibis.backends.clickhouse.compiler import translate

value = translate_val(op.value, **kw)
if helpers.needs_parens(op.value):
value = helpers.parenthesize(value)

translate_val.register(ops.Contains)(contains("IN"))
translate_val.register(ops.NotContains)(contains("NOT IN"))
# this will fail to execute if there's a correlation, but it's too
# annoying to detect so we let it through to enable the
# uncorrelated use case (pandas-style `.isin`)
subquery = translate(op.options.to_expr().as_table().op(), {})
return f"{value} IN ({_sql(subquery)})"


@translate_val.register(ops.DayOfWeekName)
Expand Down
Expand Up @@ -3,5 +3,5 @@ SELECT
SELECT
arrayJoin(t0.ids) AS ids
FROM way_view AS t0
) AS "Contains(id, ids)"
) AS "InColumn(id, ids)"
FROM node_view AS t0
14 changes: 9 additions & 5 deletions ibis/backends/clickhouse/tests/test_operators.py
Expand Up @@ -145,18 +145,22 @@ def test_string_temporal_compare_between_datetimes(con, left, right):


@pytest.mark.parametrize("container", [list, tuple, set])
def test_field_in_literals(con, alltypes, translate, container):
values = {"foo", "bar", "baz"}
def test_field_in_literals(con, alltypes, df, translate, container):
values = {"1", "2", "3", "5", "7"}
foobar = container(values)
expected = tuple(values)

expr = alltypes.string_col.isin(foobar)
result_col = con.execute(expr.name("result"))
expected_col = df.string_col.isin(foobar).rename("result")
assert translate(expr.op()) == f"string_col IN {expected}"
assert len(con.execute(expr))
tm.assert_series_equal(result_col, expected_col)

expr = alltypes.string_col.notin(foobar)
assert translate(expr.op()) == f"string_col NOT IN {expected}"
assert len(con.execute(expr))
result_col = con.execute(expr.name("result"))
expected_col = (~df.string_col.isin(foobar)).rename("result")
assert translate(expr.op()) == f"NOT string_col IN {expected}"
tm.assert_series_equal(result_col, expected_col)


@pytest.mark.parametrize(
Expand Down
16 changes: 4 additions & 12 deletions ibis/backends/dask/execution/generic.py
Expand Up @@ -56,14 +56,12 @@
execute_intersection_dataframe_dataframe,
execute_isinf,
execute_isnan,
execute_node_contains_series_nodes,
execute_node_contains_series_sequence,
execute_node_column_in_column,
execute_node_column_in_values,
execute_node_dropna_dataframe,
execute_node_fillna_dataframe_dict,
execute_node_fillna_dataframe_scalar,
execute_node_ifnull_series,
execute_node_not_contains_series_nodes,
execute_node_not_contains_series_sequence,
execute_node_nullif_scalar_series,
execute_node_nullif_series,
execute_node_self_reference_dataframe,
Expand Down Expand Up @@ -150,14 +148,8 @@
ops.IsNan: [((dd.Series,), execute_isnan)],
ops.IsInf: [((dd.Series,), execute_isinf)],
ops.SelfReference: [((dd.DataFrame,), execute_node_self_reference_dataframe)],
ops.Contains: [
((dd.Series, tuple), execute_node_contains_series_nodes),
((dd.Series, dd.Series), execute_node_contains_series_sequence),
],
ops.NotContains: [
((dd.Series, tuple), execute_node_not_contains_series_nodes),
((dd.Series, dd.Series), execute_node_not_contains_series_sequence),
],
ops.InValues: [((dd.Series, tuple), execute_node_column_in_values)],
ops.InColumn: [((dd.Series, dd.Series), execute_node_column_in_column)],
ops.IfNull: [
((dd.Series, simple_types), execute_node_ifnull_series),
((dd.Series, dd.Series), execute_node_ifnull_series),
Expand Down
11 changes: 2 additions & 9 deletions ibis/backends/datafusion/compiler.py
Expand Up @@ -452,20 +452,13 @@ def stddev(op, **kw):
raise ValueError(f"Unrecognized how value: {op.how}")


@translate.register(ops.Contains)
def contains(op, **kw):
@translate.register(ops.InValues)
def in_values(op, **kw):
value = translate(op.value, **kw)
options = list(map(partial(translate, **kw), op.options))
return df.functions.in_list(value, options, negated=False)


@translate.register(ops.NotContains)
def not_contains(op, **kw):
value = translate(op.value, **kw)
options = list(map(partial(translate, **kw), op.options))
return df.functions.in_list(value, options, negated=True)


@translate.register(ops.Negate)
def negate(op, **kw):
return df.lit(-1) * translate(op.arg, **kw)
Expand Down
@@ -1 +1 @@
`g` NOT IN ('foo', 'bar', 'baz')
NOT `g` IN ('foo', 'bar', 'baz')
@@ -1,3 +1,3 @@
SELECT t0.*
FROM `alltypes` t0
WHERE t0.`g` NOT IN ('foo', 'bar')
WHERE NOT t0.`g` IN ('foo', 'bar')
@@ -1 +1 @@
2 NOT IN (`a`, `b`, `c`)
NOT 2 IN (`a`, `b`, `c`)

0 comments on commit fe9a289

Please sign in to comment.