From 0fa7084d706408bd95422a0af72db648d25ad09f Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 20 Apr 2026 09:40:02 -0700 Subject: [PATCH 1/2] Support equality parameter constraints in ParameterConstraint Summary: Add support for linear equality constraints (`w^T x == b`) alongside existing inequality constraints (`w^T x <= b`) in Ax's `ParameterConstraint` class. This is the first diff in a stack that threads equality constraints all the way down to BoTorch's `optimize_acqf`. Changes: - Add `extract_coefficient_dict_from_equality` to `ax/utils/common/sympy.py` for parsing `"expr == bound"` strings (SymPy can't parse `==` directly since Python evaluates it as a boolean). - Extend `ParameterConstraint.__init__` to accept `equality=` kwarg alongside existing `inequality=` kwarg. Exactly one must be provided. - Add `is_equality` property. - Update `check()` to use `|w^T x - b| <= tol` for equality constraints. - Update `__repr__`, `clone()`, `clone_with_transformed_parameters()`. - Add comprehensive tests for equality constraints. Differential Revision: D100256486 --- ax/core/parameter_constraint.py | 71 ++++++++++++++++------ ax/core/tests/test_parameter_constraint.py | 70 +++++++++++++++++++++ ax/utils/common/equality.py | 6 +- ax/utils/common/sympy.py | 35 +++++++++++ ax/utils/common/tests/test_sympy.py | 51 ++++++++++++++++ 5 files changed, 212 insertions(+), 21 deletions(-) create mode 100644 ax/utils/common/tests/test_sympy.py diff --git a/ax/core/parameter_constraint.py b/ax/core/parameter_constraint.py index bfdb8790859..cb55a9d209f 100644 --- a/ax/core/parameter_constraint.py +++ b/ax/core/parameter_constraint.py @@ -14,27 +14,54 @@ from ax.exceptions.core import UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.string_utils import unsanitize_name -from ax.utils.common.sympy import extract_coefficient_dict_from_inequality +from ax.utils.common.sympy import ( + extract_coefficient_dict_from_equality, + extract_coefficient_dict_from_inequality, +) +from pyre_extensions import none_throws class ParameterConstraint(SortableBase): """Base class for linear parameter constraints. - Constraints are expressed as a SymPy parsable inequality string. + Supports both inequality constraints (``w^T x <= b``) and equality + constraints (``w^T x == b``). Exactly one of ``inequality`` or + ``equality`` must be provided. """ - def __init__(self, inequality: str) -> None: - """Initialize ParameterConstraint + def __init__( + self, + inequality: str | None = None, + *, + equality: str | None = None, + ) -> None: + """Initialize ParameterConstraint. Args: - inequality: String representation of the constraint. At this point in time - Ax only accepts linear inequality constraints. + inequality: String representation of a linear inequality + constraint, e.g. ``"x1 + x2 <= 3"``. + equality: String representation of a linear equality + constraint, e.g. ``"x1 + x2 == 3"``. + + Exactly one of ``inequality`` or ``equality`` must be provided. """ - self._inequality_str = inequality + if (inequality is None) == (equality is None): + raise UserInputError( + "Exactly one of `inequality` or `equality` must be provided." + ) - coefficient_dict = extract_coefficient_dict_from_inequality( - inequality_str=inequality - ) + self.is_equality: bool = equality is not None + constraint_str = none_throws(equality if self.is_equality else inequality) + self._constraint_str: str = constraint_str + + if self.is_equality: + coefficient_dict = extract_coefficient_dict_from_equality( + equality_str=constraint_str + ) + else: + coefficient_dict = extract_coefficient_dict_from_inequality( + inequality_str=constraint_str + ) # Iterate through the coefficients to extract the parameter names and weights # and the bound @@ -47,9 +74,12 @@ def __init__(self, inequality: str) -> None: # Invert because we are "moving" the bound to the right hand side self._bound = -1 * coefficient else: + constraint_type = "equality" if self.is_equality else "inequality" + other_type = "inequality" if self.is_equality else "equality" raise UserInputError( - "Only linear inequality parameter constraints are supported, found " - f"{inequality}" + f"Only linear {constraint_type} parameter constraints are " + f"supported, found {constraint_str}. " + f"Did you mean to use the `{other_type}` argument?" ) @property @@ -59,7 +89,7 @@ def constraint_dict(self) -> dict[str, float]: @property def bound(self) -> float: - """Get bound of the inequality of the constraint.""" + """Get bound of the constraint.""" return self._bound @bound.setter @@ -71,7 +101,9 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool: """Whether or not the set of parameter values satisfies the constraint. Does a weighted sum of the parameter values based on the constraint_dict - and checks that the sum is less than the bound. + and checks against the bound. For inequality constraints checks + ``weighted_sum <= bound``; for equality constraints checks + ``|weighted_sum - bound| <= tolerance``. Args: parameter_dict: Map from parameter name to parameter value. @@ -87,13 +119,15 @@ def check(self, parameter_dict: dict[str, int | float]) -> bool: float(parameter_dict[param]) * weight for param, weight in self.constraint_dict.items() ) - # Expected `int` for 2nd anonymous parameter to call `int.__le__` but got - # `float`. + if self.is_equality: + return abs(weighted_sum - self._bound) <= 1e-8 return weighted_sum <= self._bound + 1e-8 # allow for numerical imprecision def clone(self) -> ParameterConstraint: """Clone.""" - return ParameterConstraint(inequality=self._inequality_str) + if self.is_equality: + return ParameterConstraint(equality=self._constraint_str) + return ParameterConstraint(inequality=self._constraint_str) def clone_with_transformed_parameters( self, transformed_parameters: dict[str, Parameter] @@ -102,10 +136,11 @@ def clone_with_transformed_parameters( return self.clone() def __repr__(self) -> str: + op = "==" if self.is_equality else "<=" return ( "ParameterConstraint(" + " + ".join(f"{v}*{k}" for k, v in sorted(self.constraint_dict.items())) - + f" <= {self._bound})" + + f" {op} {self._bound})" ) @property diff --git a/ax/core/tests/test_parameter_constraint.py b/ax/core/tests/test_parameter_constraint.py index 30bc1be4c2c..8425f7a1d02 100644 --- a/ax/core/tests/test_parameter_constraint.py +++ b/ax/core/tests/test_parameter_constraint.py @@ -20,6 +20,8 @@ def setUp(self) -> None: super().setUp() self.constraint = ParameterConstraint(inequality="2 * x - 3 * y <= 6.0") self.constraint_repr = "ParameterConstraint(2.0*x + -3.0*y <= 6.0)" + self.eq_constraint = ParameterConstraint(equality="x + y == 1.0") + self.eq_constraint_repr = "ParameterConstraint(1.0*x + 1.0*y == 1.0)" def test_constraint_dict_and_bounds(self) -> None: constraint = ParameterConstraint(inequality="x1 + x2 <= 1") @@ -123,6 +125,74 @@ def test_Sortable(self) -> None: ) self.assertTrue(constraint1 < constraint2) + def test_equality_constraint_init(self) -> None: + cases = [ + ("x + y == 1.0", {"x": 1.0, "y": 1.0}, 1.0), + ("2 * x + 3 * y == 5.0", {"x": 2.0, "y": 3.0}, 5.0), + ("-x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0), + ("- x + y == 1.0", {"x": -1.0, "y": 1.0}, 1.0), + ] + for expr, expected_dict, expected_bound in cases: + with self.subTest(expr=expr): + c = ParameterConstraint(equality=expr) + self.assertTrue(c.is_equality) + self.assertEqual(c.constraint_dict, expected_dict) + self.assertEqual(c.bound, expected_bound) + + c_ineq = ParameterConstraint(inequality="x + y <= 1.0") + self.assertFalse(c_ineq.is_equality) + + def test_equality_constraint_must_provide_one(self) -> None: + # Neither provided + with self.assertRaisesRegex(UserInputError, "Exactly one"): + ParameterConstraint() + + # Both provided + with self.assertRaisesRegex(UserInputError, "Exactly one"): + ParameterConstraint(inequality="x <= 1", equality="x == 1") + + def test_equality_constraint_check(self) -> None: + c = ParameterConstraint(equality="x + y == 1.0") + + cases = [ + ({"x": 0.5, "y": 0.5}, True, "exact"), + ({"x": 0.5, "y": 0.5 + 0.5e-9}, True, "within tolerance"), + ({"x": 0.5, "y": 0.6}, False, "violated above"), + ({"x": 0.5, "y": 0.4}, False, "violated below"), + ] + for params, expected, label in cases: + with self.subTest(label=label): + self.assertEqual(c.check(params), expected) + + def test_equality_constraint_repr(self) -> None: + self.assertEqual(str(self.eq_constraint), self.eq_constraint_repr) + + def test_equality_constraint_clone(self) -> None: + clone = self.eq_constraint.clone() + self.assertTrue(clone.is_equality) + self.assertEqual(clone.constraint_dict, self.eq_constraint.constraint_dict) + self.assertEqual(clone.bound, self.eq_constraint.bound) + + # Mutation of clone doesn't affect original + clone._bound = 99.0 + self.assertNotEqual(self.eq_constraint.bound, clone.bound) + + def test_equality_constraint_clone_with_transformed_parameters(self) -> None: + clone = self.eq_constraint.clone_with_transformed_parameters( + transformed_parameters={} + ) + self.assertTrue(clone.is_equality) + self.assertEqual(clone.bound, self.eq_constraint.bound) + + def test_equality_constraint_eq(self) -> None: + c1 = ParameterConstraint(equality="x + y == 1.0") + c2 = ParameterConstraint(equality="x + y == 1.0") + self.assertEqual(c1, c2) + + # Different from inequality with same coefficients + c3 = ParameterConstraint(inequality="x + y <= 1.0") + self.assertNotEqual(c1, c3) + class ValidateConstraintParametersTest(TestCase): def test_validate_constraint_parameters(self) -> None: diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index 595bed819e4..38088ad60e7 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -207,9 +207,9 @@ def object_attribute_dicts_find_unequal_fields( equal = one_val is other_val is None or (one_val.db_id == other_val.db_id) elif field == "_db_id": equal = skip_db_id_check or one_val == other_val - # Do not check the inequality_str for ParameterConstraints, checking the bound - # and coefficients dict is sufficient. - elif field == "_inequality_str": + # Do not check the constraint string for ParameterConstraints, checking + # the bound and coefficients dict is sufficient. + elif field == "_constraint_str": equal = True else: equal = is_ax_equal(one_val, other_val) diff --git a/ax/utils/common/sympy.py b/ax/utils/common/sympy.py index 4eeffa624fd..4a870f6a321 100644 --- a/ax/utils/common/sympy.py +++ b/ax/utils/common/sympy.py @@ -45,6 +45,41 @@ def extract_coefficient_dict_from_inequality( } +def extract_coefficient_dict_from_equality( + equality_str: str, +) -> dict[Symbol, float]: + """ + Parse a string of the form ``"expr == bound"`` into a coefficient dictionary. + + SymPy's ``sympify`` does not produce an ``Equality`` from ``==`` (Python + evaluates ``==`` as a boolean), so we split the string on ``" == "`` and + sympify each side independently. + + All terms are moved to the left side (``lhs - rhs``), and the result is + returned as a coefficient dictionary -- the same format as + ``extract_coefficient_dict_from_inequality``. + """ + parts = equality_str.split("==") + if len(parts) != 2: + raise UserInputError( + f"Expected an equality constraint containing '==', found {equality_str}" + ) + + try: + lhs = sympify(sanitize_name(parts[0].strip(), sanitize_parens=True)) + rhs = sympify(sanitize_name(parts[1].strip(), sanitize_parens=True)) + except SympifyError: + raise UserInputError(f"Could not parse equality constraint: {equality_str}") + + if not isinstance(lhs, Expr) or not isinstance(rhs, Expr): + raise UserInputError(f"Could not parse equality constraint: {equality_str}") + + expression = lhs - rhs + return { + key: float(value) for key, value in expression.as_coefficients_dict().items() + } + + def parse_objective_expression(expression_str: str) -> Expr | tuple[Expr, ...]: """Sanitize and sympify an objective expression string. diff --git a/ax/utils/common/tests/test_sympy.py b/ax/utils/common/tests/test_sympy.py new file mode 100644 index 00000000000..f740b692710 --- /dev/null +++ b/ax/utils/common/tests/test_sympy.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.exceptions.core import UserInputError +from ax.utils.common.sympy import extract_coefficient_dict_from_equality +from ax.utils.common.testutils import TestCase +from sympy import Symbol + + +def _to_str_keys(d: dict[Symbol, float]) -> dict[str, float]: + return {str(k): v for k, v in d.items()} + + +class ExtractCoefficientDictFromEqualityTest(TestCase): + def test_valid_expressions(self) -> None: + cases = [ + ("x + y == 1.0", {"x": 1.0, "y": 1.0, "1": -1.0}), + ("2 * x + 3 * y == 5.0", {"x": 2.0, "y": 3.0, "1": -5.0}), + ("x == 3", {"x": 1.0, "1": -3.0}), + ("-x + y == 0", {"x": -1.0, "y": 1.0}), + ("- x + y == 1.0", {"x": -1.0, "y": 1.0, "1": -1.0}), + ("x + y == 1", {"x": 1.0, "y": 1.0, "1": -1.0}), + ] + for expr, expected in cases: + with self.subTest(expr=expr): + result = _to_str_keys(extract_coefficient_dict_from_equality(expr)) + self.assertEqual(result, expected) + + def test_sanitized_names(self) -> None: + result = _to_str_keys( + extract_coefficient_dict_from_equality("foo.bar + foo.baz == 1") + ) + self.assertEqual(result.pop("1"), -1.0) + self.assertEqual(len(result), 2) + + def test_error_missing_operator(self) -> None: + with self.assertRaisesRegex(UserInputError, "=="): + extract_coefficient_dict_from_equality("x + y <= 1") + + def test_error_multiple_operators(self) -> None: + with self.assertRaisesRegex(UserInputError, "=="): + extract_coefficient_dict_from_equality("x == y == 1") + + def test_error_empty_string(self) -> None: + with self.assertRaisesRegex(UserInputError, "=="): + extract_coefficient_dict_from_equality("") From e715983c79932e8cee205c60dcbcff484a2debec Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 20 Apr 2026 09:45:25 -0700 Subject: [PATCH 2/2] Support equality constraints in service layer string parsing (#5174) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5174 Add support for parsing equality constraint strings (e.g. `"x1 + x2 == 3"`) in `constraint_from_str`. This extends the existing `<=`/`>=` parsing to also accept `==` as a comparison operator. - Add `_process_equality_constraint` function (analogous to `_process_linear_constraint`) that constructs `ParameterConstraint(equality=...)`. - Detect `==` in `constraint_from_str` and route to the new function. - Reject equality order constraints (`"x1 == x2"`) with a clear error message. - Update `INVALID_CONSTRAINT_ERROR_MSG` to document `==` support. Reviewed By: bletham Differential Revision: D100256487 --- ax/service/tests/test_instantiation_utils.py | 39 +++++- ax/service/utils/instantiation.py | 123 ++++++++++++++++--- 2 files changed, 139 insertions(+), 23 deletions(-) diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index f9b66862cc4..4c7f53ce511 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -115,12 +115,7 @@ def test_constraint_from_str(self) -> None: ) with self.assertRaisesRegex( ValueError, - ( - r"Received invalid parameter constraint format: " - r"`x1 \+ x2 \+ x3 = 3`\. " - r"Please use one of the following forms:\n" - r".*\n.*\n.*\nAcceptable comparison operators are \">=\" and \"<=\"\." - ), + r"Received invalid parameter constraint format: `x1 \+ x2 \+ x3 = 3`", ): InstantiationBase.constraint_from_str( "x1 + x2 + x3 = 3", {"x1": x1, "x2": x2, "x3": x3} @@ -193,6 +188,38 @@ def test_constraint_from_str(self) -> None: "x1 + x2 / 2.0 + x3 >= 3", {"x1": x1, "x2": x2, "x3": x3} ) + # --- Equality constraints --- + eq_constraint = InstantiationBase.constraint_from_str( + "x1 + x2 == 1", {"x1": x1, "x2": x2} + ) + self.assertTrue(eq_constraint.is_equality) + self.assertEqual(eq_constraint.constraint_dict, {"x1": 1.0, "x2": 1.0}) + self.assertEqual(eq_constraint.bound, 1.0) + + # Weighted equality + eq_weighted = InstantiationBase.constraint_from_str( + "2*x1 + 3*x2 == 5", {"x1": x1, "x2": x2} + ) + self.assertTrue(eq_weighted.is_equality) + self.assertEqual(eq_weighted.constraint_dict, {"x1": 2.0, "x2": 3.0}) + self.assertEqual(eq_weighted.bound, 5.0) + + # Single parameter equality + eq_single = InstantiationBase.constraint_from_str( + "x1 == 3", {"x1": x1, "x2": x2} + ) + self.assertTrue(eq_single.is_equality) + self.assertEqual(eq_single.constraint_dict, {"x1": 1.0}) + self.assertEqual(eq_single.bound, 3.0) + + # Order equality constraint should error + with self.assertRaisesRegex(ValueError, "DerivedParameter"): + InstantiationBase.constraint_from_str("x1 == x2", {"x1": x1, "x2": x2}) + + # Linear equality that equates two params should also error + with self.assertRaisesRegex(ValueError, "DerivedParameter"): + InstantiationBase.constraint_from_str("x1 - x2 == 0", {"x1": x1, "x2": x2}) + def test_spaces_in_metric_and_parameter_names(self) -> None: # Metric and parameter names with spaces are allowed everywhere # except in constraint string parsing, where split() would break. diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 0b9f017cf76..6e2399967de 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -86,6 +86,8 @@ } +COMPARISON_OPS_WITH_EQ: set[str] = {"<=", ">=", "=="} + INVALID_CONSTRAINT_ERROR_MSG = ( "Received invalid parameter constraint format: `{}`. " "Please use one of the following forms:\n" @@ -96,7 +98,10 @@ "* Weighted linear constraints: `* >= ` or " "`* + * <= `, where you can add one or more weighted terms on " "the left side, and there should be no spaces between weights and parameter " - 'names.\nAcceptable comparison operators are ">=" and "<=".' + "names.\n" + "* Equality constraints: ` + == ` or `* + * == `, " + "same as linear constraints but using `==` instead of `<=` or `>=`.\n" + 'Acceptable comparison operators are ">=", "<=", and "==".' ) @@ -440,27 +445,42 @@ def constraint_from_str( last_token_is_numeric = True except ValueError: last_token_is_numeric = False + + # Identify the comparison operator (second-to-last for linear, middle + # for order constraints). is_order_constraint = ( len(tokens) == 3 - and tokens[1] in COMPARISON_OPS + and tokens[1] in COMPARISON_OPS_WITH_EQ and not last_token_is_numeric ) is_linear_constraint = ( - # if len == 3, then this is a single parameter bound constraint, otherwise - # it corresponds to a numerical bound on a sum of parameters + # if len == 3, then this is a single parameter bound constraint, + # otherwise it corresponds to a numerical bound on a sum of + # parameters len(tokens) >= 3 and len(tokens) % 2 == 1 - and tokens[-2] in COMPARISON_OPS + and tokens[-2] in COMPARISON_OPS_WITH_EQ and last_token_is_numeric ) if is_order_constraint: # e.g. "x1 >= x2" + if tokens[1] == "==": + raise ValueError( + "Equality order constraints (e.g. 'x1 == x2') are not " + "supported. Use a DerivedParameter to express that two " + "parameters must be equal." + ) return _process_order_constraint( tokens=tokens, parameters=parameters, ) - if is_linear_constraint: # e.g. "x1 + x2 >= 3" + if is_linear_constraint: # e.g. "x1 + x2 >= 3" or "x1 + x2 == 3" + if tokens[-2] == "==": + return _process_equality_constraint( + tokens=tokens, + parameters=parameters, + ) return _process_linear_constraint( tokens=tokens, parameters=parameters, @@ -1042,32 +1062,41 @@ def _process_order_constraint( ) -def _process_linear_constraint( +def _parse_linear_constraint_tokens( tokens: Sequence[str], parameters: Mapping[str, Parameter], -) -> ParameterConstraint: - """Processes a linear constraint, e.g. "x1 + x2 <= 3". The last token is expected - to be a numeric constant, and the other tokens are expected to be parameters, their - multiplicative coefficients (e.g."2.5*x1") and "+" or "-" operators (e.g. "+"). + operator_str: str, +) -> tuple[dict[str, float], float]: + """Parse tokens of a linear constraint into parameter weights and bound. + + Shared helper for ``_process_linear_constraint`` and + ``_process_equality_constraint``. Validates ``*`` placement, processes + alternating monomials / operators, and returns the raw + ``parameter_weights`` dict and ``bound``. Args: tokens: A list of tokens in the constraint string. parameters: A mapping from parameter names to their definitions. + operator_str: The comparison operator string (e.g. ``"<="``/``">="`` + /``"=="``), used only for error messages. Returns: - A ParameterConstraint object representing the linear constraint. + A tuple of (parameter_weights, bound). """ parameter_names = parameters.keys() bound = float(tokens[-1]) if any(token[0] == "*" or token[-1] == "*" for token in tokens): raise ValueError( - "A linear constraint should be the form a*x + b*y - c*z <= d" - ", where a,b,c,d are float constants and x,y,z are parameters. " - "There should be no space in each term around the operator `*`, and " - "there should be a single space around each operator +, -, <= and >=." + f"A linear constraint should be the form " + f"a*x + b*y - c*z {operator_str} d" + ", where a,b,c,d are float constants and x,y,z are " + "parameters. There should be no space in each term " + "around the operator `*`, and there should be a " + f"single space around each operator +, -, " + f"and {operator_str}." ) - parameter_weights = {} + parameter_weights: dict[str, float] = {} current_sign = 1.0 # Determines whether the operator is + or - # tokens are alternating monomials and operators for idx, token in enumerate(tokens[:-2]): @@ -1091,6 +1120,27 @@ def _process_linear_constraint( raise ValueError( f"Expected a mixed constraint, found operator `{token}`." ) + return parameter_weights, bound + + +def _process_linear_constraint( + tokens: Sequence[str], + parameters: Mapping[str, Parameter], +) -> ParameterConstraint: + """Processes a linear constraint, e.g. "x1 + x2 <= 3". The last token is expected + to be a numeric constant, and the other tokens are expected to be parameters, their + multiplicative coefficients (e.g."2.5*x1") and "+" or "-" operators (e.g. "+"). + + Args: + tokens: A list of tokens in the constraint string. + parameters: A mapping from parameter names to their definitions. + + Returns: + A ParameterConstraint object representing the linear constraint. + """ + parameter_weights, bound = _parse_linear_constraint_tokens( + tokens=tokens, parameters=parameters, operator_str="<= or >=" + ) # tokens[-2] is checked to be either LEQ or GEQ if sum_const is True comparison_multiplier = ( 1.0 if COMPARISON_OPS[tokens[-2]] is ComparisonOp.LEQ else -1.0 @@ -1104,6 +1154,45 @@ def _process_linear_constraint( ) +def _process_equality_constraint( + tokens: Sequence[str], + parameters: Mapping[str, Parameter], +) -> ParameterConstraint: + """Processes a linear equality constraint, e.g. "x1 + x2 == 3". + + The last token is expected to be a numeric constant, the second-to-last + is ``"=="``, and the other tokens are parameters, their multiplicative + coefficients (e.g. ``"2.5*x1"``) and ``"+"`` or ``"-"`` operators. + + Args: + tokens: A list of tokens in the constraint string. + parameters: A mapping from parameter names to their definitions. + + Returns: + A ParameterConstraint with ``equality=...``. + """ + parameter_weights, bound = _parse_linear_constraint_tokens( + tokens=tokens, parameters=parameters, operator_str="==" + ) + # Reject equality constraints that equate two parameters + # (e.g. "x1 - x2 == 0"). DerivedParameter is the correct tool. + if ( + bound == 0.0 + and len(parameter_weights) == 2 + and set(parameter_weights.values()) == {1.0, -1.0} + ): + params = list(parameter_weights.keys()) + raise ValueError( + f"Equality constraint '{' '.join(tokens)}' is equivalent to " + f"'{params[0]} == {params[1]}'. Use a DerivedParameter to " + "express that two parameters must be equal." + ) + expr = " + ".join( + f"{coeff} * {param}" for param, coeff in parameter_weights.items() + ) + return ParameterConstraint(equality=f"{expr} == {bound}") + + def _process_monomial(monomial_str: str) -> tuple[float, str]: """Process a monomial in a linear constraint.