Skip to content

Commit

Permalink
Use ModelSpec.cross_validate() in ModelSelectionGenerationNode
Browse files Browse the repository at this point in the history
Summary: See title

Reviewed By: ldworkin

Differential Revision: D32838687

fbshipit-source-id: 6ac6829d75d514226aa5080fb4f44ed2a2b6a639
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Dec 16, 2021
1 parent 6f22a62 commit d8c52c2
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.cross_validation import (
cross_validate,
compute_diagnostics,
BestModelSelector,
)
from ax.modelbridge.cross_validation import BestModelSelector
from ax.modelbridge.model_spec import ModelSpec, FactoryFunctionModelSpec
from ax.modelbridge.registry import (
ModelRegistryBase,
Expand Down Expand Up @@ -50,7 +46,6 @@ def __init__(
self,
model_specs: List[ModelSpec],
best_model_selector: Optional[BestModelSelector] = None,
cvkwargs: Optional[Dict[str, Any]] = None,
) -> None:
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
# and `_pick_fitted_model_to_gen_from` methods, we validate the
Expand All @@ -59,7 +54,6 @@ def __init__(
# method to bypass that validation.
self.model_specs = model_specs
self.best_model_selector = best_model_selector
self.cvkwargs = cvkwargs

def fit(
self,
Expand Down Expand Up @@ -135,14 +129,11 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec:
raise NotImplementedError(CANNOT_SELECT_ONE_MODEL_MSG)
return self.model_specs[0]

cvkwargs = self.cvkwargs or {}
cv_diagnostics = []
for model_spec in self.model_specs:
cv_result = cross_validate(model_spec.fitted_model, **cvkwargs)
cv_diagnostics.append(compute_diagnostics(cv_result))
model_spec.cross_validate()

best_model_index = not_none(self.best_model_selector).best_diagnostic(
cv_diagnostics
diagnostics=[not_none(m.diagnostics) for m in self.model_specs],
)
return self.model_specs[best_model_index]

Expand Down

0 comments on commit d8c52c2

Please sign in to comment.