Skip to content

Commit

Permalink
Do not persist entire AutoMLState in Searcher (#870)
Browse files Browse the repository at this point in the history
* Do not persist entire AutoMLState in Searcher

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>

* Fix tests

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
  • Loading branch information
Yard1 committed Jan 6, 2023
1 parent 90aea9c commit 5f67c0a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
6 changes: 3 additions & 3 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,15 @@ def _train_with_config(
return estimator, train_time


def size(state: AutoMLState, config: dict) -> float:
def size(learner_classes: dict, config: dict) -> float:
"""Size function.
Returns:
The mem size in bytes for a config.
"""
config = config.get("ml", config)
estimator = config["learner"]
learner_class = state.learner_classes.get(estimator)
learner_class = learner_classes.get(estimator)
return learner_class.size(config)


Expand Down Expand Up @@ -3125,7 +3125,7 @@ def _search_parallel(self):
min_resource=min_resource_all_estimator,
max_resource=self.max_resource,
config_constraints=[
(partial(size, self._state), "<=", self._mem_thres)
(partial(size, self._state.learner_classes), "<=", self._mem_thres)
],
metric_constraints=self.metric_constraints,
seed=self._seed,
Expand Down
8 changes: 6 additions & 2 deletions test/automl/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def test_metric_constraints():
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,
config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
config_constraints=[
(partial(size, automl._state.learner_classes), "<=", automl._mem_thres)
],
metric_constraints=automl.metric_constraints,
num_samples=5,
)
Expand Down Expand Up @@ -159,7 +161,9 @@ def test_metric_constraints_custom():
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,
config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
config_constraints=[
(partial(size, automl._state.learner_classes), "<=", automl._mem_thres)
],
metric_constraints=automl.metric_constraints,
num_samples=5,
)
Expand Down
6 changes: 5 additions & 1 deletion test/automl/test_python_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def test_logging_level(self):
min_resource=automl.min_resource,
max_resource=automl.max_resource,
config_constraints=[
(partial(size, automl._state), "<=", automl._mem_thres)
(
partial(size, automl._state.learner_classes),
"<=",
automl._mem_thres,
)
],
metric_constraints=automl.metric_constraints,
)
Expand Down
4 changes: 3 additions & 1 deletion test/automl/test_xgboost2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def test_simple(method=None):
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,
config_constraints=[(partial(size, automl._state), "<=", automl._mem_thres)],
config_constraints=[
(partial(size, automl._state.learner_classes), "<=", automl._mem_thres)
],
metric_constraints=automl.metric_constraints,
num_samples=5,
)
Expand Down

0 comments on commit 5f67c0a

Please sign in to comment.