Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typehint for deprecated and experimental #3575

Merged
merged 20 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions optuna/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Optional

import optuna
from optuna._experimental import experimental
from optuna._experimental import experimental_class
from optuna._experimental import experimental_func
from optuna.trial import FrozenTrial
from optuna.trial import TrialState

Expand Down Expand Up @@ -59,7 +60,7 @@ def __call__(self, study: "optuna.study.Study", trial: FrozenTrial) -> None:
study.stop()


@experimental("2.8.0")
@experimental_class("2.8.0")
class RetryFailedTrialCallback:
"""Retry a failed trial up to a maximum number of times.

Expand Down Expand Up @@ -134,7 +135,7 @@ def __call__(self, study: "optuna.study.Study", trial: FrozenTrial) -> None:
)

@staticmethod
@experimental("2.8.0")
@experimental_func("2.8.0")
def retried_trial_number(trial: FrozenTrial) -> Optional[int]:
"""Return the number of the original trial being retried.

Expand All @@ -150,7 +151,7 @@ def retried_trial_number(trial: FrozenTrial) -> Optional[int]:
return trial.system_attrs.get("failed_trial", None)

@staticmethod
@experimental("3.0.0")
@experimental_func("3.0.0")
def retry_history(trial: FrozenTrial) -> List[int]:
"""Return the list of retried trial numbers with respect to the specified trial.

Expand Down
106 changes: 73 additions & 33 deletions optuna/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import functools
import inspect
import textwrap
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar
import warnings

from packaging import version
from typing_extensions import ParamSpec
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


from optuna._experimental import _get_docstring_indent
from optuna._experimental import _validate_version


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")


_DEPRECATION_NOTE_TEMPLATE = """

.. warning::
Expand Down Expand Up @@ -41,13 +47,13 @@ def _format_text(text: str) -> str:
return "\n\n" + textwrap.indent(text.strip(), " ") + "\n"


def deprecated(
def deprecated_func(
deprecated_version: str,
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Any:
"""Decorate class or function as deprecated.
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
"""Decorate function as deprecated.

Args:
deprecated_version:
Expand Down Expand Up @@ -75,52 +81,86 @@ def deprecated(
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def _deprecated_wrapper(f: Any) -> Any:
# f is either func or class.
def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
if func.__doc__ is None:
func.__doc__ = ""

note = _DEPRECATION_NOTE_TEMPLATE.format(d_ver=deprecated_version, r_ver=removed_version)
if text is not None:
note += _format_text(text)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

def _deprecated_func(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
"""Decorates a function as deprecated.

This decorator is supposed to be applied to the deprecated function.
"""
if func.__doc__ is None:
func.__doc__ = ""

note = _DEPRECATION_NOTE_TEMPLATE.format(
d_ver=deprecated_version, r_ver=removed_version
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else func.__name__),
d_ver=deprecated_version,
r_ver=removed_version,
)
if text is not None:
note += _format_text(text)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent
message += " " + text
warnings.warn(message, FutureWarning, stacklevel=2)

# TODO(mamu): Annotate this correctly.
@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any) -> Any:
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else func.__name__),
d_ver=deprecated_version,
r_ver=removed_version,
)
if text is not None:
message += " " + text
warnings.warn(message, FutureWarning, stacklevel=2)
return func(*args, **kwargs)

return wrapper

return decorator


def deprecated_class(
deprecated_version: str,
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Callable[[CT], CT]:
"""Decorate class as deprecated.

Args:
deprecated_version:
The version in which the target feature is deprecated.
removed_version:
The version in which the target feature will be removed.
name:
The name of the feature. Defaults to the function or class name. Optional.
himkt marked this conversation as resolved.
Show resolved Hide resolved
text:
The additional text for the deprecation note. The default note is build using specified
``deprecated_version`` and ``removed_version``. If you want to provide additional
information, please specify this argument yourself.

.. note::
The default deprecation note is as follows: "Deprecated in v{d_ver}. This feature
will be removed in the future. The removal of this feature is currently scheduled
for v{r_ver}, but this schedule is subject to change. See
https://github.com/optuna/optuna/releases/tag/v{d_ver}."

return func(*args, **kwargs) # type: ignore
.. note::
The specified text is concatenated after the default deprecation note.
"""

return new_func
_validate_version(deprecated_version)
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def _deprecated_class(cls: Any) -> Any:
def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
"""Decorates a class as deprecated.

This decorator is supposed to be applied to the deprecated class.
"""
_original_init = cls.__init__
_original_init = getattr(cls, "__init__")
_original_name = getattr(cls, "__name__")

@functools.wraps(_original_init)
def wrapped_init(self, *args, **kwargs) -> None: # type: ignore
message = _DEPRECATION_WARNING_TEMPLATE.format(
name=(name if name is not None else cls.__name__),
name=(name if name is not None else _original_name),
d_ver=deprecated_version,
r_ver=removed_version,
)
Expand All @@ -134,7 +174,7 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

_original_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
setattr(cls, "__init__", wrapped_init)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python/mypy#2427 (comment)
(But some linter is not happy with the hack: python/mypy#2427 (comment), I'm not sure whether it is preferred or not)


if cls.__doc__ is None:
cls.__doc__ = ""
Expand All @@ -149,6 +189,6 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

return cls

return _deprecated_class(f) if inspect.isclass(f) else _deprecated_func(f)
return wrapper(cls)

return _deprecated_wrapper
return decorator
78 changes: 47 additions & 31 deletions optuna/_experimental.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import functools
import inspect
import textwrap
from typing import Any
from typing import Callable
from typing import Optional
from typing import TypeVar
import warnings

from typing_extensions import ParamSpec

from optuna.exceptions import ExperimentalWarning


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")

_EXPERIMENTAL_NOTE_TEMPLATE = """

.. note::
Expand All @@ -31,8 +37,11 @@ def _get_docstring_indent(docstring: str) -> str:
return docstring.split("\n")[-1] if "\n" in docstring else ""


def experimental(version: str, name: Optional[str] = None) -> Any:
"""Decorate class or function as experimental.
def experimental_func(
version: str,
name: Optional[str] = None,
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
"""Decorate function as experimental.

Args:
version: The first version that supports the target feature.
Expand All @@ -41,55 +50,62 @@ def experimental(version: str, name: Optional[str] = None) -> Any:

_validate_version(version)

def _experimental_wrapper(f: Any) -> Any:
# f is either func or class.
def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
if func.__doc__ is None:
func.__doc__ = ""

note = _EXPERIMENTAL_NOTE_TEMPLATE.format(ver=version)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else func.__name__, version
),
ExperimentalWarning,
stacklevel=2,
)

def _experimental_func(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
return func(*args, **kwargs)

if func.__doc__ is None:
func.__doc__ = ""
return wrapper

note = _EXPERIMENTAL_NOTE_TEMPLATE.format(ver=version)
indent = _get_docstring_indent(func.__doc__)
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent
return decorator

# TODO(crcrpar): Annotate this correctly.
@functools.wraps(func)
def new_func(*args: Any, **kwargs: Any) -> Any:
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else func.__name__, version
),
ExperimentalWarning,
stacklevel=2,
)

return func(*args, **kwargs) # type: ignore
def experimental_class(
version: str,
name: Optional[str] = None,
) -> Callable[[CT], CT]:

return new_func
_validate_version(version)

def _experimental_class(cls: Any) -> Any:
def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
"""Decorates a class as experimental.

This decorator is supposed to be applied to the experimental class.
"""
_original_init = cls.__init__
_original_init = getattr(cls, "__init__")
_original_name = getattr(cls, "__name__")

@functools.wraps(_original_init)
def wrapped_init(self, *args, **kwargs) -> None: # type: ignore
def wrapped_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
name if name is not None else cls.__name__, version
name if name is not None else _original_name, version
),
ExperimentalWarning,
stacklevel=2,
)

_original_init(self, *args, **kwargs)

cls.__init__ = wrapped_init
setattr(cls, "__init__", wrapped_init)

if cls.__doc__ is None:
cls.__doc__ = ""
Expand All @@ -100,6 +116,6 @@ def wrapped_init(self, *args, **kwargs) -> None: # type: ignore

return cls

return _experimental_class(f) if inspect.isclass(f) else _experimental_func(f)
return wrapper(cls)

return _experimental_wrapper
return decorator
12 changes: 6 additions & 6 deletions optuna/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Union
import warnings

from optuna._deprecated import deprecated
from optuna._deprecated import deprecated_class


CategoricalChoiceType = Union[None, bool, int, float, str]
Expand Down Expand Up @@ -186,7 +186,7 @@ def _contains(self, param_value_in_internal_repr: float) -> bool:
return self.low <= value <= self.high and abs(k - round(k)) < 1.0e-8


@deprecated("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
@deprecated_class("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
class UniformDistribution(FloatDistribution):
"""A uniform distribution in the linear domain.

Expand All @@ -213,7 +213,7 @@ def _asdict(self) -> Dict:
return d


@deprecated("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
@deprecated_class("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
class LogUniformDistribution(FloatDistribution):
"""A uniform distribution in the log domain.

Expand All @@ -240,7 +240,7 @@ def _asdict(self) -> Dict:
return d


@deprecated("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
@deprecated_class("3.0.0", "6.0.0", text=_float_distribution_deprecated_msg)
class DiscreteUniformDistribution(FloatDistribution):
"""A discretized uniform distribution in the linear domain.

Expand Down Expand Up @@ -378,7 +378,7 @@ def _contains(self, param_value_in_internal_repr: float) -> bool:
return self.low <= value <= self.high and (value - self.low) % self.step == 0


@deprecated("3.0.0", "6.0.0", text=_int_distribution_deprecated_msg)
@deprecated_class("3.0.0", "6.0.0", text=_int_distribution_deprecated_msg)
class IntUniformDistribution(IntDistribution):
"""A uniform distribution on integers.

Expand Down Expand Up @@ -412,7 +412,7 @@ def _asdict(self) -> Dict:
return d


@deprecated("3.0.0", "6.0.0", text=_int_distribution_deprecated_msg)
@deprecated_class("3.0.0", "6.0.0", text=_int_distribution_deprecated_msg)
class IntLogUniformDistribution(IntDistribution):
"""A uniform distribution on integers in the log domain.

Expand Down