Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions ax/core/parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,54 @@
from ax.exceptions.core import UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.string_utils import unsanitize_name
from ax.utils.common.sympy import extract_coefficient_dict_from_inequality
from ax.utils.common.sympy import (
extract_coefficient_dict_from_equality,
extract_coefficient_dict_from_inequality,
)
from pyre_extensions import none_throws


class ParameterConstraint(SortableBase):
"""Base class for linear parameter constraints.

Constraints are expressed as a SymPy parsable inequality string.
Supports both inequality constraints (``w^T x <= b``) and equality
constraints (``w^T x == b``). Exactly one of ``inequality`` or
``equality`` must be provided.
"""

def __init__(self, inequality: str) -> None:
"""Initialize ParameterConstraint
def __init__(
self,
inequality: str | None = None,
*,
equality: str | None = None,
) -> None:
"""Initialize ParameterConstraint.

Args:
inequality: String representation of the constraint. At this point in time
Ax only accepts linear inequality constraints.
inequality: String representation of a linear inequality
constraint, e.g. ``"x1 + x2 <= 3"``.
equality: String representation of a linear equality
constraint, e.g. ``"x1 + x2 == 3"``.

Exactly one of ``inequality`` or ``equality`` must be provided.
"""
self._inequality_str = inequality
if (inequality is None) == (equality is None):
raise UserInputError(
"Exactly one of `inequality` or `equality` must be provided."
)

coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=inequality
)
self.is_equality: bool = equality is not None
constraint_str = none_throws(equality if self.is_equality else inequality)
self._constraint_str: str = constraint_str

if self.is_equality:
coefficient_dict = extract_coefficient_dict_from_equality(
equality_str=constraint_str
)
else:
coefficient_dict = extract_coefficient_dict_from_inequality(
inequality_str=constraint_str
)

# Iterate through the coefficients to extract the parameter names and weights
# and the bound
Expand All @@ -47,9 +74,12 @@ def __init__(self, inequality: str) -> None:
# Invert because we are "moving" the bound to the right hand side
self._bound = -1 * coefficient
else:
constraint_type = "equality" if self.is_equality else "inequality"
other_type = "inequality" if self.is_equality else "equality"
raise UserInputError(
"Only linear inequality parameter constraints are supported, found "
f"{inequality}"
f"Only linear {constraint_type} parameter constraints are "
f"supported, found {constraint_str}. "
f"Did you mean to use the `{other_type}` argument?"
)

@property
Expand All @@ -59,7 +89,7 @@ def constraint_dict(self) -> dict[str, float]:

@property
def bound(self) -> float:
"""Get bound of the inequality of the constraint."""
"""Get bound of the constraint."""
return self._bound

@bound.setter
Expand All @@ -71,7 +101,9 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool:
"""Whether or not the set of parameter values satisfies the constraint.

Does a weighted sum of the parameter values based on the constraint_dict
and checks that the sum is less than the bound.
and checks against the bound. For inequality constraints checks
``weighted_sum <= bound``; for equality constraints checks
``|weighted_sum - bound| <= tolerance``.

Args:
parameter_dict: Map from parameter name to parameter value.
Expand All @@ -87,13 +119,15 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool:
float(parameter_dict[param]) * weight
for param, weight in self.constraint_dict.items()
)
# Expected `int` for 2nd anonymous parameter to call `int.__le__` but got
# `float`.
if self.is_equality:
return abs(weighted_sum - self._bound) <= 1e-8
return weighted_sum <= self._bound + 1e-8 # allow for numerical imprecision

def clone(self) -> ParameterConstraint:
"""Clone."""
return ParameterConstraint(inequality=self._inequality_str)
if self.is_equality:
return ParameterConstraint(equality=self._constraint_str)
return ParameterConstraint(inequality=self._constraint_str)

def clone_with_transformed_parameters(
self, transformed_parameters: dict[str, Parameter]
Expand All @@ -102,10 +136,11 @@ def clone_with_transformed_parameters(
return self.clone()

def __repr__(self) -> str:
op = "==" if self.is_equality else "<="
return (
"ParameterConstraint("
+ " + ".join(f"{v}*{k}" for k, v in sorted(self.constraint_dict.items()))
+ f" <= {self._bound})"
+ f" {op} {self._bound})"
)

@property
Expand Down
70 changes: 70 additions & 0 deletions ax/core/tests/test_parameter_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def setUp(self) -> None:
super().setUp()
self.constraint = ParameterConstraint(inequality="2 * x - 3 * y <= 6.0")
self.constraint_repr = "ParameterConstraint(2.0*x + -3.0*y <= 6.0)"
self.eq_constraint = ParameterConstraint(equality="x + y == 1.0")
self.eq_constraint_repr = "ParameterConstraint(1.0*x + 1.0*y == 1.0)"

def test_constraint_dict_and_bounds(self) -> None:
constraint = ParameterConstraint(inequality="x1 + x2 <= 1")
Expand Down Expand Up @@ -123,6 +125,74 @@ def test_Sortable(self) -> None:
)
self.assertTrue(constraint1 < constraint2)

def test_equality_constraint_init(self) -> None:
cases = [
("x + y == 1.0", {"x": 1.0, "y": 1.0}, 1.0),
("2 * x + 3 * y == 5.0", {"x": 2.0, "y": 3.0}, 5.0),
("-x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0),
("- x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0),
]
for expr, expected_dict, expected_bound in cases:
with self.subTest(expr=expr):
c = ParameterConstraint(equality=expr)
self.assertTrue(c.is_equality)
self.assertEqual(c.constraint_dict, expected_dict)
self.assertEqual(c.bound, expected_bound)

c_ineq = ParameterConstraint(inequality="x + y <= 1.0")
self.assertFalse(c_ineq.is_equality)

def test_equality_constraint_must_provide_one(self) -> None:
# Neither provided
with self.assertRaisesRegex(UserInputError, "Exactly one"):
ParameterConstraint()

# Both provided
with self.assertRaisesRegex(UserInputError, "Exactly one"):
ParameterConstraint(inequality="x <= 1", equality="x == 1")

def test_equality_constraint_check(self) -> None:
c = ParameterConstraint(equality="x + y == 1.0")

cases = [
({"x": 0.5, "y": 0.5}, True, "exact"),
({"x": 0.5, "y": 0.5 + 0.5e-9}, True, "within tolerance"),
({"x": 0.5, "y": 0.6}, False, "violated above"),
({"x": 0.5, "y": 0.4}, False, "violated below"),
]
for params, expected, label in cases:
with self.subTest(label=label):
self.assertEqual(c.check(params), expected)

def test_equality_constraint_repr(self) -> None:
self.assertEqual(str(self.eq_constraint), self.eq_constraint_repr)

def test_equality_constraint_clone(self) -> None:
clone = self.eq_constraint.clone()
self.assertTrue(clone.is_equality)
self.assertEqual(clone.constraint_dict, self.eq_constraint.constraint_dict)
self.assertEqual(clone.bound, self.eq_constraint.bound)

# Mutation of clone doesn't affect original
clone._bound = 99.0
self.assertNotEqual(self.eq_constraint.bound, clone.bound)

def test_equality_constraint_clone_with_transformed_parameters(self) -> None:
clone = self.eq_constraint.clone_with_transformed_parameters(
transformed_parameters={}
)
self.assertTrue(clone.is_equality)
self.assertEqual(clone.bound, self.eq_constraint.bound)

def test_equality_constraint_eq(self) -> None:
c1 = ParameterConstraint(equality="x + y == 1.0")
c2 = ParameterConstraint(equality="x + y == 1.0")
self.assertEqual(c1, c2)

# Different from inequality with same coefficients
c3 = ParameterConstraint(inequality="x + y <= 1.0")
self.assertNotEqual(c1, c3)


class ValidateConstraintParametersTest(TestCase):
def test_validate_constraint_parameters(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def object_attribute_dicts_find_unequal_fields(
equal = one_val is other_val is None or (one_val.db_id == other_val.db_id)
elif field == "_db_id":
equal = skip_db_id_check or one_val == other_val
# Do not check the inequality_str for ParameterConstraints, checking the bound
# and coefficients dict is sufficient.
elif field == "_inequality_str":
# Do not check the constraint string for ParameterConstraints, checking
# the bound and coefficients dict is sufficient.
elif field == "_constraint_str":
equal = True
else:
equal = is_ax_equal(one_val, other_val)
Expand Down
35 changes: 35 additions & 0 deletions ax/utils/common/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,41 @@ def extract_coefficient_dict_from_inequality(
}


def extract_coefficient_dict_from_equality(
equality_str: str,
) -> dict[Symbol, float]:
"""
Parse a string of the form ``"expr == bound"`` into a coefficient dictionary.

SymPy's ``sympify`` does not produce an ``Equality`` from ``==`` (Python
evaluates ``==`` as a boolean), so we split the string on ``" == "`` and
sympify each side independently.

All terms are moved to the left side (``lhs - rhs``), and the result is
returned as a coefficient dictionary -- the same format as
``extract_coefficient_dict_from_inequality``.
"""
parts = equality_str.split("==")
if len(parts) != 2:
raise UserInputError(
f"Expected an equality constraint containing '==', found {equality_str}"
)

try:
lhs = sympify(sanitize_name(parts[0].strip(), sanitize_parens=True))
rhs = sympify(sanitize_name(parts[1].strip(), sanitize_parens=True))
except SympifyError:
raise UserInputError(f"Could not parse equality constraint: {equality_str}")

if not isinstance(lhs, Expr) or not isinstance(rhs, Expr):
raise UserInputError(f"Could not parse equality constraint: {equality_str}")

expression = lhs - rhs
return {
key: float(value) for key, value in expression.as_coefficients_dict().items()
}


def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]:
"""Sanitize and sympify an objective expression string.

Expand Down
51 changes: 51 additions & 0 deletions ax/utils/common/tests/test_sympy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.exceptions.core import UserInputError
from ax.utils.common.sympy import extract_coefficient_dict_from_equality
from ax.utils.common.testutils import TestCase
from sympy import Symbol


def _to_str_keys(d: dict[Symbol, float]) -> dict[str, float]:
return {str(k): v for k, v in d.items()}


class ExtractCoefficientDictFromEqualityTest(TestCase):
def test_valid_expressions(self) -> None:
cases = [
("x + y == 1.0", {"x": 1.0, "y": 1.0, "1": -1.0}),
("2 * x + 3 * y == 5.0", {"x": 2.0, "y": 3.0, "1": -5.0}),
("x == 3", {"x": 1.0, "1": -3.0}),
("-x + y == 0", {"x": -1.0, "y": 1.0}),
("- x + y == 1.0", {"x": -1.0, "y": 1.0, "1": -1.0}),
("x + y == 1", {"x": 1.0, "y": 1.0, "1": -1.0}),
]
for expr, expected in cases:
with self.subTest(expr=expr):
result = _to_str_keys(extract_coefficient_dict_from_equality(expr))
self.assertEqual(result, expected)

def test_sanitized_names(self) -> None:
result = _to_str_keys(
extract_coefficient_dict_from_equality("foo.bar + foo.baz == 1")
)
self.assertEqual(result.pop("1"), -1.0)
self.assertEqual(len(result), 2)

def test_error_missing_operator(self) -> None:
with self.assertRaisesRegex(UserInputError, "=="):
extract_coefficient_dict_from_equality("x + y <= 1")

def test_error_multiple_operators(self) -> None:
with self.assertRaisesRegex(UserInputError, "=="):
extract_coefficient_dict_from_equality("x == y == 1")

def test_error_empty_string(self) -> None:
with self.assertRaisesRegex(UserInputError, "=="):
extract_coefficient_dict_from_equality("")
Loading