113 changes: 51 additions & 62 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
map_to,
one_of,
optional,
ref,
tuple_of,
validator,
)
Expand Down Expand Up @@ -84,6 +85,11 @@ def __call__(self, *args, **kwargs):
# Input type validators / coercion functions


@rule
def just(arg):
return lambda **_: arg


@rule
def value_list_of(inner, arg, **kwargs):
# TODO(kszucs): would be nice to remove ops.ValueList
Expand All @@ -103,17 +109,9 @@ def value_list_of(inner, arg, **kwargs):


@rule
def sort_key_from(table, key, *, this=None):
def sort_key_from(table_ref, key, **kwargs):
import ibis.expr.operations as ops

if isinstance(key, tuple):
column, order = key
else:
column, order = key, True

if isinstance(order, str):
order = order.lower()

is_ascending = {
"asc": True,
"ascending": True,
Expand All @@ -124,25 +122,35 @@ def sort_key_from(table, key, *, this=None):
False: False,
True: True,
}
ascending = map_to(is_ascending, order)
column = one_of(

if callable(key):
key = function_of(table_ref, key)

if isinstance(key, ops.SortKey):
return key
elif isinstance(key, ops.DeferredSortKey):
key, order = key.what, key.ascending
elif isinstance(key, tuple):
key, order = key
else:
key, order = key, True

key = one_of(
(
instance_of((ops.DeferredSortKey, ops.SortKey)),
column_from(table),
function_of(table),
any,
function_of(table_ref),
column_from(table_ref),
column(any),
instance_of(ops.RandomScalar),
),
column,
this=this,
key,
**kwargs,
)

if isinstance(column, ops.SortKey):
return column
elif isinstance(column, ops.DeferredSortKey):
table = this[table] if isinstance(table, str) else table
return column.resolve(table)
else:
return ops.SortKey(column, ascending=ascending)
if isinstance(order, str):
order = order.lower()
order = map_to(is_ascending, order)

return ops.SortKey(key, ascending=order)


@rule
Expand Down Expand Up @@ -430,8 +438,11 @@ def output_dtype(self):
return output_dtype


# TODO(kszucs): it could be as simple as rlz.instance_of(ops.TableNode)
# we have a single test case testing the schema superset condition, not
# used anywhere else
@rule
def table(arg, *, schema=None, **kwargs):
def table(arg, schema=None, **kwargs):
"""A table argument.
Parameters
Expand Down Expand Up @@ -468,7 +479,7 @@ def table(arg, *, schema=None, **kwargs):


@rule
def column_from(table, column, *, this=None):
def column_from(table_ref, column, **kwargs):
"""A column from a named table.
This validator accepts columns passed as string, integer, or column
Expand All @@ -478,14 +489,8 @@ def column_from(table, column, *, this=None):
"""
import ibis.expr.operations as ops

if isinstance(table, str):
if table not in this:
raise com.IbisTypeError(f"Could not get table {table} from {this}")
else:
table = this[table]

# TODO(kszucs): should avoid converting to TableExpr
table = table.to_expr()
table = table_ref(**kwargs).to_expr()

# TODO(kszucs): should avoid converting to a ColumnExpr
if isinstance(column, ops.Node):
Expand Down Expand Up @@ -516,39 +521,20 @@ def column_from(table, column, *, this=None):
)


# TODO(kszucs): consider to remove since it's only used by TopK
@rule
def base_table_of(name, *, this):
def base_table_of(table_ref, *, this, strict=True):
from ibis.expr.analysis import find_first_base_table

arg = this[name]
arg = table_ref(this=this)
base = find_first_base_table(arg)
if base is None:
if strict and base is None:
raise com.IbisTypeError(f"`{arg}` doesn't have a base table")

return base


@rule
def function_of(
arg,
fn,
*,
output_rule=any,
this=None,
):
import ibis.expr.operations as ops

if isinstance(arg, str):
arg = this[arg].to_expr()
elif callable(arg):
arg = arg(this=this).to_expr()
elif isinstance(arg, ops.Node):
arg = arg.to_expr()
else:
raise com.IbisTypeError(
'argument `arg` must be a string, inner rule or an operation'
)
def function_of(table_ref, fn, *, output_rule=any, this=None):
arg = table_ref(this=this).to_expr()

if util.is_function(fn):
arg = fn(arg)
Expand Down Expand Up @@ -584,8 +570,12 @@ def non_negative_integer(arg, **kwargs):


@rule
def pair(inner_left, inner_right, a, b, **kwargs):
return inner_left(a, **kwargs), inner_right(b, **kwargs)
def pair(inner_left, inner_right, arg, **kwargs):
try:
a, b = arg
except TypeError:
raise com.IbisTypeError(f"{arg} is not an iterable with two elements")
return inner_left(a[0], **kwargs), inner_right(b, **kwargs)


@rule
Expand All @@ -600,8 +590,7 @@ def analytic(arg, **kwargs):


@validator
def window(win, *, from_base_table_of, this):
from ibis.expr.analysis import find_first_base_table
def window_from(table_ref, win, **kwargs):
from ibis.expr.window import Window

if not isinstance(win, Window):
Expand All @@ -610,7 +599,7 @@ def window(win, *, from_base_table_of, this):
f"got type {type(win).__name__}"
)

table = find_first_base_table(this[from_base_table_of])
table = table_ref(**kwargs)
if table is not None:
win = win.bind(table.to_expr())

Expand Down
51 changes: 14 additions & 37 deletions ibis/expr/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,16 @@ def __init__(
max_lookback=None,
how='rows',
):
import ibis.expr.operations as ops

self._group_by = tuple(
toolz.unique(
(
arg.op() if isinstance(arg, ir.Expr) else arg
for arg in util.promote_list(group_by)
),
key=lambda value: getattr(value, "_key", value),
arg.op() if isinstance(arg, ir.Expr) else arg
for arg in util.promote_list(group_by)
)
)

_order_by = []
for expr in util.promote_list(order_by):
try:
arg = expr.op()
except AttributeError:
arg = expr
if isinstance(expr, ir.Expr) and not isinstance(expr, ir.SortExpr):
arg = ops.SortKey(arg)
_order_by.append(arg)

self._order_by = tuple(
toolz.unique(
_order_by, key=lambda value: getattr(value, "_key", value)
arg.op() if isinstance(arg, ir.Expr) else arg
for arg in util.promote_list(order_by)
)
)

Expand Down Expand Up @@ -259,15 +244,13 @@ def bind(self, table):
# Internal API, ensure that any unresolved expr references (as strings,
# say) are bound to the table being windowed

import ibis.expr.operations as ops

groups = [
table._ensure_expr(
arg.to_expr() if isinstance(arg, ops.Node) else arg
).op()
for arg in self._group_by
]
sorts = rlz.tuple_of(rlz.sort_key_from(table), self._order_by)
groups = rlz.tuple_of(
rlz.one_of((rlz.column_from(rlz.just(table)), rlz.any)),
self._group_by,
)
sorts = rlz.tuple_of(
rlz.sort_key_from(rlz.just(table)), self._order_by
)

return self._replace(group_by=groups, order_by=sorts)

Expand Down Expand Up @@ -308,9 +291,7 @@ def order_by(self, expr):

def __equals__(self, other):
return (
len(self._group_by) == len(other._group_by)
and len(self._order_by) == len(other._order_by)
and self.max_lookback == other.max_lookback
self.max_lookback == other.max_lookback
and (
self.preceding.equals(other.preceding)
if isinstance(self.preceding, ir.Expr)
Expand All @@ -321,12 +302,8 @@ def __equals__(self, other):
if isinstance(self.following, ir.Expr)
else self.following == other.following
)
and all(
a.equals(b) for a, b in zip(self._group_by, other._group_by)
)
and all(
a.equals(b) for a, b in zip(self._order_by, other._order_by)
)
and self._group_by == other._group_by
and self._order_by == other._order_by
)

def equals(self, other):
Expand Down
10 changes: 5 additions & 5 deletions ibis/tests/expr/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,19 +299,19 @@ def test_invalid_column_or_scalar(validator, value, expected):
],
)
def test_valid_column_from(check_table, value, expected):
validator = rlz.column_from("table")
validator = rlz.column_from(rlz.ref("table"))
this = dict(table=table.op())
assert validator(value, this=this).equals(expected.op())


@pytest.mark.parametrize(
('check_table', 'validator', 'value'),
[
(table, rlz.column_from("not_table"), "int_col"),
(table, rlz.column_from("table"), "col_not_in_table"),
(table, rlz.column_from(rlz.ref("not_table")), "int_col"),
(table, rlz.column_from(rlz.ref("table")), "col_not_in_table"),
(
table,
rlz.column_from("table"),
rlz.column_from(rlz.ref("table")),
similar_table.int_col,
),
],
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_optional(validator, input):
def test_base_table_of_failure_mode():
class BrokenUseOfBaseTableOf(ops.Node):
arg = rlz.any
foo = rlz.function_of(rlz.base_table_of("arg"))
foo = rlz.function_of(rlz.base_table_of(rlz.ref("arg"), strict=True))

arg = ibis.literal("abc")

Expand Down
18 changes: 12 additions & 6 deletions ibis/tests/expr/test_window_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.types as ir
from ibis.expr.window import _determine_how, rows_with_max_lookback
Expand Down Expand Up @@ -146,8 +147,12 @@ def test_over_auto_bind(alltypes):

expr = t.f.lag().over(w)

actual_window = expr.op().args[1]
expected = ibis.window(group_by=t.g, order_by=t.f)
# TODO(kszucs): the window object doesn't apply the rules for the sorting
# keys, so need to wrap the expected order key with a SortKey for now
# on long term we should refactor the window object to a WindowFrame op
actual_window = expr.op().args[1] # noqa
expected = ibis.window(group_by=t.g, order_by=ops.SortKey(t.f)) # noqa

assert_equal(actual_window, expected)


Expand All @@ -159,8 +164,9 @@ def test_window_function_bind(alltypes):

expr = t.f.lag().over(w)

actual_window = expr.op().args[1]
expected = ibis.window(group_by=t.g, order_by=t.f)
actual_window = expr.op().args[1] # noqa
expected = ibis.window(group_by=t.g, order_by=ops.SortKey(t.f)) # noqa

assert_equal(actual_window, expected)


Expand Down Expand Up @@ -193,8 +199,8 @@ def test_window_bind_to_table(alltypes):
t = alltypes
w = ibis.window(group_by='g', order_by=ibis.desc('f'))

w2 = w.bind(alltypes)
expected = ibis.window(group_by=t.g, order_by=ibis.desc(t.f))
w2 = w.bind(alltypes) # noqa
expected = ibis.window(group_by=t.g, order_by=ibis.desc(t.f)) # noqa

assert_equal(w2, expected)

Expand Down