Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,54 @@
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.string_utils import unsanitize_name
from ax.utils.common.sympy import extract_coefficient_dict_from_inequality
from ax.utils.common.sympy import (
extract_coefficient_dict_from_equality,
extract_coefficient_dict_from_inequality,
)
from pyre_extensions import none_throws


class ParameterConstraint(SortableBase):
"""Base class for linear parameter constraints.

Constraints are expressed as a SymPy parsable inequality string.
Supports both inequality constraints (``w^T x <= b``) and equality
constraints (``w^T x == b``). Exactly one of ``inequality`` or
``equality`` must be provided.
"""

def __init__(self, inequality: str) -> None:
"""Initialize ParameterConstraint
def __init__(
self,
inequality: str | None = None,
*,
equality: str | None = None,
) -> None:
"""Initialize ParameterConstraint.

Args:
inequality: String representation of the constraint. At this point in time
Ax only accepts linear inequality constraints.
inequality: String representation of a linear inequality
constraint, e.g. ``"x1 + x2 <= 3"``.
equality: String representation of a linear equality
constraint, e.g. ``"x1 + x2 == 3"``.

Exactly one of ``inequality`` or ``equality`` must be provided.
"""
self._inequality_str = inequality
if (inequality is None) == (equality is None):
raise UserInputError(
"Exactly one of `inequality` or `equality` must be provided."
)

coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=inequality
)
self.is_equality: bool = equality is not None
constraint_str = none_throws(equality if self.is_equality else inequality)
self._constraint_str: str = constraint_str

if self.is_equality:
coefficient_dict = extract_coefficient_dict_from_equality(
equality_str=constraint_str
)
else:
coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=constraint_str
)

# Iterate through the coefficients to extract the parameter names and weights
# and the bound
Expand All @@ -47,9 +74,12 @@ def __init__(self, inequality: str) -> None:
# Invert because we are "moving" the bound to the right hand side
self._bound = -1 * coefficient
else:
constraint_type = "equality" if self.is_equality else "inequality"
other_type = "inequality" if self.is_equality else "equality"
raise UserInputError(
"Only linear inequality parameter constraints are supported, found "
f"{inequality}"
f"Only linear {constraint_type} parameter constraints are "
f"supported, found {constraint_str}. "
f"Did you mean to use the `{other_type}` argument?"
)

@property
Expand All @@ -59,7 +89,7 @@ def constraint_dict(self) -> dict[str, float]:

@property
def bound(self) -> float:
"""Get bound of the inequality of the constraint."""
"""Get bound of the constraint."""
return self._bound

@bound.setter
Expand All @@ -71,7 +101,9 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool:
"""Whether or not the set of parameter values satisfies the constraint.

Does a weighted sum of the parameter values based on the constraint_dict
and checks that the sum is less than the bound.
and checks against the bound. For inequality constraints checks
``weighted_sum <= bound``; for equality constraints checks
``|weighted_sum - bound| <= tolerance``.

Args:
parameter_dict: Map from parameter name to parameter value.
Expand All @@ -87,13 +119,15 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool:
float(parameter_dict[param]) * weight
for param, weight in self.constraint_dict.items()
)
# Expected `int` for 2nd anonymous parameter to call `int.__le__` but got
# `float`.
if self.is_equality:
return abs(weighted_sum - self._bound) <= 1e-8
return weighted_sum <= self._bound + 1e-8 # allow for numerical imprecision

def clone(self) -> ParameterConstraint:
"""Clone."""
return ParameterConstraint(inequality=self._inequality_str)
if self.is_equality:
return ParameterConstraint(equality=self._constraint_str)
return ParameterConstraint(inequality=self._constraint_str)

def clone_with_transformed_parameters(
self, transformed_parameters: dict[str, Parameter]
Expand All @@ -102,10 +136,11 @@ def clone_with_transformed_parameters(
return self.clone()

def __repr__(self) -> str:
op = "==" if self.is_equality else "<="
return (
"ParameterConstraint("
+ " + ".join(f"{v}*{k}" for k, v in sorted(self.constraint_dict.items()))
+ f" <= {self._bound})"
+ f" {op} {self._bound})"
)

@property
Expand Down
70 changes: 70 additions & 0 deletions ax/core/tests/test_parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def setUp(self) -> None:
super().setUp()
self.constraint = ParameterConstraint(inequality="2 * x - 3 * y <= 6.0")
self.constraint_repr = "ParameterConstraint(2.0*x + -3.0*y <= 6.0)"
self.eq_constraint = ParameterConstraint(equality="x + y == 1.0")
self.eq_constraint_repr = "ParameterConstraint(1.0*x + 1.0*y == 1.0)"

def test_constraint_dict_and_bounds(self) -> None:
constraint = ParameterConstraint(inequality="x1 + x2 <= 1")
Expand Down Expand Up @@ -123,6 +125,74 @@ def test_Sortable(self) -> None:
)
self.assertTrue(constraint1 < constraint2)

def test_equality_constraint_init(self) -> None:
cases = [
("x + y == 1.0", {"x": 1.0, "y": 1.0}, 1.0),
("2 * x + 3 * y == 5.0", {"x": 2.0, "y": 3.0}, 5.0),
("-x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0),
("- x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0),
]
for expr, expected_dict, expected_bound in cases:
with self.subTest(expr=expr):
c = ParameterConstraint(equality=expr)
self.assertTrue(c.is_equality)
self.assertEqual(c.constraint_dict, expected_dict)
self.assertEqual(c.bound, expected_bound)

c_ineq = ParameterConstraint(inequality="x + y <= 1.0")
self.assertFalse(c_ineq.is_equality)

def test_equality_constraint_must_provide_one(self) -> None:
# Neither provided
with self.assertRaisesRegex(UserInputError, "Exactly one"):
ParameterConstraint()

# Both provided
with self.assertRaisesRegex(UserInputError, "Exactly one"):
ParameterConstraint(inequality="x <= 1", equality="x == 1")

def test_equality_constraint_check(self) -> None:
c = ParameterConstraint(equality="x + y == 1.0")

cases = [
({"x": 0.5, "y": 0.5}, True, "exact"),
({"x": 0.5, "y": 0.5 + 0.5e-9}, True, "within tolerance"),
({"x": 0.5, "y": 0.6}, False, "violated above"),
({"x": 0.5, "y": 0.4}, False, "violated below"),
]
for params, expected, label in cases:
with self.subTest(label=label):
self.assertEqual(c.check(params), expected)

def test_equality_constraint_repr(self) -> None:
self.assertEqual(str(self.eq_constraint), self.eq_constraint_repr)

def test_equality_constraint_clone(self) -> None:
clone = self.eq_constraint.clone()
self.assertTrue(clone.is_equality)
self.assertEqual(clone.constraint_dict, self.eq_constraint.constraint_dict)
self.assertEqual(clone.bound, self.eq_constraint.bound)

# Mutation of clone doesn't affect original
clone._bound = 99.0
self.assertNotEqual(self.eq_constraint.bound, clone.bound)

def test_equality_constraint_clone_with_transformed_parameters(self) -> None:
clone = self.eq_constraint.clone_with_transformed_parameters(
transformed_parameters={}
)
self.assertTrue(clone.is_equality)
self.assertEqual(clone.bound, self.eq_constraint.bound)

def test_equality_constraint_eq(self) -> None:
c1 = ParameterConstraint(equality="x + y == 1.0")
c2 = ParameterConstraint(equality="x + y == 1.0")
self.assertEqual(c1, c2)

# Different from inequality with same coefficients
c3 = ParameterConstraint(inequality="x + y <= 1.0")
self.assertNotEqual(c1, c3)


class ValidateConstraintParametersTest(TestCase):
def test_validate_constraint_parameters(self) -> None:
Expand Down
39 changes: 33 additions & 6 deletions ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ def test_constraint_from_str(self) -> None:
)
with self.assertRaisesRegex(
ValueError,
(
r"Received invalid parameter constraint format: "
r"`x1 \+ x2 \+ x3 = 3`\. "
r"Please use one of the following forms:\n"
r".*\n.*\n.*\nAcceptable comparison operators are \">=\" and \"<=\"\."
),
r"Received invalid parameter constraint format: `x1 \+ x2 \+ x3 = 3`",
):
InstantiationBase.constraint_from_str(
"x1 + x2 + x3 = 3", {"x1": x1, "x2": x2, "x3": x3}
Expand Down Expand Up @@ -193,6 +188,38 @@ def test_constraint_from_str(self) -> None:
"x1 + x2 / 2.0 + x3 >= 3", {"x1": x1, "x2": x2, "x3": x3}
)

# --- Equality constraints ---
eq_constraint = InstantiationBase.constraint_from_str(
"x1 + x2 == 1", {"x1": x1, "x2": x2}
)
self.assertTrue(eq_constraint.is_equality)
self.assertEqual(eq_constraint.constraint_dict, {"x1": 1.0, "x2": 1.0})
self.assertEqual(eq_constraint.bound, 1.0)

# Weighted equality
eq_weighted = InstantiationBase.constraint_from_str(
"2*x1 + 3*x2 == 5", {"x1": x1, "x2": x2}
)
self.assertTrue(eq_weighted.is_equality)
self.assertEqual(eq_weighted.constraint_dict, {"x1": 2.0, "x2": 3.0})
self.assertEqual(eq_weighted.bound, 5.0)

# Single parameter equality
eq_single = InstantiationBase.constraint_from_str(
"x1 == 3", {"x1": x1, "x2": x2}
)
self.assertTrue(eq_single.is_equality)
self.assertEqual(eq_single.constraint_dict, {"x1": 1.0})
self.assertEqual(eq_single.bound, 3.0)

# Order equality constraint should error
with self.assertRaisesRegex(ValueError, "DerivedParameter"):
InstantiationBase.constraint_from_str("x1 == x2", {"x1": x1, "x2": x2})

# Linear equality that equates two params should also error
with self.assertRaisesRegex(ValueError, "DerivedParameter"):
InstantiationBase.constraint_from_str("x1 - x2 == 0", {"x1": x1, "x2": x2})

def test_spaces_in_metric_and_parameter_names(self) -> None:
# Metric and parameter names with spaces are allowed everywhere
# except in constraint string parsing, where split() would break.
Expand Down
Loading
Loading