Skip to content

Commit

Permalink
Add a wrapper to allow fairer comparisons between Vizier and other ra…
Browse files Browse the repository at this point in the history
…ytune algorithms.

This wrapper ensures that all algorithms start with the same default parameters.

PiperOrigin-RevId: 635475355
  • Loading branch information
chansoo-google authored and Copybara-Service committed May 20, 2024
1 parent e1f2347 commit c48220d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions vizier/_src/pythia/suggest_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_T = TypeVar('_T')


def _get_default_parameters(search_space: vz.SearchSpace) -> vz.ParameterDict:
def get_default_parameters(search_space: vz.SearchSpace) -> vz.ParameterDict:
"""Gets the default parameters for the given search space."""
builder = vz.SequentialParameterBuilder(search_space)

Expand Down Expand Up @@ -94,7 +94,7 @@ def wrapper_fn(self: Policy, request: SuggestRequest) -> SuggestDecision:
if request.max_trial_id > 0:
return suggest_fn(self, request)

default_parameters = _get_default_parameters(
default_parameters = get_default_parameters(
request.study_config.search_space
)
decision = SuggestDecision([vz.TrialSuggestion(default_parameters)])
Expand Down
10 changes: 5 additions & 5 deletions vizier/_src/pythia/suggest_default_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GetDefaultParametersTest(absltest.TestCase):
def test_double_user_default(self):
ss = vz.SearchSpace()
ss.root.add_float_param('x', 0.0, 1.0, default_value=0.2)
params = suggest_default._get_default_parameters(ss)
params = suggest_default.get_default_parameters(ss)
self.assertEqual(params.get_value('x'), 0.2)

@unittest.skip('TODO')
Expand All @@ -48,25 +48,25 @@ def test_double_logscale(self):
ss.root.add_float_param(
'x', np.exp(-2), np.exp(2), scale_type=vz.ScaleType.LOG
)
params = suggest_default._get_default_parameters(ss)
params = suggest_default.get_default_parameters(ss)
self.assertEqual(params.get_value('x'), 1.0)

def test_double_fixed(self):
ss = vz.SearchSpace()
ss.root.add_float_param('x', 1.0, 1.0)
params = suggest_default._get_default_parameters(ss)
params = suggest_default.get_default_parameters(ss)
self.assertEqual(params.get_value('x'), 1.0)

def test_discrete(self):
ss = vz.SearchSpace()
ss.root.add_discrete_param('x', [1, 2, 3, 6])
params = suggest_default._get_default_parameters(ss)
params = suggest_default.get_default_parameters(ss)
self.assertEqual(params.get_value('x'), 3)

def test_categorical(self):
ss = vz.SearchSpace()
ss.root.add_categorical_param('x', ['a', 'b', 'c', 'd'])
params = suggest_default._get_default_parameters(ss)
params = suggest_default.get_default_parameters(ss)
self.assertEqual(params.get_value('x'), 'c')


Expand Down

0 comments on commit c48220d

Please sign in to comment.