Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
OptimizationConfig,
)
from ax.core.parameter import DerivedParameter, Parameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.runner import Runner
from ax.core.search_space import SearchSpace
from ax.core.trial import Trial
Expand Down Expand Up @@ -362,6 +363,7 @@ def add_parameters_to_search_space(
self,
parameters: Sequence[Parameter],
status_quo_values: TParameterization | None = None,
parameter_constraints: Sequence[ParameterConstraint] | None = None,
) -> None:
"""
Add new parameters to the experiment's search space. This allows extending
Expand All @@ -376,6 +378,8 @@ def add_parameters_to_search_space(
space.
status_quo_values: Optional parameter values for the new parameters to
use in the status quo (baseline) arm, if one is defined.
parameter_constraints: Optional sequence of typed ParameterConstraint
objects to add to the search space after the parameters are added.
"""
status_quo_values = status_quo_values or {}

Expand Down Expand Up @@ -429,6 +433,10 @@ def add_parameters_to_search_space(
# Add parameters to search space
self._search_space.add_parameters(parameters)

# Add parameter constraints to search space
if parameter_constraints:
self._search_space.add_parameter_constraints(list(parameter_constraints))

def disable_parameters_in_search_space(
self, default_parameter_values: TParameterization
) -> None:
Expand Down
26 changes: 26 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ParameterType,
RangeParameter,
)
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.search_space import SearchSpace
from ax.core.types import ComparisonOp
from ax.exceptions.core import (
Expand Down Expand Up @@ -427,6 +428,31 @@ def test_add_search_space_parameters(self) -> None:
self.assertIn("new_param", experiment.status_quo.parameters)
self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0)

with self.subTest("Add parameter with parameter constraints"):
experiment = self.experiment.clone_with(trial_indices=[])
num_existing_constraints = len(
experiment.search_space.parameter_constraints
)
constraint = ParameterConstraint(
inequality="new_param + w <= 5.0",
)
experiment.add_parameters_to_search_space(
parameters=[new_param],
status_quo_values={new_param.name: 0.0},
parameter_constraints=[constraint],
)
# Verify parameter was added
self.assertIn("new_param", experiment.search_space.parameters)
# Verify constraint was added
self.assertEqual(
len(experiment.search_space.parameter_constraints),
num_existing_constraints + 1,
)
added_constraint = experiment.search_space.parameter_constraints[-1]
self.assertIn("new_param", added_constraint.constraint_dict)
self.assertIn("w", added_constraint.constraint_dict)
self.assertEqual(added_constraint.bound, 5.0)

def test_add_derived_parameter_to_search_space_with_trials(self) -> None:
"""Test adding DerivedParameters to an experiment that has existing trials.

Expand Down
24 changes: 23 additions & 1 deletion ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ax.core.objective import Objective
from ax.core.observation import ObservationFeatures
from ax.core.parameter import RangeParameter
from ax.core.parameter_constraint import ParameterConstraint
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
Expand Down Expand Up @@ -555,6 +556,7 @@ def add_parameters(
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
backfill_values: TParameterization,
status_quo_values: TParameterization | None = None,
parameter_constraints: list[str] | None = None,
) -> None:
"""
Add new parameters to the experiment's search space. This allows extending
Expand All @@ -574,6 +576,10 @@ def add_parameters(
status_quo_values: Optional parameter values for the new parameters to
use in the status quo (baseline) arm, if one is defined. If None,
the backfill values will be used for the status quo.
parameter_constraints: Optional list of string representations of
parameter constraints to add (e.g., ``"x1 + x2 <= 5.0"``
or ``"x1 <= x2"``). May reference both existing and new
parameters.
"""
parameters_to_add = [
parameter_from_config(parameter_config) for parameter_config in parameters
Expand All @@ -594,9 +600,25 @@ def add_parameters(
for parameter in parameters_to_add:
if parameter.name in backfill_values:
parameter._backfill_value = backfill_values[parameter.name]

# Convert string constraints to typed ParameterConstraint objects.
typed_parameter_constraints: list[ParameterConstraint] = []
if parameter_constraints:
# Build a parameter map with both existing and new parameters so
# constraints can reference either.
parameter_map = {
**self.experiment.search_space.parameters,
**{p.name: p for p in parameters_to_add},
}
typed_parameter_constraints = [
InstantiationBase.constraint_from_str(c, parameter_map)
for c in parameter_constraints
]

self.experiment.add_parameters_to_search_space(
parameters=parameters_to_add,
status_quo_values=status_quo_values,
status_quo_values=status_quo_values or backfill_values,
parameter_constraints=typed_parameter_constraints or None,
)
self._save_experiment_to_db_if_possible(experiment=self.experiment)

Expand Down
138 changes: 138 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,144 @@ def test_add_parameters(self) -> None:
assert isinstance(param_x3, ChoiceParameter)
self.assertEqual(param_x3.values, ["a", "b", "c"])

def test_add_parameters_backfill_values_used_for_status_quo(self) -> None:
"""Test that backfill_values are used for the status quo arm when
status_quo_values is not provided.
"""
ax_client = AxClient()
ax_client.create_experiment(
name="test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float",
},
],
status_quo={"x1": 0.5},
is_test=True,
immutable_search_space_and_opt_config=False,
)

ax_client.add_parameters(
parameters=[
RangeParameterConfig(
name="x2",
bounds=(0.0, 10.0),
parameter_type="float",
),
ChoiceParameterConfig(
name="x3",
values=["a", "b", "c"],
parameter_type="str",
),
],
backfill_values={"x2": 5.0, "x3": "a"},
)

# Verify the status quo arm was updated with backfill_values
status_quo = ax_client.experiment.status_quo
self.assertIsNotNone(status_quo)
assert status_quo is not None
self.assertEqual(status_quo.parameters["x1"], 0.5)
self.assertEqual(status_quo.parameters["x2"], 5.0)
self.assertEqual(status_quo.parameters["x3"], "a")

def test_add_parameters_with_constraints(self) -> None:
"""Test that add_parameters correctly adds parameter constraints."""
ax_client = AxClient()
ax_client.create_experiment(
name="test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 10.0],
"value_type": "float",
},
],
is_test=True,
immutable_search_space_and_opt_config=False,
)

with self.subTest("Sum constraint on new parameters"):
ax_client.add_parameters(
parameters=[
RangeParameterConfig(
name="x2",
bounds=(0.0, 10.0),
parameter_type="float",
),
],
backfill_values={"x2": 5.0},
parameter_constraints=["x1 + x2 <= 5.0"],
)
search_space = ax_client.experiment.search_space
self.assertIn("x2", search_space.parameters)
self.assertEqual(len(search_space.parameter_constraints), 1)
constraint = search_space.parameter_constraints[0]
self.assertIn("x1", constraint.constraint_dict)
self.assertIn("x2", constraint.constraint_dict)
self.assertEqual(constraint.bound, 5.0)

with self.subTest("Order constraint referencing existing and new parameter"):
ax_client_2 = AxClient()
ax_client_2.create_experiment(
name="test_experiment_2",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 10.0],
"value_type": "float",
},
],
is_test=True,
immutable_search_space_and_opt_config=False,
)
ax_client_2.add_parameters(
parameters=[
RangeParameterConfig(
name="x2",
bounds=(0.0, 10.0),
parameter_type="float",
),
],
backfill_values={"x2": 5.0},
parameter_constraints=["x1 <= x2"],
)
search_space = ax_client_2.experiment.search_space
self.assertEqual(len(search_space.parameter_constraints), 1)

with self.subTest("Constraint referencing non-existent parameter"):
ax_client_3 = AxClient()
ax_client_3.create_experiment(
name="test_experiment_3",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 10.0],
"value_type": "float",
},
],
is_test=True,
immutable_search_space_and_opt_config=False,
)
with self.assertRaises(ValueError):
ax_client_3.add_parameters(
parameters=[
RangeParameterConfig(
name="x2",
bounds=(0.0, 10.0),
parameter_type="float",
),
],
backfill_values={"x2": 5.0},
parameter_constraints=["x1 + nonexistent <= 5.0"],
)

def test_disable_parameters(self) -> None:
"""Test that disable_parameters correctly disables parameters in the search
space."""
Expand Down