Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions ax/adapter/tests/test_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,33 +327,18 @@ def get_adapter(min: float, max: float) -> TorchAdapter:

def test_arm_to_np_array(self) -> None:
# Test extracting target point from arm with valid parameters

# Setup: create arm with target parameter values
target_arm = Arm(parameters={"x1": 0.5, "x2": 1.5, "x3": 2.5})
parameters = ["x1", "x2", "x3"]

# Execute: extract target point
actual = arm_to_np_array(arm=target_arm, parameters=parameters)

# Assert: confirm extracted values match expected order
expected = np.array([0.5, 1.5, 2.5])
self.assertIsNotNone(actual)
np.testing.assert_array_equal(actual, expected)

def test_extract_arm_to_np_array_different_parameter_order(self) -> None:
# Test extracting target point with different parameter ordering

# Setup: create arm and specify parameters in different order
target_arm = Arm(parameters={"x1": 0.5, "x2": 1.5, "x3": 2.5})
parameters = ["x3", "x1", "x2"]

# Execute: extract target point
actual = arm_to_np_array(arm=target_arm, parameters=parameters)

# Assert: confirm values are extracted in specified parameter order
expected = np.array([2.5, 0.5, 1.5])
self.assertIsNotNone(actual)
np.testing.assert_array_equal(actual, expected)
cases = [
# Values extracted in natural parameter order
(["x1", "x2", "x3"], np.array([0.5, 1.5, 2.5])),
# Values extracted in a different parameter order
(["x3", "x1", "x2"], np.array([2.5, 0.5, 1.5])),
]
for parameters, expected in cases:
with self.subTest(parameters=parameters):
actual = arm_to_np_array(arm=target_arm, parameters=parameters)
self.assertIsNotNone(actual)
np.testing.assert_array_equal(actual, expected)

def test_arm_to_np_array_none(self) -> None:
# Test that None is returned when target_arm is None
Expand Down
10 changes: 5 additions & 5 deletions ax/adapter/tests/test_base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,11 @@ def _test_init_with_data(self, multi_objective: bool) -> None:
search_space=search_space, experiment_data=experiment_data
)

def test_init_with_data_single_objective(self) -> None:
self._test_init_with_data(multi_objective=False)

def test_init_with_data_multi_objective(self) -> None:
self._test_init_with_data(multi_objective=True)
def test_init_with_data(self) -> None:
# Verify init_with_data for both single-objective and multi-objective
for multi_objective in (False, True):
with self.subTest(multi_objective=multi_objective):
self._test_init_with_data(multi_objective=multi_objective)

def test_fit_tracking_metrics(self) -> None:
# Test error when fit_tracking_metrics is False and optimization
Expand Down
33 changes: 16 additions & 17 deletions ax/adapter/tests/test_hierarchical_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,19 @@ def _base_test_predict_and_cv(
cv_res = cross_validate(adapter=mbm)
self.assertEqual(len(cv_res), len(experiment.trials))

def test_with_non_hierarchical_hss(self) -> None:
experiment = self._test_gen_base(
hss=self.non_hierarchical_hss, expected_num_candidate_params=[3]
)
self._base_test_predict_and_cv(experiment=experiment)

def test_with_simple_hss(self) -> None:
experiment = self._test_gen_base(
hss=self.simple_hss, expected_num_candidate_params=[2]
)
self._base_test_predict_and_cv(experiment=experiment)

def test_with_complex_hss(self) -> None:
experiment = self._test_gen_base(
hss=self.complex_hss, expected_num_candidate_params=[2, 4, 5]
)
self._base_test_predict_and_cv(experiment=experiment)
def test_with_hss_variants(self) -> None:
cases = [
# Non-hierarchical HSS: all 3 params are candidates
("non_hierarchical", self.non_hierarchical_hss, [3]),
# Simple HSS: 2 candidate params per node
("simple", self.simple_hss, [2]),
# Complex HSS: varying candidate params across nodes (2, 4, 5)
("complex", self.complex_hss, [2, 4, 5]),
]
for label, hss, expected_num_candidate_params in cases:
with self.subTest(hss_variant=label):
experiment = self._test_gen_base(
hss=hss,
expected_num_candidate_params=expected_num_candidate_params,
)
self._base_test_predict_and_cv(experiment=experiment)
132 changes: 53 additions & 79 deletions ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,90 +1256,35 @@ def test_pairwise_preference_generator(self) -> None:
X=X.expand(2, *X.shape), Y=comp_pair_Y.expand(2, *comp_pair_Y.shape)
)

def test_get_transformed_model_gen_args_with_target_point(self) -> None:
# Test that _get_transformed_model_gen_args correctly processes target_point

# Setup: create adapter with target arm in optimization config
experiment = get_branin_experiment(with_completed_trial=True)
pruning_target_parameterization = Arm(parameters={"x1": -5.0, "x2": 15.0})
optimization_config = none_throws(
experiment.optimization_config
).clone_with_args(
pruning_target_parameterization=pruning_target_parameterization
)

adapter = TorchAdapter(
generator=TorchGenerator(),
experiment=experiment,
transforms=Cont_X_trans,
)

# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
base_gen_args = adapter._get_transformed_gen_args(
search_space=experiment.search_space,
optimization_config=optimization_config,
pending_observations={},
)

search_space_digest, torch_opt_config = adapter._get_transformed_model_gen_args(
search_space=base_gen_args.search_space,
pending_observations=base_gen_args.pending_observations,
fixed_features=base_gen_args.fixed_features,
optimization_config=base_gen_args.optimization_config,
)

# Assert: confirm pruning_target_point is correctly extracted and transformed
self.assertIsNotNone(torch_opt_config.pruning_target_point)
expected_target = torch.tensor([0.0, 1.0], dtype=torch.double)
torch.testing.assert_close(
torch_opt_config.pruning_target_point, expected_target
def _test_get_transformed_model_gen_args_target_point(
self,
with_status_quo: bool,
pruning_target_params: dict[str, float] | None,
expected_target: torch.Tensor | None,
) -> None:
experiment = get_branin_experiment(
with_completed_trial=True,
with_status_quo=with_status_quo,
)

def test_get_transformed_model_gen_args_no_target_point(self) -> None:
# Test that _get_transformed_model_gen_args handles
# pruning_target_parameterization=None correctly
opt_config = none_throws(experiment.optimization_config)
if pruning_target_params is not None:
pruning_target = Arm(parameters=pruning_target_params)
opt_config = opt_config.clone_with_args(
pruning_target_parameterization=pruning_target
)
elif with_status_quo:
opt_config = opt_config.clone()

# Setup: create adapter without target arm (default case)
experiment = get_branin_experiment(with_completed_trial=True)
adapter = TorchAdapter(
generator=TorchGenerator(),
experiment=experiment,
transforms=Cont_X_trans,
)

# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
base_gen_args = adapter._get_transformed_gen_args(
search_space=experiment.search_space,
optimization_config=none_throws(experiment.optimization_config),
pending_observations={},
)

search_space_digest, torch_opt_config = adapter._get_transformed_model_gen_args(
search_space=base_gen_args.search_space,
pending_observations=base_gen_args.pending_observations,
fixed_features=base_gen_args.fixed_features,
optimization_config=base_gen_args.optimization_config,
)

# Assert: confirm target_point is None when no pruning_target_parameterization
# is provided
self.assertIsNone(torch_opt_config.pruning_target_point)

def test_get_transformed_model_gen_args_with_sq_as_target(self) -> None:
# Test that _get_transformed_model_gen_args correctly processes the status quo
# as the target point
experiment = get_branin_experiment(
with_completed_trial=True, with_status_quo=True
)

adapter = TorchAdapter(
generator=TorchGenerator(), experiment=experiment, transforms=Cont_X_trans
)
oc = none_throws(experiment.optimization_config).clone()
# Execute: call _get_transformed_gen_args then _get_transformed_model_gen_args
base_gen_args = adapter._get_transformed_gen_args(
search_space=experiment.search_space,
optimization_config=oc,
optimization_config=opt_config,
pending_observations={},
)

Expand All @@ -1350,12 +1295,41 @@ def test_get_transformed_model_gen_args_with_sq_as_target(self) -> None:
optimization_config=base_gen_args.optimization_config,
)

# Assert: confirm pruning_target_point is correctly extracted and transformed
self.assertIsNotNone(torch_opt_config.pruning_target_point)
expected_target = torch.tensor([1 / 3.0, 0.0], dtype=torch.double)
torch.testing.assert_close(
torch_opt_config.pruning_target_point, expected_target
)
if expected_target is None:
self.assertIsNone(torch_opt_config.pruning_target_point)
else:
self.assertIsNotNone(torch_opt_config.pruning_target_point)
torch.testing.assert_close(
torch_opt_config.pruning_target_point,
expected_target,
)

def test_get_transformed_model_gen_args_target_point(self) -> None:
# Test _get_transformed_model_gen_args with various target point scenarios
for label, with_status_quo, pruning_target_params, expected_target in [
# Explicit pruning target arm is correctly transformed
(
"with_target_point",
False,
{"x1": -5.0, "x2": 15.0},
torch.tensor([0.0, 1.0], dtype=torch.double),
),
# No pruning target and no status quo -> target_point is None
("no_target_point", False, None, None),
# Status quo used as the pruning target when no explicit target
(
"sq_as_target",
True,
None,
torch.tensor([1 / 3.0, 0.0], dtype=torch.double),
),
]:
with self.subTest(scenario=label):
self._test_get_transformed_model_gen_args_target_point(
with_status_quo=with_status_quo,
pruning_target_params=pruning_target_params,
expected_target=expected_target,
)

@mock_botorch_optimize
def test_moo_with_derived_parameter(self) -> None:
Expand Down
17 changes: 8 additions & 9 deletions ax/adapter/transforms/tests/test_logit_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,15 @@ def test_InvalidSettings(self) -> None:
self._create_logit_parameter(lower=0.1, upper=0.9, log_scale=True)
self.assertEqual("x can't use both log and logit.", str(cm.exception))

# Each case violates "lower > 0 and upper < 1":
# (0.0, 0.5) -> lower == 0, (0.3, 1.0) -> upper == 1,
# (0.5, 10.0) -> upper >> 1
str_exc = "x logit requires lower > 0 and upper < 1"
with self.assertRaises(UserInputError) as cm:
self._create_logit_parameter(lower=0.0, upper=0.5)
self.assertEqual(str_exc, str(cm.exception))
with self.assertRaises(UserInputError) as cm:
self._create_logit_parameter(lower=0.3, upper=1.0)
self.assertEqual(str_exc, str(cm.exception))
with self.assertRaises(UserInputError) as cm:
self._create_logit_parameter(lower=0.5, upper=10.0)
self.assertEqual(str_exc, str(cm.exception))
for lower, upper in [(0.0, 0.5), (0.3, 1.0), (0.5, 10.0)]:
with self.subTest(lower=lower, upper=upper):
with self.assertRaises(UserInputError) as cm:
self._create_logit_parameter(lower=lower, upper=upper)
self.assertEqual(str_exc, str(cm.exception))

def test_TransformSearchSpace(self) -> None:
ss2 = deepcopy(self.search_space)
Expand Down
10 changes: 5 additions & 5 deletions ax/adapter/transforms/tests/test_map_key_to_float_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def _test_early_stopping(self, complete_with_progression: bool) -> None:
# Check that cross validation works.
cross_validate(adapter=adapter)

def test_no_early_stopping_with_progression(self) -> None:
self._test_no_early_stopping(with_progression=True)

def test_no_early_stopping_no_progression(self) -> None:
self._test_no_early_stopping(with_progression=False)
def test_no_early_stopping(self) -> None:
# Verify no-early-stopping behavior both with and without progression data
for with_progression in (True, False):
with self.subTest(with_progression=with_progression):
self._test_no_early_stopping(with_progression=with_progression)

def test_early_stopping_with_final_progression(self) -> None:
self._test_early_stopping(complete_with_progression=True)
Expand Down
48 changes: 19 additions & 29 deletions ax/adapter/transforms/tests/test_objective_as_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,37 +338,27 @@ def test_relative_constraint_feasibility_check(self) -> None:

def test_leq_constraint_feasibility(self) -> None:
"""Test feasibility checking with LEQ constraints."""
# m2 <= 0.3 constraint. Both observations have m2 > 0.3, so infeasible.
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
observations=[[1.0, 0.5], [2.0, 5.0]],
constraint_bound=0.3,
constraint_op=ComparisonOp.LEQ,
)

t = ObjectiveAsConstraint(
search_space=adapter._experiment.search_space,
experiment_data=experiment_data,
adapter=adapter,
)

self.assertTrue(t._should_add_constraint)

def test_leq_constraint_feasible(self) -> None:
"""Test that LEQ constraints with feasible points are correctly detected."""
# m2 <= 10.0 constraint. Both observations have m2 <= 10.0, so feasible.
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
observations=[[1.0, 0.5], [2.0, 5.0]],
constraint_bound=10.0,
constraint_op=ComparisonOp.LEQ,
)
cases = [
# m2 <= 0.3: both obs have m2 > 0.3 -> infeasible, constraint added
(0.3, True, "infeasible"),
# m2 <= 10.0: both obs have m2 <= 10.0 -> feasible, no constraint
(10.0, False, "feasible"),
]
for bound, expected_should_add, label in cases:
with self.subTest(bound=bound, scenario=label):
_, adapter, experiment_data = self._make_experiment_adapter_and_data(
observations=[[1.0, 0.5], [2.0, 5.0]],
constraint_bound=bound,
constraint_op=ComparisonOp.LEQ,
)

t = ObjectiveAsConstraint(
search_space=adapter._experiment.search_space,
experiment_data=experiment_data,
adapter=adapter,
)
t = ObjectiveAsConstraint(
search_space=adapter._experiment.search_space,
experiment_data=experiment_data,
adapter=adapter,
)

self.assertFalse(t._should_add_constraint)
self.assertEqual(t._should_add_constraint, expected_should_add)

def test_no_op_for_experiment_data(self) -> None:
"""Test that transform_experiment_data is a no-op."""
Expand Down
Loading
Loading