Skip to content

Commit

Permalink
Type puzzling
Browse files Browse the repository at this point in the history
  • Loading branch information
himkt committed May 20, 2022
1 parent 03c81e4 commit f98a6de
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
18 changes: 11 additions & 7 deletions optuna/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable
from typing import Optional
from typing import overload
from typing import Type
from typing import TypeVar
from typing import Union
import warnings
Expand All @@ -17,10 +16,12 @@
from optuna._experimental import _validate_version


T = TypeVar("T")
FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")


_DEPRECATION_NOTE_TEMPLATE = """
.. warning::
Expand Down Expand Up @@ -55,7 +56,7 @@ def deprecated(
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Union[Callable[[Type[CT]], Type[CT]], Callable[FP, FT]]:
) -> Union[Callable[[Callable[FP, FT]], Callable[FP, FT]], Callable[[CT], CT]]:
"""Decorate class or function as deprecated.
Args:
Expand Down Expand Up @@ -85,14 +86,14 @@ def deprecated(
_validate_two_version(deprecated_version, removed_version)

@overload
def _deprecated_wrapper(f: Callable[FP, FT]) -> Callable[FP, FT]:
def _deprecated_wrapper(f: Callable[FP, FT]) -> T:
...

@overload
def _deprecated_wrapper(f: Type[CT]) -> Type[CT]:
def _deprecated_wrapper(f: CT) -> T:
...

def _deprecated_wrapper(f: Any) -> Any:
def _deprecated_wrapper(f: Union[Callable[FP, FT], CT]) -> Union[Callable[FP, FT], CT]:
# f is either func or class.

def _deprecated_func(func: Callable[FP, FT]) -> Callable[FP, FT]:
Expand Down Expand Up @@ -127,7 +128,7 @@ def new_func(*args: Any, **kwargs: Any) -> Any:

return new_func

def _deprecated_class(cls: Type[CT]) -> Type[CT]:
def _deprecated_class(cls: CT) -> CT:
"""Decorates a class as deprecated.
This decorator is supposed to be applied to the deprecated class.
Expand Down Expand Up @@ -166,6 +167,9 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

return cls

return _deprecated_class(f) if inspect.isclass(f) else _deprecated_func(f)
if inspect.isclass(f):
return _deprecated_class(f)
else:
return _deprecated_func(f)

return _deprecated_wrapper
5 changes: 2 additions & 3 deletions optuna/_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Callable
from typing import Optional
from typing import overload
from typing import Type
from typing import TypeVar
from typing import Union
import warnings
Expand Down Expand Up @@ -44,7 +43,7 @@ def _get_docstring_indent(docstring: str) -> str:
def experimental(
version: str,
name: Optional[str] = None,
) -> Union[Callable[[Type[CT]], Type[CT]], Callable[FP, FT]]:
) -> Union[Callable[[CT], CT], Callable[FP, FT]]:
"""Decorate class or function as experimental.
Args:
Expand Down Expand Up @@ -90,7 +89,7 @@ def new_func(*args: Any, **kwargs: Any) -> Any:

return new_func

def _experimental_class(cls: Type[CT]) -> Type[CT]:
def _experimental_class(cls: CT) -> CT:
"""Decorates a class as experimental.
This decorator is supposed to be applied to the experimental class.
Expand Down

0 comments on commit f98a6de

Please sign in to comment.