Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def get_target_trial_index(
return None

# trial indices that have data for required metrics
trial_indices_with_required_metrics = _get_trial_indices_with_required_metrics(
trial_indices_with_required_metrics = get_trial_indices_with_required_metrics(
experiment=experiment,
df=df,
require_data_for_all_metrics=require_data_for_all_metrics,
Expand Down Expand Up @@ -593,7 +593,7 @@ def get_target_trial_index(
return None


def _get_trial_indices_with_required_metrics(
def get_trial_indices_with_required_metrics(
experiment: Experiment,
df: "pd.DataFrame",
require_data_for_all_metrics: bool,
Expand Down
2 changes: 0 additions & 2 deletions ax/generation_strategy/center_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def gen(
skip_fit: bool = False,
data: Data | None = None,
n: int | None = None,
arms_per_node: dict[str, int] | None = None,
**gs_gen_kwargs: Any,
) -> GeneratorRun | None:
"""Generate candidates or skip if search space is exhausted.
Expand Down Expand Up @@ -100,7 +99,6 @@ def gen(
skip_fit=skip_fit,
data=data,
n=n,
arms_per_node=arms_per_node,
**gs_gen_kwargs,
)

Expand Down
30 changes: 7 additions & 23 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def gen(
skip_fit: bool = False,
data: Data | None = None,
n: int | None = None,
arms_per_node: dict[str, int] | None = None,
**gs_gen_kwargs: Any,
) -> GeneratorRun | None:
"""This method generates candidates using `self._gen` and handles deduplication
Expand All @@ -451,9 +450,6 @@ def gen(
data: Optional override for the experiment data used to generate candidates;
if not specified, will use ``experiment.lookup_data()`` (extracted in
``Adapter``).
arms_per_node: A manual override for users interacting with a gen. strategy
via a Python API; a mapping from node name to the specific number of
arms it should produce. Passed down here by `GenerationStrategy.gen`.
gs_gen_kwargs: Keyword arguments, passed to ``GenerationStrategy.gen``.
These might be modified by this node's input constructors, before
being passed down to ``ModelSpec.gen``, where these will override any
Expand Down Expand Up @@ -485,13 +481,6 @@ def gen(
logger.debug(f"Skipping generation for node {self.name}.")
return None

if arms_per_node:
if self.name not in arms_per_node:
raise UnsupportedError(
"If manually specifying arms per node, all nodes must be specified."
)
generator_gen_kwargs["n"] = arms_per_node[self.name]

# TODO[drfreund]: Move this to `Adapter` or another more suitable place.
# Keeping here for now to limit the scope of the current changeset.
generator_gen_kwargs["fixed_features"] = (
Expand Down Expand Up @@ -992,8 +981,6 @@ class GenerationStep:
generator_gen_kwargs: Each call to `generation_strategy.gen` performs a call
to the step's adapter's `gen` under the hood; `generator_gen_kwargs` will be
passed to the adapter's `gen` like: `adapter.gen(**generator_gen_kwargs)`.
completion_criteria: List of TransitionCriterion. All `is_met` must evaluate
True for the GenerationStrategy to move on to the next Step
index: Index of this generation step, for use internally in `Generation
Strategy`. Do not assign as it will be reassigned when instantiating
`GenerationStrategy` with a list of its steps.
Expand Down Expand Up @@ -1024,7 +1011,6 @@ def __new__(
num_trials: int,
generator_kwargs: dict[str, Any] | None = None,
generator_gen_kwargs: dict[str, Any] | None = None,
completion_criteria: Sequence[TransitionCriterion] | None = None,
min_trials_observed: int = 0,
max_parallelism: int | None = None,
enforce_num_trials: bool = True,
Expand Down Expand Up @@ -1061,7 +1047,6 @@ def __new__(

generator_kwargs = generator_kwargs or {}
generator_gen_kwargs = generator_gen_kwargs or {}
completion_criteria = completion_criteria or []

if (
enforce_num_trials
Expand Down Expand Up @@ -1101,47 +1086,46 @@ def __new__(
# is set in `GenerationStrategy` constructor, because only then is the order
# of the generation steps actually known.
transition_criteria: list[TransitionCriterion] = []
# Placeholder - will be overwritten in _validate_and_set_step_sequence in GS
placeholder_transition_to = f"GenerationStep_{str(index)}"

if num_trials != -1:
transition_criteria.append(
MinTrials(
threshold=num_trials,
transition_to=placeholder_transition_to,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
block_gen_if_met=enforce_num_trials,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_all_trials_in_exp,
transition_to=None, # Re-set in GS constructor.
)
)

if min_trials_observed > 0:
transition_criteria.append(
MinTrials(
threshold=min_trials_observed,
transition_to=placeholder_transition_to,
only_in_statuses=[
TrialStatus.COMPLETED,
TrialStatus.EARLY_STOPPED,
],
threshold=min_trials_observed,
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_all_trials_in_exp,
transition_to=None, # Re-set in GS constructor.
)
)
if max_parallelism is not None:
transition_criteria.append(
MaxGenerationParallelism(
threshold=max_parallelism,
transition_to=placeholder_transition_to,
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
# MaxParallelism transitions to self,
# this will be confirmed in GS init
transition_to=f"GenerationStep_{str(index)}",
)
)

transition_criteria += list(completion_criteria)

# Create and return a GenerationNode instance
node = GenerationNode(
# NOTE: This name is a placeholder that will be overwritten in
Expand Down
12 changes: 4 additions & 8 deletions ax/generation_strategy/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,10 @@ def _get_default_n(experiment: Experiment, next_node: GenerationNode) -> int:
The default number of arms to generate from the next node, used if no n is
provided to the ``GenerationStrategy``'s gen call.
"""
# If the generator spec contains `n` use that value first
# TODO #1 [drfreund, mgarrard]: Eliminate the need to do this; the order should be:
# `arms_per_node[node_name]` > `input_constuctors(n)` > `gen_spec...kwargs["n"]`
# NOTE: We might need to simply disallow `n` in `gen_spec...kwargs`: it should
# probably never be hardcoded there. Without it, we can just enforce that at a
# point within generation strategy, an `n` is passed down to `gen_spec.gen`.
# And if we keep it, we don't have a clear point in this stack at which we are
# "no longer allowed to have a null `n`."
# If the generator spec contains `n` use that value first.
# TODO [drfreund, mgarrard]: Consider disallowing `n` in `gen_spec...kwargs`:
# it should probably never be hardcoded there. This would enforce that `n`
# is always passed down through the generation strategy at runtime.
if next_node.generator_spec_to_gen_from.generator_gen_kwargs.get("n") is not None:
return next_node.generator_spec_to_gen_from.generator_gen_kwargs["n"]

Expand Down
45 changes: 1 addition & 44 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,7 @@
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.utils import extend_pending_observations, extract_pending_observations
from ax.exceptions.core import (
AxError,
DataRequiredError,
UnsupportedError,
UserInputError,
)
from ax.exceptions.core import AxError, DataRequiredError, UnsupportedError
from ax.exceptions.generation_strategy import (
GenerationStrategyCompleted,
GenerationStrategyMisconfiguredException,
Expand Down Expand Up @@ -247,7 +242,6 @@ def gen(
n: int | None = None,
fixed_features: ObservationFeatures | None = None,
num_trials: int = 1,
arms_per_node: dict[str, int] | None = None,
) -> list[list[GeneratorRun]]:
"""Produce GeneratorRuns for multiple trials at once with the possibility of
using multiple models per trial, getting multiple GeneratorRuns per trial.
Expand Down Expand Up @@ -275,12 +269,6 @@ def gen(
important to specify all necessary fixed features.
num_trials: Number of trials to generate generator runs for in this call.
If not provided, defaults to 1.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.

Returns:
A list of lists of lists generator runs. Each outer list represents
Expand All @@ -306,7 +294,6 @@ def gen(
data=data,
n=n,
pending_observations=pending_observations,
arms_per_node=arms_per_node,
fixed_features=fixed_features,
first_generation_in_multi=len(grs_for_multiple_trials) < 1,
)
Expand Down Expand Up @@ -467,24 +454,6 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:

self._curr = nodes[0]

def _validate_arms_per_node(self, arms_per_node: dict[str, int] | None) -> None:
"""Validate that the arms_per_node argument is valid if it is provided.

Args:
arms_per_node: A map from node name to the number of arms to
generate from that node.
"""
if arms_per_node is not None and not set(self.nodes_by_name).issubset(
arms_per_node
):
raise UserInputError(
"Each node defined in the `GenerationStrategy` must have an "
"associated number of arms to generate from that node defined "
f"in `arms_per_node`. {arms_per_node} does not include all of "
f"{self.nodes_by_name.keys()}. "
"It may help to double-check the spelling."
)

def _make_default_name(self) -> str:
"""Make a default name for this generation strategy; used when no name is passed
to the constructor. For node-based generation strategies, the name is
Expand Down Expand Up @@ -515,10 +484,6 @@ def _gen_with_multiple_nodes(
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
data: Data | None = None,
fixed_features: ObservationFeatures | None = None,
# TODO: Consider naming `arms_per_node` smtg like `arms_per_node_override`,
# to convey its manually-specified nature (if it's not specified, GS selects
# what to do on its own).
arms_per_node: dict[str, int] | None = None,
first_generation_in_multi: bool = True,
) -> list[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
Expand Down Expand Up @@ -548,12 +513,6 @@ def _gen_with_multiple_nodes(
passed down to the underlying nodes. Note: if provided this will
override any algorithmically determined fixed features so it is
important to specify all necessary fixed features.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.

Returns:
A list of ``GeneratorRuns`` for a single trial.
Expand All @@ -570,7 +529,6 @@ def _gen_with_multiple_nodes(
pending_observations if pending_observations is not None else {}
)
self.experiment = experiment
self._validate_arms_per_node(arms_per_node=arms_per_node)
pack_gs_gen_kwargs = {
"grs_this_gen": grs_this_gen,
"fixed_features": fixed_features,
Expand All @@ -596,7 +554,6 @@ def _gen_with_multiple_nodes(
pending_observations=pending_observations,
skip_fit=not (first_generation_in_multi or transitioned),
n=n,
arms_per_node=arms_per_node,
**pack_gs_gen_kwargs,
)
except DataRequiredError as err:
Expand Down
1 change: 0 additions & 1 deletion ax/generation_strategy/generator_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,5 +382,4 @@ def __eq__(self, other: GeneratorSpec) -> bool:
@property
def _unique_id(self) -> str:
"""Returns the unique ID of this model spec"""
# TODO @mgarrard verify that this is unique enough
return str(hash(self))
Loading