Skip to content

Commit

Permalink
refactor(rules): generalize field referencing using rlz.ref()
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs committed Sep 16, 2022
1 parent 4d63280 commit 0afb8b9
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 150 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,7 @@ def compile_window_op(t, op, **kwargs):
# Timestamp needs to be cast to long for window bounds in spark
ordering_keys = [
F.col(sort.resolve_name()).cast('long')
if isinstance(sort.expr.output_dtype, dt.Timestamp)
if isinstance(sort.output_dtype, dt.Timestamp)
else sort.resolve_name()
for sort in window._order_by
]
Expand Down
6 changes: 3 additions & 3 deletions ibis/common/tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def fn(x, this):
return int(x) + this['other']

p = Parameter('novalidator')
assert p.validate({}, 'value') == 'value'
assert p.validate('value', this={}) == 'value'

p = Parameter('test', validator=fn)

assert p.validator is fn
assert p.default is inspect.Parameter.empty
assert p.validate({'other': 1}, '2') == 3
assert p.validate('2', this={'other': 1}) == 3

with pytest.raises(TypeError):
p.validate({}, valid=inspect.Parameter.empty)
Expand All @@ -78,7 +78,7 @@ def fn(x, this):
op = Parameter('test', validator=ofn)
assert op.validator is ofn
assert op.default is None
assert op.validate({'other': 1}, None) is None
assert op.validate(None, this={'other': 1}) is None


def test_signature():
Expand Down
21 changes: 13 additions & 8 deletions ibis/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
def validator(self):
return self._validator

def validate(self, this, arg):
def validate(self, arg, *, this):
if self.validator is EMPTY:
return arg
else:
Expand All @@ -100,18 +100,20 @@ class Signature(inspect.Signature):
Primarly used in the implementation of ibis.common.grounds.Annotable.
"""

def validate(self, *args, **kwargs):
def apply(self, *args, **kwargs):
bound = self.bind(*args, **kwargs)
bound.apply_defaults()
return bound.arguments

def validate(self, *args, **kwargs):
# bind the signature to the passed arguments and apply the validators
# before passing the arguments, so self.__init__() receives already
# validated arguments as keywords
this = {}
for name, value in bound.arguments.items():
for name, value in self.apply(*args, **kwargs).items():
param = self.parameters[name]
# TODO(kszucs): provide more error context on failure
this[name] = param.validate(this, value)
this[name] = param.validate(value, this=this)

return this

Expand All @@ -137,8 +139,11 @@ def __call__(self, instance):


@validator
def noop(arg, **kwargs):
return arg
def ref(key, *, this):
try:
return this[key]
except KeyError:
raise IbisTypeError(f"Could not get `{key}` from {this}")


@validator
Expand Down Expand Up @@ -166,7 +171,7 @@ def one_of(inners, arg, **kwargs):


@validator
def compose_of(inners, arg, **kwargs):
def all_of(inners, arg, **kwargs):
"""All of the inner validators must pass.
The order of inner validators matters.
Expand All @@ -184,7 +189,7 @@ def compose_of(inners, arg, **kwargs):
arg : Any
Value maybe coerced by inner validators to the appropiate types
"""
for inner in reversed(inners):
for inner in inners:
arg = inner(arg, **kwargs)
return arg

Expand Down
12 changes: 6 additions & 6 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,10 @@ def desc(expr: ir.Column | str) -> ir.SortExpr | ops.DeferredSortKey:
ir.SortExpr | ops.DeferredSortKey
A sort expression or deferred sort key
"""
if not isinstance(expr, Expr):
return ops.DeferredSortKey(expr, ascending=False)
else:
if isinstance(expr, Expr):
return ops.SortKey(expr, ascending=False).to_expr()
else:
return ops.DeferredSortKey(expr, ascending=False)


def asc(expr: ir.Column | str) -> ir.SortExpr | ops.DeferredSortKey:
Expand Down Expand Up @@ -544,10 +544,10 @@ def asc(expr: ir.Column | str) -> ir.SortExpr | ops.DeferredSortKey:
ir.SortExpr | ops.DeferredSortKey
A sort expression or deferred sort key
"""
if not isinstance(expr, Expr):
return ops.DeferredSortKey(expr, ascending=True)
else:
if isinstance(expr, Expr):
return ops.SortKey(expr, ascending=True).to_expr()
else:
return ops.DeferredSortKey(expr, ascending=True)


def and_(*predicates: ir.BooleanValue) -> ir.BooleanValue:
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ibis.common.exceptions import IbisTypeError, InputTypeError
from ibis.common.grounds import Annotable, Comparable, Singleton
from ibis.common.validators import (
compose_of,
all_of,
instance_of,
isin,
map_to,
Expand Down Expand Up @@ -491,7 +491,7 @@ class Interval(DataType):
"""The time unit of the interval."""

value_type = optional(
compose_of([datatype, instance_of(Integer)]), default=Int32()
all_of([datatype, instance_of(Integer)]), default=Int32()
)
"""The underlying type of the stored values."""

Expand Down
6 changes: 4 additions & 2 deletions ibis/expr/operations/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@public
class Window(Value):
expr = rlz.analytic
window = rlz.window(from_base_table_of="expr")
window = rlz.window_from(rlz.base_table_of(rlz.ref("expr"), strict=False))

output_dtype = rlz.dtype_like("expr")
output_shape = rlz.Shape.COLUMNAR
Expand Down Expand Up @@ -235,7 +235,9 @@ class NthValue(Analytic):
class TopK(Node):
arg = rlz.column(rlz.any)
k = rlz.non_negative_integer
by = rlz.one_of((rlz.function_of(rlz.base_table_of("arg")), rlz.any))
by = rlz.one_of(
(rlz.function_of(rlz.base_table_of(rlz.ref("arg"))), rlz.any)
)

def to_expr(self):
import ibis.expr.types as ir
Expand Down
20 changes: 10 additions & 10 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ class Projection(TableNode):
rlz.one_of(
(
rlz.table,
rlz.column_from("table"),
rlz.function_of("table"),
rlz.column_from(rlz.ref("table")),
rlz.function_of(rlz.ref("table")),
rlz.any,
)
)
Expand Down Expand Up @@ -342,7 +342,7 @@ def schema(self):
class Selection(Projection):
predicates = rlz.optional(rlz.tuple_of(rlz.boolean), default=())
sort_keys = rlz.optional(
rlz.tuple_of(rlz.sort_key_from("table")), default=()
rlz.tuple_of(rlz.sort_key_from(rlz.ref("table"))), default=()
)

def __init__(self, table, selections, predicates, sort_keys, **kwargs):
Expand Down Expand Up @@ -371,7 +371,7 @@ def __init__(self, table, selections, predicates, sort_keys, **kwargs):
def sort_by(self, sort_exprs):
from ibis.expr.analysis import shares_all_roots

keys = rlz.tuple_of(rlz.sort_key_from(self), sort_exprs)
keys = rlz.tuple_of(rlz.sort_key_from(rlz.just(self)), sort_exprs)

if not self.selections:
if shares_all_roots(keys, self.table):
Expand Down Expand Up @@ -403,7 +403,7 @@ class Aggregation(TableNode):
rlz.one_of(
(
rlz.function_of(
"table",
rlz.ref("table"),
output_rule=rlz.one_of(
(rlz.reduction, rlz.scalar(rlz.any))
),
Expand All @@ -421,8 +421,8 @@ class Aggregation(TableNode):
rlz.tuple_of(
rlz.one_of(
(
rlz.function_of("table"),
rlz.column_from("table"),
rlz.function_of(rlz.ref("table")),
rlz.column_from(rlz.ref("table")),
rlz.column(rlz.any),
)
)
Expand All @@ -434,7 +434,7 @@ class Aggregation(TableNode):
rlz.one_of(
(
rlz.function_of(
"table", output_rule=rlz.scalar(rlz.boolean)
rlz.ref("table"), output_rule=rlz.scalar(rlz.boolean)
),
rlz.scalar(rlz.boolean),
)
Expand Down Expand Up @@ -489,7 +489,7 @@ def schema(self):
def sort_by(self, sort_exprs):
from ibis.expr.analysis import shares_all_roots

keys = rlz.tuple_of(rlz.sort_key_from(self), sort_exprs)
keys = rlz.tuple_of(rlz.sort_key_from(rlz.just(self)), sort_exprs)

if shares_all_roots(keys, self.table):
return Aggregation(
Expand Down Expand Up @@ -560,7 +560,7 @@ class DropNa(TableNode):

table = rlz.table
how = rlz.isin({'any', 'all'})
subset = rlz.optional(rlz.tuple_of(rlz.column_from("table")))
subset = rlz.optional(rlz.tuple_of(rlz.column_from(rlz.ref("table"))))

@property
def schema(self):
Expand Down
14 changes: 6 additions & 8 deletions ibis/expr/operations/sortkeys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Callable

from public import public

import ibis.expr.rules as rlz
from ibis.common.grounds import Annotable
from ibis.expr.operations.core import Value


Expand Down Expand Up @@ -37,11 +40,6 @@ def to_expr(self):


@public
class DeferredSortKey:
def __init__(self, what, ascending=True):
self.what = what
self.ascending = ascending

def resolve(self, parent):
what = parent.to_expr()._ensure_expr(self.what)
return SortKey(what, ascending=self.ascending)
class DeferredSortKey(Annotable):
what = rlz.instance_of((int, str, Callable))
ascending = rlz.instance_of(bool)
Loading

0 comments on commit 0afb8b9

Please sign in to comment.