Skip to content

Commit

Permalink
Add support for conditionals in TrialToArrayConverter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586099252
  • Loading branch information
chansoo-google authored and Copybara-Service committed Nov 28, 2023
1 parent c07bfe5 commit 2938d30
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 44 deletions.
64 changes: 47 additions & 17 deletions vizier/_src/pyvizier/shared/parameter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""ParameterConfig wraps ParameterConfig and ParameterSpec protos."""

import collections
from typing import Collection, Set as AbstractSet, Sized
from typing import Iterable, Set as AbstractSet, Sized
import copy
import enum
import json
Expand Down Expand Up @@ -410,7 +410,7 @@ def child_parameter_configs(self) -> List['ParameterConfig']:

def subspaces(
self,
) -> Collection[Tuple[ParameterValueTypes, 'SearchSpace']]:
) -> Iterable[Tuple[ParameterValueTypes, 'SearchSpace']]:
return self._children.items()

# TODO: TO BE DEPRECATED.
Expand Down Expand Up @@ -724,24 +724,26 @@ def subspace(self, value: ParameterValueTypes) -> 'SearchSpace':
return self._children[value]


ParameterConfigOrConfigs = Union[ParameterConfig, Collection[ParameterConfig]]


@attr.define(init=False)
class ParameterConfigSelector(Sized):
class ParameterConfigSelector(Iterable[ParameterConfig], Sized):
"""Holds a reference to ParameterConfigs."""

# Selected configs.
_selected: tuple[ParameterConfig] = attr.field(init=True)
_selected: tuple[ParameterConfig] = attr.field(init=True, converter=tuple)

def __iter__(self) -> Iterator[ParameterConfig]:
return iter(self._selected)

def __len__(self) -> int:
return len(self._selected)

def __init__(self, selected: ParameterConfigOrConfigs):
if isinstance(selected, Collection):
self.__attrs_init__(tuple(selected))
else:
def __init__(
self, selected: Union[ParameterConfig, Iterable[ParameterConfig]], /
):
if isinstance(selected, ParameterConfig):
self.__attrs_init__(tuple([selected]))
else:
self.__attrs_init__(tuple(selected))

def select_values(
self, values: MonotypeParameterSequence
Expand All @@ -762,13 +764,29 @@ def select_values(
spaces.append(config.subspace(value))
return SearchSpaceSelector(spaces)

def merge(self) -> 'ParameterConfigSelector':
"""Merge by taking the union of the parameter configs with the same name.
Returns:
The returned ParameterConfigSelector does not contain parameters with
duplicate names. Their feasible set (either as a range or discrete set) is
the union of all feasible sets under the same parameter name.
"""
merged_configs = {}
for parameter_config in self:
name = parameter_config.name # Alias
existing_config = merged_configs.setdefault(name, parameter_config)
merged_configs[name] = ParameterConfig.merge(
existing_config, parameter_config
)
return ParameterConfigSelector(merged_configs.values())


class InvalidParameterError(ValueError):
"""Error thrown when parameter values are invalid."""


################### Main Classes ###################
SearchSpaceOrSpaces = Union['SearchSpace', Collection['SearchSpace']]


@attr.define(init=False)
Expand All @@ -783,11 +801,13 @@ class SearchSpaceSelector:
def __len__(self) -> int:
return len(self._selected)

def __init__(self, selected: SearchSpaceOrSpaces):
if isinstance(selected, Collection):
self.__attrs_init__(tuple(selected))
else:
def __init__(
self, selected: Union['SearchSpace', Iterable['SearchSpace']], /
):
if isinstance(selected, SearchSpace):
self.__attrs_init__(tuple([selected]))
else:
self.__attrs_init__(tuple(selected))

def add_float_param(
self,
Expand Down Expand Up @@ -1239,7 +1259,7 @@ def parse_multi_dimensional_parameter_name(

# TODO: Add def extend(space: SearchSpace)
def _add_parameters(
self, parameters: List[ParameterConfig]
self, parameters: Iterable[ParameterConfig]
) -> ParameterConfigSelector:
"""Adds deepcopy of the ParameterConfigs.
Expand All @@ -1249,6 +1269,7 @@ def _add_parameters(
Returns:
A list of SearchSpaceSelectors, one for each parameters added.
"""
parameters = list(parameters)
logging.info(
'Adding child parameters %s to %s subspaces ',
set(p.name for p in parameters),
Expand All @@ -1262,6 +1283,15 @@ def _add_parameters(

return ParameterConfigSelector(added)

def select_all(self) -> ParameterConfigSelector:
"""Select all parameters at all levels."""
all_parameter_configs = []
for space in self._selected:
for top_level_config in space.parameters:
all_parameter_configs.extend(list(top_level_config.traverse()))

return ParameterConfigSelector(all_parameter_configs)


@attr.define(frozen=False, init=True, slots=True, kw_only=True)
class SearchSpace:
Expand Down
20 changes: 20 additions & 0 deletions vizier/_src/pyvizier/shared/parameter_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from typing import Any

from absl import logging
from vizier._src.pyvizier.shared import parameter_config as pc
from vizier._src.pyvizier.shared import trial
from vizier.testing import test_studies

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -557,6 +559,24 @@ def testValidateCategoricalInput(self):
root.add_categorical_param('categorical', ['3.2', '2', 5])


class FlattenAndMergeTest(absltest.TestCase):

def testFlattenAndMerge(self):
space = test_studies.conditional_automl_space()
parameters = space.root.select_all().merge()
logging.info('Merged: %s', parameters)
self.assertCountEqual(
[p.name for p in parameters],
[
'model_type',
'learning_rate',
'optimizer_type',
'use_special_logic',
'special_logic_parameter',
],
)


class SearchSpaceContainsTest(absltest.TestCase):

def _space(self):
Expand Down
44 changes: 17 additions & 27 deletions vizier/pyvizier/converters/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import copy
import dataclasses
import enum
import itertools
from typing import Any, Callable, Collection, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union

from absl import logging
Expand Down Expand Up @@ -1147,29 +1146,17 @@ def from_study_configs(
Returns:
`DefaultTrialConverter`.
"""
# Cache ParameterConfigs.
# Traverse through all parameter configs and merge the same-named ones.
parameter_configs: Dict[str, pyvizier.ParameterConfig] = dict()
for study_config in study_configs:
all_parameter_configs = itertools.chain.from_iterable(
[
top_level_config.traverse()
for top_level_config in study_config.search_space.parameters
]
)
for parameter_config in all_parameter_configs:
name = parameter_config.name # Alias
existing_config = parameter_configs.get(name, None)
if existing_config is None:
parameter_configs[name] = parameter_config
else:
parameter_configs[name] = pyvizier.ParameterConfig.merge(
existing_config, parameter_config
)
# Merge parameter configs by name.
merged_configs = list(
pyvizier.SearchSpaceSelector([sc.search_space for sc in study_configs])
.select_all()
.merge()
)

parameter_converters = []
for pc in parameter_configs.values():
parameter_converters.append(DefaultModelInputConverter(pc))
merged_configs = {pc.name: pc for pc in merged_configs}
parameter_converters = [
DefaultModelInputConverter(pc) for pc in merged_configs.values()
]

# Append study id feature if configured to do so.
if use_study_id_feature:
Expand All @@ -1188,17 +1175,17 @@ def from_study_configs(
'had study id configured.'
)
use_study_id_feature = False
elif STUDY_ID_FIELD in parameter_configs:
elif STUDY_ID_FIELD in merged_configs:
raise ValueError(
'Dataset name conflicts with a ParameterConfig '
'that already exists: {}'.format(parameter_configs[STUDY_ID_FIELD])
'that already exists: {}'.format(merged_configs[STUDY_ID_FIELD])
)

# Create new parameter config.
parameter_config = pyvizier.ParameterConfig.factory(
STUDY_ID_FIELD, feasible_values=list(study_ids)
)
parameter_configs[STUDY_ID_FIELD] = parameter_config
merged_configs[STUDY_ID_FIELD] = parameter_config
logging.info('Created a new ParameterConfig %s', parameter_config)

# Create converter.
Expand Down Expand Up @@ -1309,7 +1296,10 @@ def create_output_converter(metric):

sc = study_config # alias, to keep pylint quiet in the next line.
converter = DefaultTrialConverter(
[create_input_converter(p) for p in sc.search_space.parameters],
[
create_input_converter(p)
for p in sc.search_space.root.select_all().merge()
],
[create_output_converter(m) for m in sc.metric_information],
)
return cls(converter)
Expand Down
16 changes: 16 additions & 0 deletions vizier/pyvizier/converters/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,22 @@
Trial = pyvizier.Trial


class TrialToArrayConverterConditionalSpaceTest(parameterized.TestCase):

def test_automl_study(self):
space = test_studies.conditional_automl_space()
converter = core.TrialToArrayConverter.from_study_config(
pyvizier.ProblemStatement(search_space=space)
)
features = converter.to_features([pyvizier.Trial()])
np.testing.assert_equal(
features,
np.array(
[[0.0, 0.0, 1.0, np.nan, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, np.nan]]
),
)


class TrialToArrayConverterTest(parameterized.TestCase):
"""Test TrialToArrayConverter class."""

Expand Down

0 comments on commit 2938d30

Please sign in to comment.