Skip to content

Commit

Permalink
fix(api): ensure that window functions are propagated
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Oct 28, 2022
1 parent 1a5d5b9 commit 4fb1106
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 93 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _varargs_call(sa_func, t, args):

def varargs(sa_func):
def formatter(t, op):
return _varargs_call(sa_func, t, op.arg.values)
return _varargs_call(sa_func, t, op.args)

return formatter

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def cast(translator, op):

def varargs(func_name):
def varargs_formatter(translator, op):
return helpers.format_call(translator, func_name, *op.arg.values)
return helpers.format_call(translator, func_name, *op.args)

return varargs_formatter

Expand Down Expand Up @@ -218,7 +218,7 @@ def hash(translator, op):


def concat(translator, op):
joined_args = ', '.join(map(translator.translate, op.arg.values))
joined_args = ', '.join(map(translator.translate, op.args))
return f"concat({joined_args})"


Expand Down
16 changes: 5 additions & 11 deletions ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.expr.types as ir
import ibis.util as util
from ibis.backends.base.sql.registry import binary_infix, window
from ibis.backends.base.sql.registry.main import varargs
from ibis.backends.clickhouse.datatypes import serialize
from ibis.backends.clickhouse.identifiers import quote_identifier

Expand Down Expand Up @@ -172,13 +173,6 @@ def _xor(translator, op):
return f'xor({left_}, {right_})'


def _varargs(func_name):
def varargs_formatter(translator, op):
return _call(translator, func_name, *op.arg.values)

return varargs_formatter


def _arbitrary(translator, op):
functions = {
None: 'any',
Expand Down Expand Up @@ -549,7 +543,7 @@ def _string_join(translator, op):


def _string_concat(translator, op):
args_formatted = ", ".join(map(translator.translate, op.arg.values))
args_formatted = ", ".join(map(translator.translate, op.args))
return f"arrayStringConcat([{args_formatted}])"


Expand Down Expand Up @@ -784,8 +778,8 @@ def _sort_key(translator, op):
ops.Cast: _cast,
# for more than 2 args this should be arrayGreatest|Least(array([]))
# because clickhouse's greatest and least doesn't support varargs
ops.Greatest: _varargs('greatest'),
ops.Least: _varargs('least'),
ops.Greatest: varargs('greatest'),
ops.Least: varargs('least'),
ops.Where: _fixed_arity('if', 3),
ops.Between: _between,
ops.SimpleCase: _simple_case,
Expand Down Expand Up @@ -876,7 +870,7 @@ def _day_of_week_index(translator, op):
ops.NotNull: _unary('isNotNull'),
ops.IfNull: _fixed_arity('ifNull', 2),
ops.NullIf: _fixed_arity('nullIf', 2),
ops.Coalesce: _varargs('coalesce'),
ops.Coalesce: varargs('coalesce'),
ops.NullIfZero: _null_if_zero,
ops.ZeroIfNull: _zero_if_null,
ops.DayOfWeekIndex: _day_of_week_index,
Expand Down
50 changes: 35 additions & 15 deletions ibis/backends/dask/execution/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
dispatcher since the top level container is a list.
"""

import collections
import functools
from collections.abc import Sized

Expand All @@ -19,6 +18,7 @@
import dask.dataframe.groupby as ddgb
import numpy as np
import toolz
from multipledispatch.variadic import Variadic

import ibis.expr.operations as ops
from ibis.backends.dask.dispatch import execute_node
Expand All @@ -43,28 +43,48 @@ def pairwise_reducer(func, values):
return functools.reduce(lambda x, y: func(x, y), values)


def compute_row_reduction(func, value):
final_sizes = {len(x) for x in value if isinstance(x, Sized)}
def compute_row_reduction(func, values):
final_sizes = {len(x) for x in values if isinstance(x, Sized)}
if not final_sizes:
return func(value)
return func(values)
(final_size,) = final_sizes
arrays = list(map(promote_to_sequence(final_size), value))
arrays = list(map(promote_to_sequence(final_size), values))
raw = pairwise_reducer(func, arrays)
return dd.from_array(raw).squeeze()


@execute_node.register(ops.Greatest, collections.abc.Sequence)
def dask_execute_node_greatest_list(op, value, **kwargs):
if all(type(v) != dd.Series for v in value):
return execute_node_greatest_list(op, value, **kwargs)
return compute_row_reduction(da.maximum, value)
# XXX: there's non-determinism in the dask and pandas dispatch registration of
# Greatest/Least/Coalesce, because 1) dask and pandas share `execute_node`
# which is a design flaw and 2) greatest/least/coalesce need to handle
# mixed-type (the Series types plus any related scalar type) inputs so `object`
# is used as a possible input type.
#
# Here we remove the dispatch for pandas if it exists because the dask rule
# handles both cases.
try:
del execute_node[ops.Greatest, Variadic[object]]
except KeyError:
pass


@execute_node.register(ops.Least, collections.abc.Sequence)
def dask_execute_node_least_list(op, value, **kwargs):
if all(type(v) != dd.Series for v in value):
return execute_node_least_list(op, value, **kwargs)
return compute_row_reduction(da.minimum, value)
try:
del execute_node[ops.Least, Variadic[object]]
except KeyError:
pass


@execute_node.register(ops.Greatest, [(object, dd.Series)])
def dask_execute_node_greatest_list(op, *values, **kwargs):
if all(type(v) != dd.Series for v in values):
return execute_node_greatest_list(op, *values, **kwargs)
return compute_row_reduction(da.maximum, values)


@execute_node.register(ops.Least, [(object, dd.Series)])
def dask_execute_node_least_list(op, *values, **kwargs):
if all(type(v) != dd.Series for v in values):
return execute_node_least_list(op, *values, **kwargs)
return compute_row_reduction(da.minimum, values)


@execute_node.register(ops.Reduction, ddgb.SeriesGroupBy, type(None))
Expand Down
33 changes: 18 additions & 15 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,8 +961,8 @@ def execute_node_value_list(op, _, **kwargs):
return [execute(arg, **kwargs) for arg in op.values]


@execute_node.register(ops.StringConcat, collections.abc.Sequence)
def execute_node_string_concat(op, args, **kwargs):
@execute_node.register(ops.StringConcat, [object])
def execute_node_string_concat(op, *args, **kwargs):
return functools.reduce(operator.add, args)


Expand Down Expand Up @@ -1168,30 +1168,33 @@ def coalesce(values):

@toolz.curry
def promote_to_sequence(length, obj):
return obj.values if isinstance(obj, pd.Series) else np.repeat(obj, length)
try:
return obj.values
except AttributeError:
return np.repeat(obj, length)


def compute_row_reduction(func, value, **kwargs):
final_sizes = {len(x) for x in value if isinstance(x, Sized)}
def compute_row_reduction(func, values, **kwargs):
final_sizes = {len(x) for x in values if isinstance(x, Sized)}
if not final_sizes:
return func(value)
return func(values)
(final_size,) = final_sizes
raw = func(list(map(promote_to_sequence(final_size), value)), **kwargs)
raw = func(list(map(promote_to_sequence(final_size), values)), **kwargs)
return pd.Series(raw).squeeze()


@execute_node.register(ops.Greatest, collections.abc.Sequence)
def execute_node_greatest_list(op, value, **kwargs):
return compute_row_reduction(np.maximum.reduce, value, axis=0)
@execute_node.register(ops.Greatest, [object])
def execute_node_greatest_list(op, *values, **kwargs):
return compute_row_reduction(np.maximum.reduce, values, axis=0)


@execute_node.register(ops.Least, collections.abc.Sequence)
def execute_node_least_list(op, value, **kwargs):
return compute_row_reduction(np.minimum.reduce, value, axis=0)
@execute_node.register(ops.Least, [object])
def execute_node_least_list(op, *values, **kwargs):
return compute_row_reduction(np.minimum.reduce, values, axis=0)


@execute_node.register(ops.Coalesce, collections.abc.Sequence)
def execute_node_coalesce(op, values, **kwargs):
@execute_node.register(ops.Coalesce, [object])
def execute_node_coalesce(op, *values, **kwargs):
# TODO: this is slow
return compute_row_reduction(coalesce, values)

Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,19 @@ def searched_case(op):

@translate.register(ops.Coalesce)
def coalesce(op):
arg = translate(op.arg)
arg = translate(ops.NodeList(*op.args))
return pl.coalesce(arg)


@translate.register(ops.Least)
def least(op):
arg = [translate(arg) for arg in op.arg]
arg = [translate(arg) for arg in op.args]
return pl.min(arg)


@translate.register(ops.Greatest)
def greatest(op):
arg = [translate(arg) for arg in op.arg]
arg = [translate(arg) for arg in op.args]
return pl.max(arg)


Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def compile_arbitrary(t, op, **kwargs):

@compiles(ops.Coalesce)
def compile_coalesce(t, op, **kwargs):
src_columns = t.translate(op.arg, **kwargs)
src_columns = t.translate(ops.NodeList(*op.args), **kwargs)
if len(src_columns) == 1:
return src_columns[0]
else:
Expand All @@ -658,7 +658,7 @@ def compile_coalesce(t, op, **kwargs):

@compiles(ops.Greatest)
def compile_greatest(t, op, **kwargs):
src_columns = t.translate(op.arg, **kwargs)
src_columns = t.translate(ops.NodeList(*op.args), **kwargs)
if len(src_columns) == 1:
return src_columns[0]
else:
Expand All @@ -667,7 +667,7 @@ def compile_greatest(t, op, **kwargs):

@compiles(ops.Least)
def compile_least(t, op, **kwargs):
src_columns = t.translate(op.arg, **kwargs)
src_columns = t.translate(ops.NodeList(*op.args), **kwargs)
if len(src_columns) == 1:
return src_columns[0]
else:
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def compile_string_split(t, op, **kwargs):

@compiles(ops.StringConcat)
def compile_string_concat(t, op, **kwargs):
src_columns = t.translate(op.arg, **kwargs)
src_columns = [t.translate(arg, **kwargs) for arg in op.args]
return F.concat(*src_columns)


Expand Down
6 changes: 1 addition & 5 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,7 @@ def _string_join(t, op):


def _string_concat(t, op):
# yes, `arg`. for variadic functions `arg` is the list of arguments.
#
# `args` is always the list of values of the fields declared in the
# operation
return functools.reduce(operator.add, map(t.translate, op.arg.values))
return functools.reduce(operator.add, map(t.translate, op.args))


def _date_from_ymd(t, op):
Expand Down
25 changes: 25 additions & 0 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,28 @@ def test_percent_rank_whole_table_no_order_by(backend, alltypes, df):
expected = df.assign(val=column).set_index('id').sort_index()

backend.assert_series_equal(result.val, expected.val)


@pytest.mark.notimpl(["dask", "datafusion", "polars"])
def test_grouped_ordered_window_coalesce(backend, alltypes, df):
t = alltypes
expr = (
t.group_by("month")
.order_by(["int_col", "id"])
.mutate(lagged_value=t.bigint_col.lag())[["id", "lagged_value"]]
)
result = expr.execute().sort_values(["id"]).lagged_value.reset_index(drop=True)

def agg(df):
df = df.sort_values(["int_col", "id"], kind="mergesort")
df = df.assign(bigint_col=lambda df: df.bigint_col.shift())
return df

expected = (
df.groupby("month")
.apply(agg)
.sort_values(["id"])
.reset_index(drop=True)
.bigint_col.rename("lagged_value")
)
backend.assert_series_equal(result, expected)
19 changes: 19 additions & 0 deletions ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from public import public

import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.grounds import Concrete
from ibis.expr.rules import Shape
from ibis.util import UnnamedMarker, deprecated
Expand Down Expand Up @@ -85,6 +86,20 @@ def to_expr(self):
return self.output_dtype.scalar(self)


@public
class Variadic(Value):
output_shape = rlz.shape_like('arg')
output_dtype = rlz.dtype_like('arg')

@attribute.default
def output_shape(self):
return rlz.highest_precedence_shape(self.args)

@property
def args(self):
return self.arg


@public
class Alias(Value):
arg = rlz.any
Expand Down Expand Up @@ -147,5 +162,9 @@ def to_expr(self):

return ir.List(self)

@property
def args(self):
return self.values


public(ValueOp=Value, UnaryOp=Unary, BinaryOp=Binary, ValueList=NodeList)
Loading

0 comments on commit 4fb1106

Please sign in to comment.