Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add register_hooks utility #247

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
convenience constructor
- Validators for `Campaign` attributes
_ `_optional` subpackage for managing optional dependencies
- `register_hooks` utility enabling user-defined augmentation of arbitrary callables
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved

### Changed
- Passing an `Objective` to `Campaign` is now optional
Expand Down
107 changes: 105 additions & 2 deletions baybe/utils/basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Collection of small basic utilities."""

import functools
import inspect
from collections.abc import Callable, Collection, Iterable, Sequence
from dataclasses import dataclass
from inspect import signature
from typing import Any, TypeVar

from baybe.exceptions import UnidentifiedSubclassError
Expand Down Expand Up @@ -139,7 +140,7 @@ def filter_attributes(
Returns:
A dictionary mapping the matched attribute names to their values.
"""
params = signature(callable_).parameters
params = inspect.signature(callable_).parameters
return {
p: getattr(object, p)
for p in params
Expand Down Expand Up @@ -181,3 +182,105 @@ def find_subclass(base: type, name_or_abbr: str, /):
f"The class name or abbreviation '{name_or_abbr}' does not refer to any "
f"of the subclasses of '{base.__name__}'."
)


def register_hooks(
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
target: Callable,
pre_hooks: Sequence[Callable] | None = None,
post_hooks: Sequence[Callable] | None = None,
) -> Callable:
"""Register custom hooks with a given target callable.

The provided hooks need to be "compatible" with the target in the sense that their
signatures can be aligned:

* The hook signature may only contain parameters that also exist in the target
callable (<-- basic requirement).
* However, parameters that are not needed by the hook can be omitted from
its signature. This requires that the parameters of both signatures can be matched
via their names. For simplicity, it is thus assumed that the hook has no
positional-only arguments.
* If an annotation is provided for a hook parameter, it must match its
target counterpart (<-- safety mechanism to prevent unintended argument use).
An exception is when the target parameter has no annotation, in which case
the hook annotation is unrestricted. This is particularly useful when registering
hooks with methods, since it offers the possibility to annotate the "self"
parameter bound to the method-carrying object, which is typically not annotated
in the target callable.

Args:
target: The callable to which the hooks are to be attached.
pre_hooks: Hooks to be executed before calling the target.
post_hooks: Hooks to be executed after calling the target.

Returns:
The wrapped callable with the hooks attached.

Raises:
TypeError: If any hook has positional-only arguments.
TypeError: If any hook expects parameters that are not present in the target.
TypeError: If any hook has a non-empty parameter annotation that does not
match with the corresponding annotation of the target.
"""
# Defaults
pre_hooks = pre_hooks or []
post_hooks = post_hooks or []

target_signature = inspect.signature(target, eval_str=True).parameters

# Validate hook signatures
for hook in [*pre_hooks, *post_hooks]:
hook_signature = inspect.signature(hook, eval_str=True).parameters

if any(
p.kind is inspect.Parameter.POSITIONAL_ONLY for p in hook_signature.values()
):
raise TypeError("The provided hooks cannot have position-only arguments.")
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved

if unrecognized := (set(hook_signature) - set(target_signature)):
raise TypeError(
f"The parameters expected by the hook '{hook.__name__}' must be a "
f"subset of the parameter of the target callable '{target.__name__}'. "
f"Unrecognized hook parameters: {unrecognized}."
)

for name, hook_param in hook_signature.items():
target_param = target_signature[name]

# If target parameter is not annotated, the hook annotation is unrestricted
if (t_hint := target_param.annotation) is inspect.Parameter.empty:
continue

# If target parameter is annotated, the hook annotation must be compatible,
# i.e., be identical or empty
if ((h_hint := hook_param.annotation) != t_hint) and (
h_hint is not inspect.Parameter.empty
):
raise TypeError(
f"The type annotation for '{name}' is not consistent between "
f"the given hook '{hook.__name__}' and the target callable "
f"'{target.__name__}'. Given: {h_hint}. Expected: {t_hint}."
)

def pass_args(hook: Callable, *args, **kwargs) -> None:
"""Call the hook with its requested subset of arguments."""
hook_signature = inspect.signature(hook, eval_str=True).parameters
matched_args = dict(zip(target_signature, args))
matched_kwargs = {
p: kwargs.get(p, target_signature[p].default)
for p in hook_signature
if p not in matched_args
}
passed_kwargs = {p: (matched_args | matched_kwargs)[p] for p in hook_signature}
hook(**passed_kwargs)

@functools.wraps(target)
def wraps(*args, **kwargs):
Scienfitz marked this conversation as resolved.
Show resolved Hide resolved
for hook in pre_hooks:
pass_args(hook, *args, **kwargs)
result = target(*args, **kwargs)
for hook in post_hooks:
pass_args(hook, *args, **kwargs)
return result

return wraps
4 changes: 4 additions & 0 deletions examples/Custom_Hooks/Custom_Hooks_Header.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Custom Hooks

These examples demonstrate how to register custom hooks using the
{func}`register_hooks <baybe.utils.basic.register_hooks>` utility.
120 changes: 120 additions & 0 deletions examples/Custom_Hooks/basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## Registering Custom Hooks

# This example demonstrates the basic mechanics of the
# {func}`register_hooks <baybe.utils.basic.register_hooks>` utility,
# which lets you hook into any callable of your choice:
# * We define a hook that is compatible with the general
# {meth}`RecommenderProtocol.recommend <baybe.recommenders.base.RecommenderProtocol.recommend>`
# interface,
# * attach it to a recommender,
# * and watch it take action.


### Imports


from dataclasses import dataclass
from time import perf_counter
from types import MethodType

from baybe.parameters import NumericalDiscreteParameter
from baybe.recommenders import RandomRecommender
from baybe.searchspace import SearchSpace
from baybe.utils.basic import register_hooks

### Defining the Hooks

# We start by defining a simple hook that lets us inspect the names of the parameters
# involved in the recommendation process.
# For this purpose, we match its signature to that of
# {meth}`RecommenderProtocol.recommend <baybe.recommenders.base.RecommenderProtocol.recommend>`:
#

# ```{admonition} Signature components
# :class: note
# Note that you are flexible in designing the signature of your hooks.
# For instance, function parameters and type annotations that you do not need in the
# hook body can simply be omitted.
# The exact rules to follow are described {func}`here <baybe.utils.basic.register_hooks>`.
# ```


def print_parameter_names_hook(self: RandomRecommender, searchspace: SearchSpace):
"""Print the names of the parameters spanning the search space."""
print(f"Recommender type: {self.__class__.__name__}")
print(f"Search space parameters: {[p.name for p in searchspace.parameters]}")


# Additionally, we set up a class that provides a combination of hooks for measuring
# the time needed to compute the recommendations:


@dataclass
class ElapsedTimePrinter:
"""Helper class for measuring the time between two calls."""

last_call_time: float | None = None

def start(printer_instance):
"""Start the timer."""
printer_instance.last_call_time = perf_counter()

def measure(printer_instance, self: RandomRecommender):
"""Measure the elapsed time."""
if printer_instance.last_call_time is None:
raise RuntimeError("Must call `start` first!")
elapsed = perf_counter() - printer_instance.last_call_time
print(f"Consumed time of {self.__class__.__name__}: {elapsed}")


# ```{admonition} Hook instance vs. target instance
# :class: important
# Notice the difference between the object belonging to the hook-providing class
# (named `printer_instance`) and the object whose method we intend to override
# (named `self`). This distinction is necessary because of
# {ref}`the particular way <BOUND_METHODS>` we attach the hook below, which binds `self`
# to the object carrying the target callable as a method.
# ```


### Monkeypatching

# Next, we create our recommender and monkeypatch its `recommend` method:

timer = ElapsedTimePrinter()
recommender = RandomRecommender()
recommender.recommend = MethodType(
register_hooks(
RandomRecommender.recommend,
pre_hooks=[print_parameter_names_hook, timer.start],
post_hooks=[timer.measure],
),
recommender,
)

# (BOUND_METHODS)=
# ```{admonition} Bound methods
# :class: important
# Note that the explicit binding via `MethodType` above is required because we
# decorate the (unbound) `RandomRecommender.recommend` **function** with our hooks
# and attach it as an overridden **method** to the recommender instance.
#
# Alternatively, we could have ...
# * ... overridden the class callable itself via
# `RandomRecommender.recommend = register_hooks(RandomRecommender.recommend, ...)`
# which, however, would affect all instances of `RandomRecommender` or
# * ... used the bound method of the instance as reference via
# `recommender.recommend = register_hooks(recommender.recommend, ...)` but then
# the hooks would not have access to the recommender instance as it is not
# explicitly exposed in the method's signature.
# ```

### Triggering the Hooks

# When we now apply the recommender in a specific context, we immediately see the
# effect of the hooks:

temperature = NumericalDiscreteParameter("Temperature", values=[90, 105, 120])
concentration = NumericalDiscreteParameter("Concentration", values=[0.057, 0.1, 0.153])
searchspace = SearchSpace.from_product([temperature, concentration])
recommendation = recommender.recommend(batch_size=3, searchspace=searchspace)
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
74 changes: 74 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
"""Tests for utilities."""

from contextlib import nullcontext

import numpy as np
import pytest
from pytest import param

from baybe.utils.basic import register_hooks
from baybe.utils.memory import bytes_to_human_readable
from baybe.utils.numerical import closest_element

_TARGET = 1337
_CLOSEST = _TARGET + 0.1


def f_plain(arg1, arg2):
pass


def f_reduced_plain(arg1):
pass


def f_annotated(arg1: str, arg2: int):
pass


def f_annotated_one_default(arg1: str, arg2: int = 1):
pass


def f_reversed_annotated(arg2: int, arg1: str):
pass


def f2_plain(arg, arg3):
pass


@pytest.mark.parametrize(
"as_ndarray", [param(False, id="list"), param(True, id="array")]
)
Expand All @@ -35,3 +62,50 @@ def test_memory_human_readable_conversion():
assert bytes_to_human_readable(1024) == (1.0, "KB")
assert bytes_to_human_readable(1024**2) == (1.0, "MB")
assert bytes_to_human_readable(4.3 * 1024**4) == (4.3, "TB")


@pytest.mark.parametrize(
("target, hook, error"),
[
param(
f_annotated,
f_annotated_one_default,
None,
id="hook_with_defaults",
),
param(
f_annotated_one_default,
f_annotated,
None,
id="target_with_defaults",
),
param(
f_annotated,
f_plain,
None,
id="hook_without_annotations",
),
param(
f_annotated,
f_reversed_annotated,
TypeError,
id="different_order",
),
param(
f_annotated,
f2_plain,
TypeError,
id="different_names",
),
param(
f_annotated,
f_reduced_plain,
TypeError,
id="hook_missing_arguments",
),
],
)
def test_register_hook(target, hook, error):
"""Passing in-/consistent signatures to `register_hook` raises an/no error."""
with pytest.raises(error) if error is not None else nullcontext():
AdrianSosic marked this conversation as resolved.
Show resolved Hide resolved
register_hooks(target, [hook])
RimRihana marked this conversation as resolved.
Show resolved Hide resolved
Loading