Skip to content

Commit

Permalink
Merge branch 'master' into wbruinsma/overload
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 19, 2023
2 parents 23eb67b + 9273285 commit 3ae8083
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 62 deletions.
132 changes: 70 additions & 62 deletions plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,12 @@ def _enhance_exception(self, e: SomeExceptionType) -> SomeExceptionType:
return type(e)(prefix + message[0].lower() + message[1:])

def resolve_method(
self, target: Union[Tuple[object, ...], Signature], types: Tuple[TypeHint]
self, target: Union[Tuple[object, ...], Signature]
) -> Tuple[Callable, TypeHint]:
"""Find the method and return type for arguments.
Args:
target (object): Target.
types (tuple[type, ...]): Types of the arguments.
Returns:
function: Method.
Expand All @@ -342,70 +341,89 @@ def resolve_method(

except NotFoundLookupError as e:
e = self._enhance_exception(e) # Specify this function.
method, return_type = self._handle_not_found_lookup_error(e)

if not self.owner:
# Not in a class. Nothing we can do.
raise e
return method, return_type

def _handle_not_found_lookup_error(
self, ex: NotFoundLookupError
) -> Tuple[Callable, TypeHint]:
if not self.owner:
# Not in a class. Nothing we can do.
raise ex

# In a class. Walk through the classes in the class's MRO, except for this
# class, and try to get the method.
method = None
return_type = object

for c in self.owner.__mro__[1:]:
# Skip the top of the type hierarchy given by `object` and `type`. We do
# not suddenly want to fall back to any unexpected default behaviour.
if c in {object, type}:
continue

# We need to check `c.__dict__` here instead of using `hasattr` since e.g.
# `c.__le__` will return even if `c` does not implement `__le__`!
if self._f.__name__ in c.__dict__:
method = getattr(c, self._f.__name__)
else:
# In a class. Walk through the classes in the class's MRO, except for
# this class, and try to get the method.
# For some reason, coverage fails to catch the `continue` below. Add
# the do-nothing `_ = None` fixes this.
# TODO: Remove this once coverage properly catches this.
_ = None
continue

# Ignore abstract methods.
if getattr(method, "__isabstractmethod__", False):
method = None
return_type = object

for c in self.owner.__mro__[1:]:
# Skip the top of the type hierarchy given by `object` and `type`.
# We do not suddenly want to fall back to any unexpected default
# behaviour.
if c in {object, type}:
continue

# We need to check `c.__dict__` here instead of using `hasattr`
# since e.g. `c.__le__` will return even if `c` does not implement
# `__le__`!
if self._f.__name__ in c.__dict__:
method = getattr(c, self._f.__name__)
else:
# For some reason, coverage fails to catch the `continue`
# below. Add the do-nothing `_ = None` fixes this.
# TODO: Remove this once coverage properly catches this.
_ = None
continue

# Ignore abstract methods.
if getattr(method, "__isabstractmethod__", False):
method = None
continue

# We found a good candidate. Break.
break

if not method:
# If no method has been found after walking through the MRO, raise
# the original exception.
raise e

# If the resolver is faithful, then we can perform caching using the types of
# the arguments. If the resolver is not faithful, then we cannot.
if self._resolver.is_faithful:
self._cache[types] = method, return_type
continue

# We found a good candidate. Break.
break

if not method:
# If no method has been found after walking through the MRO, raise the
# original exception.
raise ex
return method, return_type

def __call__(self, *args, **kw_args):
method, return_type = self._resolve_method_with_cache(args=args)
return _convert(method(*args, **kw_args), return_type)

def _resolve_method_with_cache(
self,
args: Union[Tuple[object, ...], Signature, None] = None,
types: Optional[Tuple[TypeHint, ...]] = None,
) -> Tuple[Callable, TypeHint]:
if args is None and types is None:
raise ValueError(
"Arguments `args` and `types` cannot both be `None`. "
"This should never happen!"
)

# Before attempting to use the cache, resolve any unresolved registrations. Use
# an `if`-statement to speed up the common case.
if self._pending:
self._resolve_pending_registrations()

# Attempt to use the cache based on the types of the arguments.
types = tuple(map(type, args))
if types is None:
# Attempt to use the cache based on the types of the arguments.
types = tuple(map(type, args))
try:
method, return_type = self._cache[types]
return self._cache[types]
except KeyError:
# Cache miss. Run the resolver based on the arguments.
method, return_type = self.resolve_method(args, types)
if args is None:
args = Signature(*(resolve_type_hint(t) for t in types))

return _convert(method(*args, **kw_args), return_type)
# Cache miss. Run the resolver based on the arguments.
method, return_type = self.resolve_method(args)
# If the resolver is faithful, then we can perform caching using the types
# of the arguments. If the resolver is not faithful, then we cannot.
if self._resolver.is_faithful:
self._cache[types] = method, return_type
return method, return_type

def invoke(self, *types: TypeHint) -> Callable:
"""Invoke a particular method.
Expand All @@ -416,17 +434,7 @@ def invoke(self, *types: TypeHint) -> Callable:
Returns:
function: Method.
"""
# Do this before attempting to cache. See above.
if self._pending:
self._resolve_pending_registrations()

# Attempt to use the cache based on the types.
try:
method, return_type = self._cache[types]
except KeyError:
# Cache miss. Run the resolver based on the types.
sig_types = Signature(*(resolve_type_hint(t) for t in types))
method, return_type = self.resolve_method(sig_types, types)
method, return_type = self._resolve_method_with_cache(types=types)

@wraps(self._f)
def wrapped_method(*args, **kw_args):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def f(x):
assert Function(f, owner="A").owner is A


def test_resolve_method_with_cache_no_arguments():
def f(x):
pass

with pytest.raises(ValueError, match="`args` and `types` cannot both be `None`"):
Function(f)._resolve_method_with_cache()


@pytest.fixture()
def owner_transfer():
# Save and clear.
Expand Down

0 comments on commit 3ae8083

Please sign in to comment.