Skip to content

Commit

Permalink
Make typing_extensions optional
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Sep 12, 2022
1 parent 79f0329 commit 396d762
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 34 deletions.
14 changes: 8 additions & 6 deletions optuna/_convert_positional_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,39 @@
from typing import Any
from typing import Callable
from typing import Sequence
from typing import TYPE_CHECKING
from typing import TypeVar
import warnings

from typing_extensions import ParamSpec

if TYPE_CHECKING:
from typing_extensions import ParamSpec

_T = TypeVar("_T")
_P = ParamSpec("_P")
_P = ParamSpec("_P")
_T = TypeVar("_T")


def convert_positional_args(
*,
previous_positional_arg_names: Sequence[str],
warning_stacklevel: int = 2,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
) -> "Callable[[Callable[_P, _T]], Callable[_P, _T]]":
"""Convert positional arguments to keyword arguments.
Args:
previous_positional_arg_names: List of names previously given as positional arguments.
warning_stacklevel: Level of the stack trace where decorated function locates.
"""

def converter_decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
def converter_decorator(func: "Callable[_P, _T]") -> "Callable[_P, _T]":

assert set(previous_positional_arg_names).issubset(set(signature(func).parameters)), (
f"{set(previous_positional_arg_names)} is not a subset of"
f" {set(signature(func).parameters)}"
)

@wraps(func)
def converter_wrapper(*args: Any, **kwargs: Any) -> _T:
def converter_wrapper(*args: Any, **kwargs: Any) -> "_T":

if len(args) >= 1:
warnings.warn(
Expand Down
23 changes: 13 additions & 10 deletions optuna/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
from typing import Any
from typing import Callable
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
import warnings

from packaging import version
from typing_extensions import ParamSpec

from optuna._experimental import _get_docstring_indent
from optuna._experimental import _validate_version


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")
if TYPE_CHECKING:
from typing_extensions import ParamSpec

FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")


_DEPRECATION_NOTE_TEMPLATE = """
Expand Down Expand Up @@ -52,7 +55,7 @@ def deprecated_func(
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
) -> "Callable[[Callable[FP, FT]], Callable[FP, FT]]":
"""Decorate function as deprecated.
Args:
Expand Down Expand Up @@ -81,7 +84,7 @@ def deprecated_func(
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
def decorator(func: "Callable[FP, FT]") -> "Callable[FP, FT]":
if func.__doc__ is None:
func.__doc__ = ""

Expand All @@ -92,7 +95,7 @@ def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
def wrapper(*args: Any, **kwargs: Any) -> "FT":
"""Decorates a function as deprecated.
This decorator is supposed to be applied to the deprecated function.
Expand All @@ -119,7 +122,7 @@ def deprecated_class(
removed_version: str,
name: Optional[str] = None,
text: Optional[str] = None,
) -> Callable[[CT], CT]:
) -> "Callable[[CT], CT]":
"""Decorate class as deprecated.
Args:
Expand Down Expand Up @@ -148,8 +151,8 @@ def deprecated_class(
_validate_version(removed_version)
_validate_two_version(deprecated_version, removed_version)

def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
def decorator(cls: "CT") -> "CT":
def wrapper(cls: "CT") -> "CT":
"""Decorates a class as deprecated.
This decorator is supposed to be applied to the deprecated class.
Expand Down
25 changes: 14 additions & 11 deletions optuna/_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from typing import Any
from typing import Callable
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
import warnings

from typing_extensions import ParamSpec

from optuna.exceptions import ExperimentalWarning


FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")
if TYPE_CHECKING:
from typing_extensions import ParamSpec

FT = TypeVar("FT")
FP = ParamSpec("FP")
CT = TypeVar("CT")


_EXPERIMENTAL_NOTE_TEMPLATE = """
Expand All @@ -40,7 +43,7 @@ def _get_docstring_indent(docstring: str) -> str:
def experimental_func(
version: str,
name: Optional[str] = None,
) -> Callable[[Callable[FP, FT]], Callable[FP, FT]]:
) -> "Callable[[Callable[FP, FT]], Callable[FP, FT]]":
"""Decorate function as experimental.
Args:
Expand All @@ -50,7 +53,7 @@ def experimental_func(

_validate_version(version)

def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
def decorator(func: "Callable[FP, FT]") -> "Callable[FP, FT]":
if func.__doc__ is None:
func.__doc__ = ""

Expand All @@ -59,7 +62,7 @@ def decorator(func: Callable[FP, FT]) -> Callable[FP, FT]:
func.__doc__ = func.__doc__.strip() + textwrap.indent(note, indent) + indent

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> FT:
def wrapper(*args: Any, **kwargs: Any) -> "FT":
warnings.warn(
"{} is experimental (supported from v{}). "
"The interface can change in the future.".format(
Expand All @@ -79,7 +82,7 @@ def wrapper(*args: Any, **kwargs: Any) -> FT:
def experimental_class(
version: str,
name: Optional[str] = None,
) -> Callable[[CT], CT]:
) -> "Callable[[CT], CT]":
"""Decorate class as experimental.
Args:
Expand All @@ -89,8 +92,8 @@ def experimental_class(

_validate_version(version)

def decorator(cls: CT) -> CT:
def wrapper(cls: CT) -> CT:
def decorator(cls: "CT") -> "CT":
def wrapper(cls: "CT") -> "CT":
"""Decorates a class as experimental.
This decorator is supposed to be applied to the experimental class.
Expand Down
14 changes: 8 additions & 6 deletions optuna/integration/pytorch_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from typing import TypeVar

from typing_extensions import ParamSpec

import optuna
from optuna._deprecated import deprecated_func
from optuna._experimental import experimental_class
Expand All @@ -23,16 +22,19 @@
import torch.distributed as dist


_T = TypeVar("_T")
_P = ParamSpec("_P")
if TYPE_CHECKING:
from typing_extensions import ParamSpec

_T = TypeVar("_T")
_P = ParamSpec("_P")


_suggest_deprecated_msg = (
"Use :func:`~optuna.integration.TorchDistributedTrial.suggest_float` instead."
)


def broadcast_properties(f: Callable[_P, _T]) -> Callable[_P, _T]:
def broadcast_properties(f: "Callable[_P, _T]") -> "Callable[_P, _T]":
"""Method decorator to fetch updated trial properties from rank 0 after ``f`` is run.
This decorator ensures trial properties (params, distributions, etc.) on all distributed
Expand All @@ -42,7 +44,7 @@ def broadcast_properties(f: Callable[_P, _T]) -> Callable[_P, _T]:
"""

@functools.wraps(f)
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_T":
# TODO(nlgranger): Remove type ignore after mypy includes
# https://github.com/python/mypy/pull/12668
self: TorchDistributedTrial = args[0] # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def get_install_requires() -> List[str]:
"scipy!=1.4.0" if sys.version[:3] == "3.6" else "scipy>=1.7.0",
"sqlalchemy>=1.1.0",
"tqdm",
"typing_extensions>=3.10.0.0",
"PyYAML", # Only used in `optuna/cli.py`.
]
return requirements
Expand All @@ -66,6 +65,7 @@ def get_extras_require() -> Dict[str, List[str]]:
"types-PyYAML",
"types-redis",
"types-setuptools",
"typing_extensions>=3.10.0.0",
],
"document": [
"cma",
Expand Down

0 comments on commit 396d762

Please sign in to comment.