Skip to content

Commit

Permalink
Small update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635500763
  • Loading branch information
xingyousong authored and Copybara-Service committed May 20, 2024
1 parent c48220d commit 0b4577c
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions vizier/_src/benchmarks/experimenters/shifting_experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __init__(
exptr_problem_statement = exptr.problem_statement()

if exptr_problem_statement.search_space.is_conditional:
raise ValueError('Search space should not have conditional'
f' parameters {exptr_problem_statement}')
raise ValueError(
'Search space should not have conditional'
f' parameters {exptr_problem_statement}'
)
dimension = len(exptr_problem_statement.search_space.parameters)
if dimension <= 0:
raise ValueError(f'Invalid dimension: {dimension}')
Expand All @@ -64,8 +66,8 @@ def __init__(
self._shift = np.broadcast_to(shift, (dimension,))
except ValueError as broadcast_err:
raise ValueError(
f'Shift {shift} is not broadcastable for dim: {dimension}.'
'\n') from broadcast_err
f'Shift {shift} is not broadcastable for dim: {dimension}.\n'
) from broadcast_err

# Converter should be in the underlying extpr space.
self._converter = converters.TrialToArrayConverter.from_study_config(
Expand All @@ -83,26 +85,28 @@ def __init__(
):
if parameter.type != pyvizier.ParameterType.DOUBLE:
raise ValueError(f'Non-double parameters {parameter}')
if (bounds := parameter.bounds) is not None:
if abs(shift) >= bounds[1] - bounds[0]:
raise ValueError(
f'Bounds {bounds} may need to be extended'
f'as shift {shift} is too large '
)
# Shift the bounds to maintain valid bounds.
if shift >= 0:
new_bounds = (bounds[0] + shift, bounds[1])
else:
new_bounds = (bounds[0], bounds[1] + shift)
self._problem_statement.search_space.add(
pyvizier.ParameterConfig.factory(
name=parameter.name,
bounds=new_bounds,
scale_type=parameter.scale_type,
default_value=parameter.default_value,
external_type=parameter.external_type,
)
if (bounds := parameter.bounds) is None:
raise ValueError(f'Parameter {parameter} has no bounds')

if abs(shift) >= bounds[1] - bounds[0]:
raise ValueError(
f'Bounds {bounds} may need to be extended'
f'as shift {shift} is too large '
)
# Shift the bounds to maintain valid bounds.
if shift >= 0:
new_bounds = (bounds[0] + shift, bounds[1])
else:
new_bounds = (bounds[0], bounds[1] + shift)
self._problem_statement.search_space.add(
pyvizier.ParameterConfig.factory(
name=parameter.name,
bounds=new_bounds,
scale_type=parameter.scale_type,
default_value=parameter.default_value,
external_type=parameter.external_type,
),
)

def problem_statement(self) -> pyvizier.ProblemStatement:
return copy.deepcopy(self._problem_statement)
Expand All @@ -116,8 +120,9 @@ def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None:
for parameters, suggestion in zip(previous_parameters, suggestions):
suggestion.parameters = parameters

def _offset(self, suggestions: Sequence[pyvizier.Trial],
shift: np.ndarray) -> None:
def _offset(
self, suggestions: Sequence[pyvizier.Trial], shift: np.ndarray
) -> None:
"""Offsets parameter values (OOB values are clipped)."""
for suggestion in suggestions:
features = self._converter.to_features([suggestion])
Expand Down

0 comments on commit 0b4577c

Please sign in to comment.