Skip to content

Commit

Permalink
SimplifyParameterConstraints (#2326)
Browse files Browse the repository at this point in the history
Summary:

Remove parameter constraints that can be trivially converted into an updated lower/upper bound

Reviewed By: SebastianAment

Differential Revision: D55718753
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Apr 5, 2024
1 parent 05eb25f commit 6cab8e3
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 0 deletions.
69 changes: 69 additions & 0 deletions ax/modelbridge/transforms/simplify_parameter_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import math
from typing import List, TYPE_CHECKING

from ax.core.parameter import FixedParameter, ParameterType, RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.utils.common.typeutils import checked_cast

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


class SimplifyParameterConstraints(Transform):
"""Convert parameter constraints on one parameter to an updated bound.
This transform converts parameter constraints on only one parameter into an updated
upper or lower bound. Note that this transform will convert parameters that can only
take on one value into a `FixedParameter`. Make sure this transform is applied
before `RemoveFixed` if you want to remove all fixed parameters.
"""

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
# keeps track of the constraints that cannot be converted to bounds
nontrivial_constraints: List[ParameterConstraint] = []
for pc in search_space.parameter_constraints:
if len(pc.constraint_dict) == 1:
# This can be turned into an updated bound since only one variable is
# involved in the constraint.
[(p_name, weight)] = pc.constraint_dict.items()
# NOTE: We only allow parameter constraints on range parameters
p = checked_cast(RangeParameter, search_space.parameters[p_name])
lb, ub = p.lower, p.upper
if weight == 0 and pc.bound < 0: # Cannot be satisfied
raise ValueError(
"Parameter constraint cannot be satisfied since the weight "
"is zero and the bound is negative."
)
elif weight == 0: # Constraint is always satisfied
continue
elif weight > 0: # New upper bound
ub = float(pc.bound) / float(weight)
if p.parameter_type == ParameterType.INT:
ub = math.floor(ub) # Round down
else: # New lower bound
lb = float(pc.bound) / float(weight)
if p.parameter_type == ParameterType.INT:
lb = math.ceil(lb) # Round up

if lb == ub: # Need to turn this into a fixed parameter
search_space.parameters[p_name] = FixedParameter(
name=p_name, parameter_type=p.parameter_type, value=lb
)
elif weight > 0:
p._upper = ub
else:
p._lower = lb
else:
nontrivial_constraints.append(pc)
search_space.set_parameter_constraints(nontrivial_constraints)
return search_space
122 changes: 122 additions & 0 deletions ax/modelbridge/transforms/tests/test_simplify_parameter_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from copy import deepcopy
from typing import List

from ax.core.observation import ObservationFeatures
from ax.core.parameter import (
ChoiceParameter,
FixedParameter,
Parameter,
ParameterType,
RangeParameter,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.modelbridge.transforms.simplify_parameter_constraints import (
SimplifyParameterConstraints,
)
from ax.utils.common.testutils import TestCase


class SimplifyParameterConstraintsTest(TestCase):
def setUp(self) -> None:
self.parameters: List[Parameter] = [
RangeParameter("x", lower=1, upper=3, parameter_type=ParameterType.FLOAT),
RangeParameter("y", lower=2, upper=5, parameter_type=ParameterType.INT),
ChoiceParameter(
"z", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
]
self.observation_features = [
ObservationFeatures(parameters={"x": 2, "y": 2, "z": "b"})
]

def test_transform_no_constraints(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(parameters=self.parameters)
ss_transformed = t.transform_search_space(search_space=ss)
self.assertEqual(ss, ss_transformed)
self.assertEqual(
self.observation_features,
t.transform_observation_features(self.observation_features),
)

def test_transform_weight_zero(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 0}, bound=1)
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual(ss.parameters, ss_transformed.parameters)
ss_raises = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 0}, bound=-1)
],
)
with self.assertRaisesRegex(
ValueError, "Parameter constraint cannot be satisfied since the weight"
):
ss_transformed = t.transform_search_space(search_space=deepcopy(ss_raises))

def test_transform_search_space(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 1}, bound=2), # x <= 2
ParameterConstraint(constraint_dict={"y": -1}, bound=-4), # y => 4
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(
{
**ss.parameters,
"x": RangeParameter(
"x", parameter_type=ParameterType.FLOAT, lower=1, upper=2
),
"y": RangeParameter(
"y", parameter_type=ParameterType.INT, lower=4, upper=5
),
},
ss_transformed.parameters,
)
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual( # No-op
self.observation_features,
t.transform_observation_features(self.observation_features),
)

def test_transform_to_fixed(self) -> None:
t = SimplifyParameterConstraints()
ss = SearchSpace(
parameters=self.parameters,
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": 1}, bound=1), # x == 1
ParameterConstraint(constraint_dict={"y": -1}, bound=-5), # y == 5
],
)
ss_transformed = t.transform_search_space(search_space=deepcopy(ss))
self.assertEqual(
{
**ss.parameters,
"x": FixedParameter("x", parameter_type=ParameterType.FLOAT, value=1),
"y": FixedParameter("y", parameter_type=ParameterType.INT, value=5),
},
ss_transformed.parameters,
)
self.assertEqual(ss_transformed.parameter_constraints, [])
self.assertEqual( # No-op
self.observation_features,
t.transform_observation_features(self.observation_features),
)
4 changes: 4 additions & 0 deletions ax/storage/transform_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
)
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.search_space_to_choice import SearchSpaceToChoice
from ax.modelbridge.transforms.simplify_parameter_constraints import (
SimplifyParameterConstraints,
)
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
from ax.modelbridge.transforms.task_encode import TaskEncode
Expand Down Expand Up @@ -79,6 +82,7 @@
LogY: 23,
Relativize: 24,
RelativizeWithConstantControl: 25,
SimplifyParameterConstraints: 26,
}


Expand Down
8 changes: 8 additions & 0 deletions sphinx/source/modelbridge.rst
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,14 @@ Transforms
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.simplify_parameter_constraints`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.modelbridge.transforms.simplify_parameter_constraints
:members:
:undoc-members:
:show-inheritance:

`ax.modelbridge.transforms.standardize\_y`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down

0 comments on commit 6cab8e3

Please sign in to comment.