Skip to content

Commit

Permalink
Merge pull request #3575 from himkt/decorator-typehint
Browse files Browse the repository at this point in the history
Add typehint for deprecated and experimental
  • Loading branch information
HideakiImamura committed Jun 7, 2022
2 parents 4fd8a9e + eb87f3d commit 07a8193
Show file tree
Hide file tree
Showing 54 changed files with 293 additions and 221 deletions.
9 changes: 5 additions & 4 deletions optuna/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Optional

import optuna
from optuna._experimental import experimental
from optuna._experimental import experimental_class
from optuna._experimental import experimental_func
from optuna.trial import FrozenTrial
from optuna.trial import TrialState

Expand Down Expand Up @@ -59,7 +60,7 @@ def __call__(self, study: "optuna.study.Study", trial: FrozenTrial) -> None:
study.stop()


@experimental("2.8.0")
@experimental_class("2.8.0")
class RetryFailedTrialCallback:
"""Retry a failed trial up to a maximum number of times.
Expand Down Expand Up @@ -134,7 +135,7 @@ def __call__(self, study: "optuna.study.Study", trial: FrozenTrial) -> None:
)

@staticmethod
@experimental("2.8.0")
@experimental_func("2.8.0")
def retried_trial_number(trial: FrozenTrial) -> Optional[int]:
"""Return the number of the original trial being retried.
Expand All @@ -150,7 +151,7 @@ def retried_trial_number(trial: FrozenTrial) -> Optional[int]:
return trial.system_attrs.get("failed_trial", None)

@staticmethod
@experimental("3.0.0")
@experimental_func("3.0.0")
def retry_history(trial: FrozenTrial) -> List[int]:
"""Return the list of retried trial numbers with respect to the specified trial.
Expand Down
108 changes: 74 additions & 34 deletions optuna/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import functools
import inspect
import textwrap
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar
import warnings

from packaging import version
from typing_extensions import ParamSpec

from optuna._experimental import _get_docstring_indent
from optuna._experimental import _validate_version


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")


_DEPRECATION_NOTE_TEMPLATE = """
.. warning::
Expand Down Expand Up @@ -41,21 +47,21 @@ def _format_text(text: str) -> str:
return "\n\n" + textwrap.indent(text.strip(), " ") + "\n"


def deprecated(
def deprecated_func(
deprecated_version: str,
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Any:
"""Decorate class or function as deprecated.
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
"""Decorate function as deprecated.
Args:
deprecated_version:
The version in which the target feature is deprecated.
removed_version:
The version in which the target feature will be removed.
name:
The name of the feature. Defaults to the function or class name. Optional.
The name of the feature. Defaults to the function name. Optional.
text:
The additional text for the deprecation note. The default note is build using specified
``deprecated_version`` and ``removed_version``. If you want to provide additional
Expand All @@ -75,52 +81,86 @@ def deprecated(
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def _deprecated_wrapper(f: Any) -> Any:
# f is either func or class.
def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
if func.__doc__ is None:
func.__doc__ = ""

note = _DEPRECATION_NOTE_TEMPLATE.format(d_ver=deprecated_version, r_ver=removed_version)
if text is not None:
note += _format_text(text)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

def _deprecated_func(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
"""Decorates a function as deprecated.
This decorator is supposed to be applied to the deprecated function.
"""
if func.__doc__ is None:
func.__doc__ = ""

note = _DEPRECATION_NOTE_TEMPLATE.format(
d_ver=deprecated_version, r_ver=removed_version
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else func.__name__),
d_ver=deprecated_version,
r_ver=removed_version,
)
if text is not None:
note += _format_text(text)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent
message += " " + text
warnings.warn(message, FutureWarning, stacklevel=2)

# TODO(mamu): Annotate this correctly.
@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any) -> Any:
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else func.__name__),
d_ver=deprecated_version,
r_ver=removed_version,
)
if text is not None:
message += " " + text
warnings.warn(message, FutureWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator

return func(*args, **kwargs) # type: ignore

return new_func
def deprecated_class(
deprecated_version: str,
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Callable[[CT], CT]:
"""Decorate class as deprecated.
Args:
deprecated_version:
The version in which the target feature is deprecated.
removed_version:
The version in which the target feature will be removed.
name:
The name of the feature. Defaults to the class name. Optional.
text:
The additional text for the deprecation note. The default note is build using specified
``deprecated_version`` and ``removed_version``. If you want to provide additional
information, please specify this argument yourself.
.. note::
The default deprecation note is as follows: "Deprecated in v{d_ver}. This feature
will be removed in the future. The removal of this feature is currently scheduled
for v{r_ver}, but this schedule is subject to change. See
https://github.com/optuna/optuna/releases/tag/v{d_ver}."
.. note::
The specified text is concatenated after the default deprecation note.
"""

_validate_version(deprecated_version)
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def _deprecated_class(cls: Any) -> Any:
def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
"""Decorates a class as deprecated.
This decorator is supposed to be applied to the deprecated class.
"""
_original_init = cls.__init__
_original_init = getattr(cls, "__init__")
_original_name = getattr(cls, "__name__")

@functools.wraps(_original_init)
def wrapped_init(self, *args, **kwargs) -> None: # type: ignore
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else cls.__name__),
name=(name if name is not None else _original_name),
d_ver=deprecated_version,
r_ver=removed_version,
)
Expand All @@ -134,7 +174,7 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

_original_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
setattr(cls, "__init__", wrapped_init)

if cls.__doc__ is None:
cls.__doc__ = ""
Expand All @@ -149,6 +189,6 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

return cls

return _deprecated_class(f) if inspect.isclass(f) else _deprecated_func(f)
return wrapper(cls)

return _deprecated_wrapper
return decorator
86 changes: 54 additions & 32 deletions optuna/_experimental.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import functools
import inspect
import textwrap
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar
import warnings

from typing_extensions import ParamSpec

from optuna.exceptions import ExperimentalWarning


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")

_EXPERIMENTAL_NOTE_TEMPLATE = """
.. note::
Expand All @@ -31,65 +37,81 @@ def _get_docstring_indent(docstring: str) -> str:
return docstring.split("\n")[-1] if "\n" in docstring else ""


def experimental(version: str, name: Optional[str] = None) -> Any:
"""Decorate class or function as experimental.
def experimental_func(
version: str,
name: Optional[str] = None,
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
"""Decorate function as experimental.
Args:
version: The first version that supports the target feature.
name: The name of the feature. Defaults to the function or class name. Optional.
name: The name of the feature. Defaults to the function name. Optional.
"""

_validate_version(version)

def _experimental_wrapper(f: Any) -> Any:
# f is either func or class.
def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
if func.__doc__ is None:
func.__doc__ = ""

note = _EXPERIMENTAL_NOTE_TEMPLATE.format(ver=version)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else func.__name__, version
),
ExperimentalWarning,
stacklevel=2,
)

def _experimental_func(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
return func(*args, **kwargs)

if func.__doc__ is None:
func.__doc__ = ""
return wrapper

note = _EXPERIMENTAL_NOTE_TEMPLATE.format(ver=version)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent
return decorator

# TODO(crcrpar): Annotate this correctly.
@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any) -> Any:
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else func.__name__, version
),
ExperimentalWarning,
stacklevel=2,
)

return func(*args, **kwargs) # type: ignore
def experimental_class(
version: str,
name: Optional[str] = None,
) -> Callable[[CT], CT]:
"""Decorate class as experimental.
Args:
version: The first version that supports the target feature.
name: The name of the feature. Defaults to the class name. Optional.
"""

return new_func
_validate_version(version)

def _experimental_class(cls: Any) -> Any:
def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
"""Decorates a class as experimental.
This decorator is supposed to be applied to the experimental class.
"""
_original_init = cls.__init__
_original_init = getattr(cls, "__init__")
_original_name = getattr(cls, "__name__")

@functools.wraps(_original_init)
def wrapped_init(self, *args, **kwargs) -> None: # type: ignore
def wrapped_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else cls.__name__, version
name if name is not None else _original_name, version
),
ExperimentalWarning,
stacklevel=2,
)

_original_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
setattr(cls, "__init__", wrapped_init)

if cls.__doc__ is None:
cls.__doc__ = ""
Expand All @@ -100,6 +122,6 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

return cls

return _experimental_class(f) if inspect.isclass(f) else _experimental_func(f)
return wrapper(cls)

return _experimental_wrapper
return decorator

0 comments on commit 07a8193

Please sign in to comment.