Skip to content

Commit

Permalink
Merge pull request #5441 from eukaryo/typing-brute-force
Browse files Browse the repository at this point in the history
Simplify annotations in `_brute_force.py`
  • Loading branch information
not522 committed May 15, 2024
2 parents ab958d4 + fa6b7d7 commit 2b6a36f
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions optuna/samplers/_brute_force.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Sequence
from dataclasses import dataclass
import decimal
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -38,10 +34,10 @@ class _TreeNode:
# 2. Leaf. This is represented by children={} and param_name=None.
# 3. Normal node. It has a param_name and non-empty children.

param_name: Optional[str] = None
children: Optional[Dict[float, "_TreeNode"]] = None
param_name: str | None = None
children: dict[float, "_TreeNode"] | None = None

def expand(self, param_name: Optional[str], search_space: Iterable[float]) -> None:
def expand(self, param_name: str | None, search_space: Iterable[float]) -> None:
# If the node is unexpanded, expand it.
# Otherwise, check if the node is compatible with the given search space.
if self.children is None:
Expand All @@ -60,8 +56,8 @@ def set_leaf(self) -> None:
self.expand(None, [])

def add_path(
self, params_and_search_spaces: Iterable[Tuple[str, Iterable[float], float]]
) -> Optional["_TreeNode"]:
self, params_and_search_spaces: Iterable[tuple[str, Iterable[float], float]]
) -> "_TreeNode" | None:
# Add a path (i.e. a list of suggested parameters in one trial) to the tree.
current_node = self
for param_name, search_space, value in params_and_search_spaces:
Expand Down Expand Up @@ -136,25 +132,25 @@ def objective(trial):
suggestions during distributed optimization.
"""

def __init__(self, seed: Optional[int] = None) -> None:
def __init__(self, seed: int | None = None) -> None:
self._rng = LazyRandomState(seed)

def infer_relative_search_space(
self, study: Study, trial: FrozenTrial
) -> Dict[str, BaseDistribution]:
) -> dict[str, BaseDistribution]:
return {}

def sample_relative(
self, study: Study, trial: FrozenTrial, search_space: Dict[str, BaseDistribution]
) -> Dict[str, Any]:
self, study: Study, trial: FrozenTrial, search_space: dict[str, BaseDistribution]
) -> dict[str, Any]:
return {}

@staticmethod
def _populate_tree(
tree: _TreeNode, trials: Iterable[FrozenTrial], params: Dict[str, Any]
tree: _TreeNode, trials: Iterable[FrozenTrial], params: dict[str, Any]
) -> None:
# Populate tree under given params from the given trials.
incomplete_leaves: List[_TreeNode] = []
incomplete_leaves: list[_TreeNode] = []
for trial in trials:
if not all(p in trial.params and trial.params[p] == v for p, v in params.items()):
continue
Expand Down Expand Up @@ -214,7 +210,7 @@ def after_trial(
study: Study,
trial: FrozenTrial,
state: TrialState,
values: Optional[Sequence[float]],
values: Sequence[float] | None,
) -> None:
trials = study.get_trials(
deepcopy=False,
Expand Down

0 comments on commit 2b6a36f

Please sign in to comment.