Skip to content

Commit

Permalink
refactor: add deferrable decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Sep 27, 2023
1 parent be1fd65 commit b09d978
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
33 changes: 32 additions & 1 deletion ibis/expr/deferred.py
@@ -1,7 +1,9 @@
from __future__ import annotations

import functools
import inspect
import operator
from typing import Any, Callable, NoReturn
from typing import Any, Callable, NoReturn, TypeVar

_BINARY_OPS: dict[str, Callable[[Any, Any], Any]] = {
"+": operator.add,
Expand Down Expand Up @@ -275,6 +277,35 @@ def _resolve(expr: Deferred, param: Any) -> Any:
return expr


F = TypeVar("F", bound=Callable)


def deferrable(func: F) -> F:
"""Wrap a top-level expr function to support deferred arguments.
When a deferrable function is called, if any of the direct args or kwargs
is a `Deferred` value, then the result is also `Deferred`. Otherwise the
function is called directly.
"""
# Parse the signature of func so we can validate deferred calls eagerly,
# erroring for invalid/missing arguments at call time not resolve time.
sig = inspect.signature(func)

@functools.wraps(func)
def inner(*args, **kwargs):
is_deferred = any(isinstance(a, Deferred) for a in args) or any(
isinstance(v, Deferred) for v in kwargs.values()
)
if is_deferred:
# Try to bind the arguments now, raising a nice error
# immediately if the function was called incorrectly
sig.bind(*args, **kwargs)
return deferred_apply(func, *args, **kwargs)
return func(*args, **kwargs)

return inner


def deferred_apply(func: Callable, *args: Any, **kwargs: Any) -> Deferred:
"""Construct a deferred call from a callable and arguments.
Expand Down
26 changes: 25 additions & 1 deletion ibis/expr/tests/test_deferred.py
Expand Up @@ -8,7 +8,7 @@

import ibis
from ibis import _
from ibis.expr.deferred import deferred_apply
from ibis.expr.deferred import deferrable, deferred_apply


@pytest.fixture
Expand Down Expand Up @@ -195,3 +195,27 @@ def test_deferred_is_not_iterable(obj):

with pytest.raises(TypeError, match="is not an iterator"):
next(obj)


def test_deferrable(table):
@deferrable
def f(a, b, c=3):
return a + b + c

assert f(table.a, table.b).equals(table.a + table.b + 3)
assert f(table.a, table.b, c=4).equals(table.a + table.b + 4)

expr = f(_.a, _.b)
sol = table.a + table.b + 3
res = expr.resolve(table)
assert res.equals(sol)
assert repr(expr) == "f(_.a, _.b)"

expr = f(1, 2, c=_.a)
sol = 3 + table.a
res = expr.resolve(table)
assert res.equals(sol)
assert repr(expr) == "f(1, 2, c=_.a)"

with pytest.raises(TypeError, match="unknown"):
f(_.a, _.b, unknown=3) # invalid calls caught at call time

0 comments on commit b09d978

Please sign in to comment.