Skip to content

Commit

Permalink
Add argcount to the function wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
DeviousStoat committed Jan 19, 2024
1 parent fcf3dd4 commit eef539a
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
54 changes: 41 additions & 13 deletions src/pydash/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import time
import typing as t

from typing_extensions import Concatenate, Literal, ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec, Protocol

import pydash as pyd
from pydash.helpers import getargcount


__all__ = (
Expand Down Expand Up @@ -50,7 +51,15 @@
P = ParamSpec("P")


class After(t.Generic[P, T]):
class _WithArgCount(Protocol):
func: t.Callable

@property
def _argcount(self) -> t.Optional[int]:
return getargcount(self.func, None)


class After(_WithArgCount, t.Generic[P, T]):
"""Wrap a function in an after context."""

def __init__(self, func: t.Callable[P, T], n: t.SupportsInt) -> None:
Expand All @@ -73,7 +82,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> t.Union[T, None]:
return None


class Ary(t.Generic[T]):
class Ary(_WithArgCount, t.Generic[T]):
"""Wrap a function in an ary context."""

def __init__(self, func: t.Callable[..., T], n: t.Union[t.SupportsInt, None]) -> None:
Expand Down Expand Up @@ -166,12 +175,11 @@ def __init__(self, *funcs, from_right: bool = True) -> None: # type: ignore
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Return results of composing :attr:`funcs`."""
funcs = list(self.funcs)
from_index = -1 if self.from_right else 0

result = None

while funcs:
result = funcs.pop(from_index)(*args, **kwargs)
result = funcs.pop(self._from_index)(*args, **kwargs)
# Incompatible type in assignements but needed here
# type safety is ensured from the `__init__` signature
args = (result,) # type: ignore
Expand All @@ -180,6 +188,14 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
# type safety is ensured from the `__init__` signature
return result # type: ignore

@property
def _from_index(self) -> int:
return -1 if self.from_right else 0

@property
def _argcount(self) -> t.Optional[int]:
return getargcount(self.funcs[self._from_index], None)


class Conjoin(t.Generic[T]):
"""Wrap a set of functions in a conjoin context."""
Expand Down Expand Up @@ -398,7 +414,7 @@ def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs) # pragma: no cover


class Debounce(t.Generic[P, T]):
class Debounce(_WithArgCount, t.Generic[P, T]):
"""Wrap a function in a debounce context."""

def __init__(
Expand Down Expand Up @@ -451,7 +467,7 @@ def iteratee(item: T) -> bool:
return pyd.some(obj, iteratee)


class Flip(object):
class Flip(_WithArgCount):
"""Wrap a function in a flip context."""

def __init__(self, func: t.Callable) -> None:
Expand Down Expand Up @@ -494,8 +510,12 @@ def __init__(self, *funcs: t.Callable[P, T]) -> None:
def __call__(self, *objs: P.args, **kwargs: P.kwargs) -> t.List[T]:
return [func(*objs, **kwargs) for func in self.funcs]

@property
def _argcount(self) -> t.Optional[int]:
return getargcount(self.funcs[0], None) if self.funcs else None


class OverArgs(object):
class OverArgs(_WithArgCount):
"""Wrap a function in an over_args context."""

def __init__(self, func: t.Callable, *transforms: t.Callable) -> None:
Expand All @@ -507,7 +527,7 @@ def __call__(self, *args):
return self.func(*args)


class Negate(t.Generic[P]):
class Negate(_WithArgCount, t.Generic[P]):
"""Wrap a function in a negate context."""

def __init__(self, func: t.Callable[P, t.Any]) -> None:
Expand All @@ -518,7 +538,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> bool:
return not self.func(*args, **kwargs)


class Once(t.Generic[P, T]):
class Once(_WithArgCount, t.Generic[P, T]):
"""Wrap a function in a once context."""

def __init__(self, func: t.Callable[P, T]) -> None:
Expand All @@ -536,7 +556,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.result # type: ignore


class Partial(t.Generic[T]):
class Partial(_WithArgCount, t.Generic[T]):
"""Wrap a function in a partial context."""

def __init__(
Expand All @@ -562,8 +582,16 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> T:

return self.func(*args, **kwargs)

@property
def _argcount(self) -> t.Optional[int]:
func_argcount = getargcount(self.func, None)
if func_argcount is None:
return None
argcount = func_argcount - len(self.args) - len(self.kwargs)
return argcount if argcount >= 0 else None


class Rearg(t.Generic[P, T]):
class Rearg(_WithArgCount, t.Generic[P, T]):
"""Wrap a function in a rearg context."""

def __init__(self, func: t.Callable[P, T], *indexes: int) -> None:
Expand Down Expand Up @@ -608,7 +636,7 @@ def __call__(self, args: t.Iterable) -> T:
return self.func(*args)


class Throttle(t.Generic[P, T]):
class Throttle(_WithArgCount, t.Generic[P, T]):
"""Wrap a function in a throttle context."""

def __init__(self, func: t.Callable[P, T], wait: int) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/pydash/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def getargcount(iteratee, maxargs):
if hasattr(iteratee, "_argcount"):
# Optimization feature where argcount of iteratee is known and properly
# set by initiator.
return iteratee._argcount
# It should always be right, but it can be `None` for the function wrappers
# in `pydash.function` as the wrapped functions are out of our control and
# can support an unknown number of arguments.
argcount = iteratee._argcount
return argcount if argcount is not None else maxargs

if isinstance(iteratee, type) or pyd.is_builtin(iteratee):
# Only pass single argument to type iteratees or builtins.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,30 @@ def test_unary(case, args, kwargs, expected):
)
def test_wrap(case, args, expected):
assert _.wrap(*case)(*args) == expected


def test_flow_argcount():
assert _.flow(lambda x, y: x + y, lambda x: x * 2)._argcount == 2


def test_flow_right_argcount():
assert _.flow_right(lambda x: x * 2, lambda x, y: x + y)._argcount == 2


def test_juxtapose_argcount():
assert _.juxtapose(lambda x, y, z: x + y + z, lambda x, y, z: x * y * z)._argcount == 3


def test_partial_argcount():
assert _.partial(lambda x, y, z: x + y + z, 1, 2)._argcount == 1


def test_partial_right_argcount():
assert _.partial_right(lambda x, y, z: x + y + z, 1, 2)._argcount == 1


def test_can_be_used_as_predicate_argcount_is_known():
def is_positive(x: int) -> bool:
return x > 0

assert _.filter_([-1, 0, 1], _.negate(is_positive)) == [-1, 0]

0 comments on commit eef539a

Please sign in to comment.