Skip to content

Commit

Permalink
Move instantiation.py functions to mixin
Browse files Browse the repository at this point in the history
Summary: This allows them to be overridden by the PTSClient

Reviewed By: lena-kashtelyan

Differential Revision: D34430873

fbshipit-source-id: 32be72b932a0e99db116816f2027a2559b1f5782
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Mar 7, 2022
1 parent df637c6 commit 50a4d32
Show file tree
Hide file tree
Showing 4 changed files with 747 additions and 707 deletions.
15 changes: 5 additions & 10 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,7 @@
from ax.plot.helper import _format_dict, _get_in_sample_arms
from ax.plot.trace import optimization_trace_single_method
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.instantiation import (
data_and_evaluations_from_raw_data,
make_experiment,
ObjectiveProperties,
build_objective_threshold,
)
from ax.service.utils.instantiation import ObjectiveProperties, InstantiationBase
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
from ax.storage.json_store.decoder import (
Expand Down Expand Up @@ -91,7 +86,7 @@
)


class AxClient(WithDBSettingsBase, BestPointMixin):
class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
"""
Convenience handler for management of experimentation cycle through a
service-like API. External system manages scheduling of the cycle and makes
Expand Down Expand Up @@ -321,7 +316,7 @@ def create_experiment(
for objective, properties in objectives.items()
}
objective_kwargs["objective_thresholds"] = [
build_objective_threshold(objective, properties)
self.build_objective_threshold(objective, properties)
for objective, properties in objectives.items()
if properties.threshold is not None
]
Expand All @@ -333,7 +328,7 @@ def create_experiment(
category=DeprecationWarning,
)

experiment = make_experiment(
experiment = self.make_experiment(
name=name,
parameters=parameters,
parameter_constraints=parameter_constraints,
Expand Down Expand Up @@ -1442,7 +1437,7 @@ def _make_evaluations_and_data(
"""
raw_data_by_arm = self._raw_data_by_arm(trial=trial, raw_data=raw_data)

evaluations, data = data_and_evaluations_from_raw_data(
evaluations, data = self.data_and_evaluations_from_raw_data(
raw_data=raw_data_by_arm,
metric_names=self.objective_names,
trial_index=trial.index,
Expand Down
10 changes: 3 additions & 7 deletions ax/service/managed_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@
get_best_parameters_from_model_predictions,
get_best_raw_objective_point,
)
from ax.service.utils.instantiation import (
make_experiment,
TParameterRepresentation,
data_and_evaluations_from_raw_data,
)
from ax.service.utils.instantiation import TParameterRepresentation, InstantiationBase
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -99,7 +95,7 @@ def with_evaluation_function(
) -> "OptimizationLoop":
"""Constructs a synchronous `OptimizationLoop` using an evaluation
function."""
experiment = make_experiment(
experiment = InstantiationBase.make_experiment(
name=experiment_name,
parameters=parameters,
objective_name=objective_name,
Expand Down Expand Up @@ -204,7 +200,7 @@ def run_trial(self) -> None:
trial = self._get_new_trial()

trial.mark_running(no_runner_required=True)
_, data = data_and_evaluations_from_raw_data(
_, data = InstantiationBase.data_and_evaluations_from_raw_data(
raw_data={
arm.name: self._call_evaluation_function(arm.parameters, weight)
for arm, weight in self._get_weights_by_arm(trial)
Expand Down
78 changes: 36 additions & 42 deletions ax/service/tests/test_instantiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,7 @@
from ax.core.parameter import ParameterType, RangeParameter, FixedParameter
from ax.core.search_space import HierarchicalSearchSpace
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.service.utils.instantiation import (
_get_parameter_type,
constraint_from_str,
outcome_constraint_from_str,
make_experiment,
make_objectives,
make_search_space,
make_optimization_config,
raw_data_to_evaluation,
parameter_from_json,
)
from ax.service.utils.instantiation import InstantiationBase
from ax.utils.common.testutils import TestCase


Expand All @@ -28,16 +18,16 @@ class TestInstantiationtUtils(TestCase):

def test_parameter_type_validation(self):
with self.assertRaisesRegex(ValueError, "No AE parameter type"):
_get_parameter_type(list)
InstantiationBase._get_parameter_type(list)

def test_constraint_from_str(self):
with self.assertRaisesRegex(ValueError, "Bound for the constraint"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 + x2 <= not_numerical_bound", {"x1": None, "x2": None}
)
with self.assertRaisesRegex(ValueError, "Outcome constraint bound"):
outcome_constraint_from_str("m1 <= not_numerical_bound")
three_val_constaint = constraint_from_str(
InstantiationBase.outcome_constraint_from_str("m1 <= not_numerical_bound")
three_val_constaint = InstantiationBase.constraint_from_str(
"x1 + x2 + x3 <= 3",
{
"x1": RangeParameter(
Expand All @@ -54,12 +44,14 @@ def test_constraint_from_str(self):

self.assertEqual(three_val_constaint.bound, 3.0)
with self.assertRaisesRegex(ValueError, "Parameter constraint should"):
constraint_from_str("x1 + x2 + <= 3", {"x1": None, "x2": None, "x3": None})
InstantiationBase.constraint_from_str(
"x1 + x2 + <= 3", {"x1": None, "x2": None, "x3": None}
)
with self.assertRaisesRegex(ValueError, "Parameter constraint should"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 + x2 + x3 = 3", {"x1": None, "x2": None, "x3": None}
)
three_val_constaint2 = constraint_from_str(
three_val_constaint2 = InstantiationBase.constraint_from_str(
"-x1 + 2.1*x2 - 4*x3 <= 3",
{
"x1": RangeParameter(
Expand All @@ -79,39 +71,39 @@ def test_constraint_from_str(self):
three_val_constaint2.constraint_dict, {"x1": -1.0, "x2": 2.1, "x3": -4.0}
)
with self.assertRaisesRegex(ValueError, "Multiplier should be float"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 - e*x2 + x3 <= 3", {"x1": None, "x2": None, "x3": None}
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 - 2 *x2 + 3 *x3 <= 3", {"x1": None, "x2": None, "x3": None}
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 - 2* x2 + 3* x3 <= 3", {"x1": None, "x2": None, "x3": None}
)
with self.assertRaisesRegex(ValueError, "A linear constraint should be"):
constraint_from_str(
InstantiationBase.constraint_from_str(
"x1 - 2 * x2 + 3*x3 <= 3", {"x1": None, "x2": None, "x3": None}
)

def test_objective_validation(self):
with self.assertRaisesRegex(UnsupportedError, "Ambiguous objective definition"):
make_experiment(
InstantiationBase.make_experiment(
parameters={"name": "x", "type": "range", "bounds": [0, 1]},
objective_name="branin",
objectives={"branin": "minimize", "currin": "maximize"},
)

def test_add_tracking_metrics(self):
experiment = make_experiment(
experiment = InstantiationBase.make_experiment(
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
tracking_metric_names=None,
)
self.assertDictEqual(experiment._tracking_metrics, {})

metrics_names = ["metric_1", "metric_2"]
experiment = make_experiment(
experiment = InstantiationBase.make_experiment(
parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}],
tracking_metric_names=metrics_names,
)
Expand All @@ -122,8 +114,10 @@ def test_add_tracking_metrics(self):

def test_make_objectives(self):
with self.assertRaisesRegex(ValueError, "specify 'minimize' or 'maximize'"):
make_objectives({"branin": "unknown"})
objectives = make_objectives({"branin": "minimize", "currin": "maximize"})
InstantiationBase.make_objectives({"branin": "unknown"})
objectives = InstantiationBase.make_objectives(
{"branin": "minimize", "currin": "maximize"}
)
branin_metric = [o.minimize for o in objectives if o.metric.name == "branin"]
self.assertTrue(branin_metric[0])
currin_metric = [o.minimize for o in objectives if o.metric.name == "currin"]
Expand All @@ -134,7 +128,7 @@ def test_make_optimization_config(self):
objective_thresholds = ["branin <= 0", "currin >= 0"]
with self.subTest("Single-objective optimizations with objective thresholds"):
with self.assertRaisesRegex(ValueError, "not specify objective thresholds"):
make_optimization_config(
InstantiationBase.make_optimization_config(
{"branin": "minimize"},
objective_thresholds,
outcome_constraints=[],
Expand All @@ -143,15 +137,15 @@ def test_make_optimization_config(self):

with self.subTest("MOO missing objective thresholds"):
with self.assertRaises(UserInputError):
multi_optimization_config = make_optimization_config(
multi_optimization_config = InstantiationBase.make_optimization_config(
objectives,
objective_thresholds=objective_thresholds[:1],
outcome_constraints=[],
status_quo_defined=False,
)

with self.subTest("MOO with all objective threshold"):
multi_optimization_config = make_optimization_config(
multi_optimization_config = InstantiationBase.make_optimization_config(
objectives,
objective_thresholds,
outcome_constraints=[],
Expand All @@ -163,7 +157,7 @@ def test_make_optimization_config(self):
with self.subTest(
"Single-objective optimizations without objective thresholds"
):
single_optimization_config = make_optimization_config(
single_optimization_config = InstantiationBase.make_optimization_config(
{"branin": "minimize"},
objective_thresholds=[],
outcome_constraints=[],
Expand All @@ -177,7 +171,7 @@ def test_single_valued_choice_to_fixed_param_conversion(self):
"type": "choice",
"values": [1.0],
}
output = parameter_from_json(representation)
output = InstantiationBase.parameter_from_json(representation)
self.assertIsInstance(output, FixedParameter)
self.assertEqual(output.value, 1.0)

Expand Down Expand Up @@ -210,7 +204,7 @@ def test_hss(self):
},
{"name": "another_int", "type": "fixed", "value": "2"},
]
search_space = make_search_space(
search_space = InstantiationBase.make_search_space(
parameters=parameter_dicts, parameter_constraints=[]
)
self.assertIsInstance(search_space, HierarchicalSearchSpace)
Expand All @@ -220,13 +214,13 @@ def test_hss(self):
class TestRawDataToEvaluation(TestCase):
def test_raw_data_is_not_dict_of_dicts(self):
with self.assertRaises(ValueError):
raw_data_to_evaluation(
InstantiationBase.raw_data_to_evaluation(
raw_data={"arm_0": {"objective_a": 6}},
metric_names=["objective_a"],
)

def test_it_converts_to_floats_in_dict_and_leaves_tuples(self):
result = raw_data_to_evaluation(
result = InstantiationBase.raw_data_to_evaluation(
raw_data={
"objective_a": 6,
"objective_b": 1.0,
Expand All @@ -240,51 +234,51 @@ def test_it_converts_to_floats_in_dict_and_leaves_tuples(self):

def test_dict_entries_must_be_int_float_or_tuple(self):
with self.assertRaises(ValueError):
raw_data_to_evaluation(
InstantiationBase.raw_data_to_evaluation(
raw_data={"objective_a": [6.0, None]},
metric_names=["objective_a"],
)

def test_it_requires_a_dict_for_multi_objectives(self):
with self.assertRaises(ValueError):
raw_data_to_evaluation(
InstantiationBase.raw_data_to_evaluation(
raw_data=(6.0, None),
metric_names=["objective_a", "objective_b"],
)

def test_it_accepts_a_list_for_single_objectives(self):
raw_data = [({"arm__0": {}}, {"objective_a": (1.4, None)})]
result = raw_data_to_evaluation(
result = InstantiationBase.raw_data_to_evaluation(
raw_data=raw_data,
metric_names=["objective_a"],
)
self.assertEqual(raw_data, result)

def test_it_turns_a_tuple_into_a_dict(self):
raw_data = (1.4, None)
result = raw_data_to_evaluation(
result = InstantiationBase.raw_data_to_evaluation(
raw_data=raw_data,
metric_names=["objective_a"],
)
self.assertEqual(result["objective_a"], raw_data)

def test_it_turns_an_int_into_a_dict_of_tuple(self):
result = raw_data_to_evaluation(
result = InstantiationBase.raw_data_to_evaluation(
raw_data=1,
metric_names=["objective_a"],
)
self.assertEqual(result["objective_a"], (1.0, None))

def test_it_turns_a_float_into_a_dict_of_tuple(self):
result = raw_data_to_evaluation(
result = InstantiationBase.raw_data_to_evaluation(
raw_data=1.6,
metric_names=["objective_a"],
)
self.assertEqual(result["objective_a"], (1.6, None))

def test_it_raises_for_unexpected_types(self):
with self.assertRaises(ValueError):
raw_data_to_evaluation(
InstantiationBase.raw_data_to_evaluation(
raw_data="1.6",
metric_names=["objective_a"],
)
Loading

0 comments on commit 50a4d32

Please sign in to comment.