diff --git a/vizier/_src/benchmarks/experimenters/shifting_experimenter.py b/vizier/_src/benchmarks/experimenters/shifting_experimenter.py index ab9e401a6..095297aec 100644 --- a/vizier/_src/benchmarks/experimenters/shifting_experimenter.py +++ b/vizier/_src/benchmarks/experimenters/shifting_experimenter.py @@ -54,8 +54,10 @@ def __init__( exptr_problem_statement = exptr.problem_statement() if exptr_problem_statement.search_space.is_conditional: - raise ValueError('Search space should not have conditional' - f' parameters {exptr_problem_statement}') + raise ValueError( + 'Search space should not have conditional' + f' parameters {exptr_problem_statement}' + ) dimension = len(exptr_problem_statement.search_space.parameters) if dimension <= 0: raise ValueError(f'Invalid dimension: {dimension}') @@ -64,8 +66,8 @@ def __init__( self._shift = np.broadcast_to(shift, (dimension,)) except ValueError as broadcast_err: raise ValueError( - f'Shift {shift} is not broadcastable for dim: {dimension}.' - '\n') from broadcast_err + f'Shift {shift} is not broadcastable for dim: {dimension}.\n' + ) from broadcast_err # Converter should be in the underlying extpr space. self._converter = converters.TrialToArrayConverter.from_study_config( @@ -83,26 +85,28 @@ def __init__( ): if parameter.type != pyvizier.ParameterType.DOUBLE: raise ValueError(f'Non-double parameters {parameter}') - if (bounds := parameter.bounds) is not None: - if abs(shift) >= bounds[1] - bounds[0]: - raise ValueError( - f'Bounds {bounds} may need to be extended' - f'as shift {shift} is too large ' - ) - # Shift the bounds to maintain valid bounds. - if shift >= 0: - new_bounds = (bounds[0] + shift, bounds[1]) - else: - new_bounds = (bounds[0], bounds[1] + shift) - self._problem_statement.search_space.add( - pyvizier.ParameterConfig.factory( - name=parameter.name, - bounds=new_bounds, - scale_type=parameter.scale_type, - default_value=parameter.default_value, - external_type=parameter.external_type, - ) + if (bounds := parameter.bounds) is None: + raise ValueError(f'Parameter {parameter} has no bounds') + + if abs(shift) >= bounds[1] - bounds[0]: + raise ValueError( + f'Bounds {bounds} may need to be extended' + f'as shift {shift} is too large ' ) + # Shift the bounds to maintain valid bounds. + if shift >= 0: + new_bounds = (bounds[0] + shift, bounds[1]) + else: + new_bounds = (bounds[0], bounds[1] + shift) + self._problem_statement.search_space.add( + pyvizier.ParameterConfig.factory( + name=parameter.name, + bounds=new_bounds, + scale_type=parameter.scale_type, + default_value=parameter.default_value, + external_type=parameter.external_type, + ), + ) def problem_statement(self) -> pyvizier.ProblemStatement: return copy.deepcopy(self._problem_statement) @@ -116,8 +120,9 @@ def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None: for parameters, suggestion in zip(previous_parameters, suggestions): suggestion.parameters = parameters - def _offset(self, suggestions: Sequence[pyvizier.Trial], - shift: np.ndarray) -> None: + def _offset( + self, suggestions: Sequence[pyvizier.Trial], shift: np.ndarray + ) -> None: """Offsets parameter values (OOB values are clipped).""" for suggestion in suggestions: features = self._converter.to_features([suggestion])