Skip to content

Commit

Permalink
Merge pull request #3841 from not522/refactor-tell-with-warning
Browse files Browse the repository at this point in the history
Refactor `_tell.py`
  • Loading branch information
knshnb committed Sep 8, 2022
2 parents 312f9fd + a1a3442 commit 79f0329
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 118 deletions.
6 changes: 5 additions & 1 deletion optuna/study/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ def _run_trial(
# `_tell_with_warning` may raise during trial post-processing.
try:
frozen_trial = _tell_with_warning(
study=study, trial=trial, values=value_or_values, state=state, suppress_warning=True
study=study,
trial=trial,
value_or_values=value_or_values,
state=state,
suppress_warning=True,
)
except Exception:
frozen_trial = study._storage.get_trial(trial._trial_id)
Expand Down
221 changes: 107 additions & 114 deletions optuna/study/_tell.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import copy
import math
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import warnings

Expand All @@ -22,100 +20,8 @@
_logger = logging.get_logger(__name__)


def _check_and_convert_to_values(
n_objectives: int, original_value: Optional[Union[float, Sequence[float]]], trial_number: int
) -> Tuple[Optional[List[float]], Optional[str]]:
if isinstance(original_value, Sequence):
_original_values: Sequence[Optional[float]] = list(original_value)
else:
_original_values = [original_value]

_checked_values = []
for v in _original_values:
checked_v, failure_message = _check_single_value(v, trial_number)
if failure_message is not None:
# TODO(Imamura): Construct error message taking into account all values and do not
# early return
# `value` is assumed to be ignored on failure so we can set it to any value.
return None, failure_message
elif isinstance(checked_v, float):
_checked_values.append(checked_v)
else:
assert False

if n_objectives != len(_original_values):
return (
None,
(
f"The number of the values "
f"{len(_checked_values)} did not match the number of the objectives "
f"{n_objectives}."
),
)

return _checked_values, None


def _check_single_value(
original_value: Optional[float], trial_number: int
) -> Tuple[Optional[float], Optional[str]]:
value = None
failure_message = None

try:
value = float(original_value) # type: ignore
except (
ValueError,
TypeError,
):
failure_message = f"The value {repr(original_value)} could not be cast to float."

if value is not None and math.isnan(value):
value = None
failure_message = f"The value {original_value} is not acceptable."

return value, failure_message


def _tell_with_warning(
study: "optuna.Study",
trial: Union[trial_module.Trial, int],
values: Optional[Union[float, Sequence[float]]] = None,
state: Optional[TrialState] = None,
skip_if_finished: bool = False,
suppress_warning: bool = False,
) -> FrozenTrial:
"""Internal method of :func:`~optuna.study.Study.tell`.
Refer to the document for :func:`~optuna.study.Study.tell` for the reference.
This method has one additional parameter ``suppress_warning``.
Args:
suppress_warning:
If :obj:`True`, tell will not show warnings when tell receives an invalid
values. This flag is expected to be :obj:`True` only when it is invoked by
Study.optimize.
"""

if not isinstance(trial, (trial_module.Trial, int)):
raise TypeError("Trial must be a trial object or trial number.")

if state == TrialState.COMPLETE:
if values is None:
raise ValueError(
"No values were told. Values are required when state is TrialState.COMPLETE."
)
elif state in (TrialState.PRUNED, TrialState.FAIL):
if values is not None:
raise ValueError(
"Values were told. Values cannot be specified when state is "
"TrialState.PRUNED or TrialState.FAIL."
)
elif state is not None:
raise ValueError(f"Cannot tell with state {state}.")

def _get_frozen_trial(study: "optuna.Study", trial: Union[trial_module.Trial, int]) -> FrozenTrial:
if isinstance(trial, trial_module.Trial):
trial_number = trial.number
trial_id = trial._trial_id
elif isinstance(trial, int):
trial_number = trial
Expand Down Expand Up @@ -146,57 +52,144 @@ def _tell_with_warning(
"created."
) from e
else:
assert False, "Should not reach."
raise TypeError("Trial must be a trial object or trial number.")

return study._storage.get_trial(trial_id)

frozen_trial = study._storage.get_trial(trial_id)
warning_message = None

def _check_state_and_values(
state: Optional[TrialState], values: Optional[Union[float, Sequence[float]]]
) -> None:
if state == TrialState.COMPLETE:
if values is None:
raise ValueError(
"No values were told. Values are required when state is TrialState.COMPLETE."
)
elif state in (TrialState.PRUNED, TrialState.FAIL):
if values is not None:
raise ValueError(
"Values were told. Values cannot be specified when state is "
"TrialState.PRUNED or TrialState.FAIL."
)
elif state is not None:
raise ValueError(f"Cannot tell with state {state}.")


def _check_values_are_feasible(study: "optuna.Study", values: Sequence[float]) -> Optional[str]:
for v in values:
# TODO(Imamura): Construct error message taking into account all values and do not early
# return `value` is assumed to be ignored on failure so we can set it to any value.
try:
float(v)
except (ValueError, TypeError):
return f"The value {repr(v)} could not be cast to float."

if math.isnan(v):
return f"The value {v} is not acceptable."

if len(study.directions) != len(values):
return (
f"The number of the values {len(values)} did not match the number of the objectives "
f"{len(study.directions)}."
)

return None


def _tell_with_warning(
study: "optuna.Study",
trial: Union[trial_module.Trial, int],
value_or_values: Optional[Union[float, Sequence[float]]] = None,
state: Optional[TrialState] = None,
skip_if_finished: bool = False,
suppress_warning: bool = False,
) -> FrozenTrial:
"""Internal method of :func:`~optuna.study.Study.tell`.
Refer to the document for :func:`~optuna.study.Study.tell` for the reference.
This method has one additional parameter ``suppress_warning``.
Args:
suppress_warning:
If :obj:`True`, tell will not show warnings when tell receives an invalid
values. This flag is expected to be :obj:`True` only when it is invoked by
Study.optimize.
"""

# Validate the trial argument.
frozen_trial = _get_frozen_trial(study, trial)
if frozen_trial.state.is_finished() and skip_if_finished:
_logger.info(
f"Skipped telling trial {trial_number} with values "
f"{values} and state {state} since trial was already finished. "
f"Skipped telling trial {frozen_trial.number} with values "
f"{value_or_values} and state {state} since trial was already finished. "
f"Finished trial has values {frozen_trial.values} and state {frozen_trial.state}."
)
return copy.deepcopy(frozen_trial)
elif frozen_trial.state == TrialState.WAITING:
raise ValueError("Cannot tell a waiting trial.")
elif frozen_trial.state != TrialState.RUNNING:
raise ValueError(f"Cannot tell a {frozen_trial.state.name} trial.")

# Validate the state and values arguments.
values: Optional[Sequence[float]]
if value_or_values is None:
values = None
elif isinstance(value_or_values, Sequence):
values = value_or_values
else:
values = [value_or_values]

_check_state_and_values(state, values)

if state == TrialState.PRUNED:
warning_message = None

if state == TrialState.COMPLETE:
assert values is not None

values_conversion_failure_message = _check_values_are_feasible(study, values)
if values_conversion_failure_message is not None:
raise ValueError(values_conversion_failure_message)
elif state == TrialState.PRUNED:
# Register the last intermediate value if present as the value of the trial.
# TODO(hvy): Whether a pruned trials should have an actual value can be discussed.
assert values is None

last_step = frozen_trial.last_step
if last_step is not None:
values = [frozen_trial.intermediate_values[last_step]]

values, values_conversion_failure_message = _check_and_convert_to_values(
len(study.directions), values, trial_number
)

if state == TrialState.COMPLETE and values_conversion_failure_message is not None:
raise ValueError(values_conversion_failure_message)
last_intermediate_value = frozen_trial.intermediate_values[last_step]
# intermediate_values can be unacceptable value, i.e., NaN.
if _check_values_are_feasible(study, [last_intermediate_value]) is None:
values = [last_intermediate_value]
elif state is None:
if values is None:
values_conversion_failure_message = "The value None could not be cast to float."
else:
values_conversion_failure_message = _check_values_are_feasible(study, values)

if state is None:
if values_conversion_failure_message is None:
state = TrialState.COMPLETE
else:
state = TrialState.FAIL
values = None
if not suppress_warning:
warnings.warn(values_conversion_failure_message)
else:
warning_message = values_conversion_failure_message

assert state is not None

# Cast values to list of floats.
if values is not None:
# values have beed checked to be castable to floats in _check_values_are_feasible.
values = [float(value) for value in values]

# Post-processing and storing the trial.
try:
# Sampler defined trial post-processing.
study = pruners._filter_study(study, frozen_trial)
study.sampler.after_trial(study, frozen_trial, state, values)
finally:
study._storage.set_trial_state_values(trial_id, state, values)
study._storage.set_trial_state_values(frozen_trial._trial_id, state, values)

frozen_trial = copy.deepcopy(study._storage.get_trial(trial_id))
frozen_trial = copy.deepcopy(study._storage.get_trial(frozen_trial._trial_id))

if warning_message is not None:
frozen_trial.set_system_attr(STUDY_TELL_WARNING_KEY, warning_message)
Expand Down
6 changes: 5 additions & 1 deletion optuna/study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def df(x):
"""

return _tell_with_warning(
study=self, trial=trial, values=values, state=state, skip_if_finished=skip_if_finished
study=self,
trial=trial,
value_or_values=values,
state=state,
skip_if_finished=skip_if_finished,
)

def set_user_attr(self, key: str, value: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/study_tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def func_numerical(trial: Trial) -> float:
mock_obj.assert_called_once_with(
study=mock.ANY,
trial=mock.ANY,
values=mock.ANY,
value_or_values=mock.ANY,
state=mock.ANY,
suppress_warning=True,
)
2 changes: 1 addition & 1 deletion tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ def test_tell_duplicate_tell() -> None:
# Should not panic when passthrough is enabled.
study.tell(trial, 1.0, skip_if_finished=True)

with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
study.tell(trial, 1.0, skip_if_finished=False)


Expand Down

0 comments on commit 79f0329

Please sign in to comment.