diff --git a/plum/promotion.py b/plum/promotion.py index f2208c0..a7ad58f 100644 --- a/plum/promotion.py +++ b/plum/promotion.py @@ -1,3 +1,15 @@ +"""Promotion and conversion functions.""" + +__all__ = [ + "convert", + "add_conversion_method", + "conversion_method", + "add_promotion_rule", + "promote", +] + +from typing import Callable, Protocol, Type, TypeVar + from beartype.door import TypeHint import plum.function @@ -7,13 +19,8 @@ from .repr import repr_short from .type import resolve_type_hint -__all__ = [ - "convert", - "add_conversion_method", - "conversion_method", - "add_promotion_rule", - "promote", -] +T = TypeVar("T") +R = TypeVar("R") _dispatch = Dispatcher() @@ -40,13 +47,19 @@ def convert(obj, type_to): @_dispatch def _convert(obj, type_to): - if _is_bearable(obj, resolve_type_hint(type_to)): - return obj - else: + if not _is_bearable(obj, resolve_type_hint(type_to)): raise TypeError(f"Cannot convert `{obj}` to `{repr_short(type_to)}`.") + return obj + + +class _ConversionCallable(Protocol[T, R]): + def __call__(self, obj: T) -> R: + ... -def add_conversion_method(type_from, type_to, f): +def add_conversion_method( + type_from: Type[T], type_to: Type[R], f: _ConversionCallable +) -> None: """Add a conversion method to convert an object from one type to another. Args: @@ -61,7 +74,9 @@ def perform_conversion(obj: type_from, _: type_to): return f(obj) -def conversion_method(type_from, type_to): +def conversion_method( + type_from: Type[T], type_to: Type[R] +) -> Callable[[_ConversionCallable[T, R]], _ConversionCallable[T, R]]: """Decorator to add a conversion method to convert an object from one type to another. @@ -70,7 +85,7 @@ def conversion_method(type_from, type_to): type_to (type): Type to convert to. """ - def add_method(f): + def add_method(f: _ConversionCallable[T, R]) -> _ConversionCallable[T, R]: add_conversion_method(type_from, type_to, f) return f