Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions ax/adapter/tests/test_torch_moo_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,7 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None:
)
fixed_features = ObservationFeatures(parameters={"x1": 0.0})
search_space = exp.search_space.clone()
param_constraints = [
ParameterConstraint(constraint_dict={"x1": 1.0}, bound=10.0)
]
param_constraints = [ParameterConstraint(inequality="x1 <= 10")]
search_space.add_parameter_constraints(param_constraints)
oc = none_throws(exp.optimization_config).clone()
oc.objective._objectives[0].minimize = True
Expand Down
4 changes: 1 addition & 3 deletions ax/adapter/transforms/tests/test_choice_encode_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def setUp(self) -> None:
sort_values=False,
),
],
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
],
parameter_constraints=[ParameterConstraint(inequality="-0.5*x + a <= 0.5")],
)
self.t = self.t_class(search_space=self.search_space)
input_params: TParameterization = {
Expand Down
4 changes: 1 addition & 3 deletions ax/adapter/transforms/tests/test_one_hot_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def setUp(self) -> None:
is_ordered=True,
),
],
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5)
],
parameter_constraints=[ParameterConstraint(inequality="-0.5*x + a <= 0.5")],
)
self.t = OneHot(search_space=self.search_space)
self.t2 = OneHot(
Expand Down
8 changes: 4 additions & 4 deletions ax/adapter/transforms/tests/test_unit_x_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def setUp(self) -> None:
),
],
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": -0.5, "y": 1}, bound=0.5),
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
ParameterConstraint(inequality="-0.5*x + y <= 0.5"),
ParameterConstraint(inequality="-0.5*x + a <= 0.5"),
],
)
self.t = UnitX(search_space=self.search_space)
Expand Down Expand Up @@ -157,8 +157,8 @@ def test_TransformNewSearchSpace(self) -> None:
),
],
parameter_constraints=[
ParameterConstraint(constraint_dict={"x": -0.5, "y": 1}, bound=0.5),
ParameterConstraint(constraint_dict={"x": -0.5, "a": 1}, bound=0.5),
ParameterConstraint(inequality="-0.5*x + y <= 0.5"),
ParameterConstraint(inequality="-0.5*x + a <= 0.5"),
],
)
self.t.transform_search_space(new_ss)
Expand Down
8 changes: 7 additions & 1 deletion ax/adapter/transforms/unit_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
bound -= w * l
else:
constraint_dict[p_name] = w

expr = " + ".join(
f"{coeff} * {param}" for param, coeff in constraint_dict.items()
)
new_constraints.append(
ParameterConstraint(constraint_dict=constraint_dict, bound=bound)
ParameterConstraint(
inequality=f"{expr} <= {bound}",
)
)
search_space.set_parameter_constraints(new_constraints)
return search_space
Expand Down
5 changes: 1 addition & 4 deletions ax/analysis/healthcheck/tests/test_search_space_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ def test_search_space_boundary_proportions(self) -> None:
),
],
parameter_constraints=[
ParameterConstraint(
constraint_dict={"float_range_1": 1.0, "float_range_2": 1.0},
bound=4.0,
)
ParameterConstraint(inequality="float_range_1 + float_range_2 <= 4")
],
)

Expand Down
2 changes: 1 addition & 1 deletion ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_configure_experiment(self) -> None:
],
parameter_constraints=[
ParameterConstraint(
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
inequality="int_param <= float_param",
)
],
),
Expand Down
63 changes: 2 additions & 61 deletions ax/api/utils/instantiation/from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from collections.abc import Sequence

from ax.core.map_metric import MapMetric

from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
Expand All @@ -20,14 +19,13 @@
OutcomeConstraint,
ScalarizedOutcomeConstraint,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.exceptions.core import UserInputError
from ax.utils.common.string_utils import sanitize_name, unsanitize_name
from ax.utils.common.sympy import extract_coefficient_dict_from_inequality
from pyre_extensions import assert_is_instance, none_throws
from sympy.core.add import Add
from sympy.core.expr import Expr
from sympy.core.mul import Mul
from sympy.core.relational import GreaterThan, LessThan
from sympy.core.symbol import Symbol
from sympy.core.sympify import sympify

Expand Down Expand Up @@ -95,35 +93,6 @@ def optimization_config_from_string(
)


def parse_parameter_constraint(constraint_str: str) -> ParameterConstraint:
"""
Parse a parameter constraint string into a ParameterConstraint object using SymPy.
Currently only supports linear constraints of the form "a * x + b * y >= k" or
"a * x + b * y <= k".
"""
coefficient_dict = _extract_coefficient_dict_from_inequality(
inequality_str=constraint_str
)

# Iterate through the coefficients to extract the parameter names and weights and
# the bound
constraint_dict = {}
bound = 0
for term, coefficient in coefficient_dict.items():
if term.is_symbol:
constraint_dict[unsanitize_name(term.name)] = coefficient
elif term.is_number:
# Invert because we are "moving" the bound to the right hand side
bound = -1 * coefficient
else:
raise UserInputError(
"Only linear inequality parameter constraints are supported, found "
f"{constraint_str}"
)

return ParameterConstraint(constraint_dict=constraint_dict, bound=bound)


def parse_objective(objective_str: str) -> Objective:
"""
Parse an objective string into an Objective object using SymPy.
Expand Down Expand Up @@ -154,7 +123,7 @@ def parse_outcome_constraint(constraint_str: str) -> OutcomeConstraint:
multiply your bound by "baseline". For example "qps >= 0.95 * baseline" will
constrain such that the QPS is at least 95% of the baseline arm's QPS.
"""
coefficient_dict = _extract_coefficient_dict_from_inequality(
coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=constraint_str
)

Expand Down Expand Up @@ -248,31 +217,3 @@ def _create_single_objective(expression: Expr) -> Objective:
)

raise UserInputError(f"Only linear objectives are supported, found {expression}.")


def _extract_coefficient_dict_from_inequality(
inequality_str: str,
) -> dict[Symbol, float]:
"""
Use SymPy to parse a string into an inequality, invert if necessary to enforce a
less-than relationship, move all terms to the left side, and return the
coefficients as a dictionary. This is useful for parsing parameter and outcome
constraints.
"""
# Parse the constraint string into a SymPy inequality
inequality = sympify(sanitize_name(inequality_str))

# Check the SymPy object is a valid inequality
if not isinstance(inequality, GreaterThan | LessThan):
raise UserInputError(f"Expected an inequality, found {inequality_str}")

# Move all terms to the left side of the inequality and invert if necessary to
# enforce a less-than relationship
if isinstance(inequality, LessThan):
expression = inequality.lhs - inequality.rhs
else:
expression = inequality.rhs - inequality.lhs

return {
key: float(value) for key, value in expression.as_coefficients_dict().items()
}
10 changes: 6 additions & 4 deletions ax/api/utils/instantiation/from_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
# pyre-strict

from ax.api.utils.instantiation.from_config import parameter_from_config
from ax.api.utils.instantiation.from_string import parse_parameter_constraint
from ax.api.utils.structs import ExperimentStruct
from ax.core.evaluations_to_data import DataType
from ax.core.experiment import Experiment
from ax.core.parameter_constraint import validate_constraint_parameters
from ax.core.parameter_constraint import (
ParameterConstraint,
validate_constraint_parameters,
)
from ax.core.search_space import SearchSpace


Expand All @@ -22,8 +24,8 @@ def experiment_from_struct(struct: ExperimentStruct) -> Experiment:
]

constraints = [
parse_parameter_constraint(constraint_str=constraint_str)
for constraint_str in struct.parameter_constraints
ParameterConstraint(inequality=inequality)
for inequality in struct.parameter_constraints
]

# Ensure that all ParameterConstraints are valid and acting on existing parameters
Expand Down
4 changes: 2 additions & 2 deletions ax/api/utils/instantiation/tests/test_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def test_experiment_from_config(self) -> None:
],
parameter_constraints=[
ParameterConstraint(
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
inequality="int_param <= float_param",
)
],
),
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_experiment_from_config(self) -> None:
],
parameter_constraints=[
ParameterConstraint(
constraint_dict={"int_param": 1, "float_param": -1}, bound=0
inequality="int_param <= float_param",
)
],
),
Expand Down
42 changes: 0 additions & 42 deletions ax/api/utils/instantiation/tests/test_from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
optimization_config_from_string,
parse_objective,
parse_outcome_constraint,
parse_parameter_constraint,
)
from ax.core.map_metric import MapMetric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
Expand All @@ -22,7 +21,6 @@
OutcomeConstraint,
ScalarizedOutcomeConstraint,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase

Expand Down Expand Up @@ -97,46 +95,6 @@ def test_optimization_config_from_string(self) -> None:
),
)

def test_parse_parameter_constraint(self) -> None:
constraint = parse_parameter_constraint(constraint_str="x1 + x2 <= 1")
self.assertEqual(
constraint,
ParameterConstraint(constraint_dict={"x1": 1, "x2": 1}, bound=1.0),
)

with_coefficients = parse_parameter_constraint(
constraint_str="2 * x1 + 3 * x2 <= 1"
)
self.assertEqual(
with_coefficients,
ParameterConstraint(constraint_dict={"x1": 2, "x2": 3}, bound=1.0),
)

flipped_sign = parse_parameter_constraint(constraint_str="x1 + x2 >= 1")
self.assertEqual(
flipped_sign,
ParameterConstraint(constraint_dict={"x1": -1, "x2": -1}, bound=-1.0),
)

weird = parse_parameter_constraint(constraint_str="x1 + x2 <= 1.5 * x3 + 2")
self.assertEqual(
weird,
ParameterConstraint(
constraint_dict={"x1": 1, "x2": 1, "x3": -1.5}, bound=2.0
),
)

with self.assertRaisesRegex(UserInputError, "Only linear"):
parse_parameter_constraint(constraint_str="x1 * x2 <= 1")
# test with sanitization
constraint = parse_parameter_constraint(constraint_str="foo.bar + foo.baz <= 1")
self.assertEqual(
constraint,
ParameterConstraint(
constraint_dict={"foo.bar": 1, "foo.baz": 1}, bound=1.0
),
)

def test_parse_objective(self) -> None:
single_objective = parse_objective(objective_str="ne")
self.assertEqual(
Expand Down
44 changes: 29 additions & 15 deletions ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,46 @@

from ax.core.parameter import Parameter, RangeParameter
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.string_utils import unsanitize_name
from ax.utils.common.sympy import extract_coefficient_dict_from_inequality


class ParameterConstraint(SortableBase):
"""Base class for linear parameter constraints.

Constraints are expressed using a map from parameter name to weight
followed by a bound.

The constraint is satisfied if sum_i(w_i * v_i) <= b where:
w is the vector of parameter weights.
v is a vector of parameter values.
b is the specified bound.
Constraints are expressed as a SymPy parsable inequality string.
"""

def __init__(self, constraint_dict: dict[str, float], bound: float) -> None:
def __init__(self, inequality: str) -> None:
"""Initialize ParameterConstraint

Args:
constraint_dict: Map from parameter name to weight.
bound: Bound of the inequality of the constraint.
inequality: String representation of the constraint. At this point in time
Ax only accepts linear inequality constraints.
"""
self._constraint_dict = constraint_dict
self._bound = bound
self._inequality_str = inequality

coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=inequality
)

# Iterate through the coefficients to extract the parameter names and weights
# and the bound
self._constraint_dict: dict[str, float] = {}
self._bound: float = 0.0
for term, coefficient in coefficient_dict.items():
if term.is_symbol:
self._constraint_dict[unsanitize_name(term.name)] = coefficient
elif term.is_number:
# Invert because we are "moving" the bound to the right hand side
self._bound = -1 * coefficient
else:
raise UserInputError(
"Only linear inequality parameter constraints are supported, found "
f"{inequality}"
)

@property
def constraint_dict(self) -> dict[str, float]:
Expand Down Expand Up @@ -78,9 +94,7 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool:

def clone(self) -> ParameterConstraint:
"""Clone."""
return ParameterConstraint(
constraint_dict=self._constraint_dict.copy(), bound=self._bound
)
return ParameterConstraint(inequality=self._inequality_str)

def clone_with_transformed_parameters(
self, transformed_parameters: dict[str, Parameter]
Expand Down
Loading