Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 64 additions & 17 deletions ax/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from logging import Logger
from typing import Any
from typing import Any, cast

import numpy as np
from ax.adapter.data_utils import (
Expand All @@ -31,9 +31,9 @@
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.observation_utils import recombine_observations
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TCandidateMetadata, TModelPredict
from ax.core.types import TCandidateMetadata, TModelPredict, TParamValue
from ax.core.utils import get_target_trial_index, has_map_metrics
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.model import AdapterMethodNotImplementedError, ModelError
Expand Down Expand Up @@ -410,17 +410,33 @@ def _compute_in_design(
return search_space.check_membership_df(arm_data=experiment_data.arm_data)

def _set_model_space(self, arm_data: DataFrame) -> None:
"""Set model space, possibly expanding range parameters to cover data."""
"""Set model space, possibly expanding parameters to cover data.

For ``RangeParameter``, expand ``lower`` / ``upper`` bounds to cover the
range of values observed in ``arm_data``.

For ``ChoiceParameter``, append any observed values not already in
``p.values``. Expansion is restricted to numeric ordered choice
parameters: expanding unordered choices breaks ``OneHot`` at gen time
(the surrogate's ``self.parameters`` is frozen at fit and contains
one-hot dimensions for every expanded value, but
``OneHot.transform_search_space`` prunes back to the user-declared
values at gen, leaving stale entries that ``extract_search_space_digest``
fails to look up). Expanding non-numeric choices is unsupported for
a similar reason. If a numeric ordered choice parameter is declared
as ``INT`` but an observed value is non-integer, the model space copy
of the parameter is relaxed to ``FLOAT`` to preserve the observation
exactly. This only mutates the adapter-local ``_model_space``;
``experiment.search_space`` is untouched.
"""
# If fill for missing values, include those in expansion.
t = FillMissingParameters(
search_space=self._model_space,
config=self._transform_configs.get("FillMissingParameters", None),
)
fill_values = t._fill_values
# Update model space. Expand bounds as needed to cover the values found
# in the data. Only applies to range parameters.
for p_name, p in self._model_space.parameters.items():
if not isinstance(p, RangeParameter):
if not isinstance(p, (RangeParameter, ChoiceParameter)):
continue
if p_name in arm_data:
param_vals = arm_data[p_name].dropna().tolist()
Expand All @@ -430,19 +446,50 @@ def _set_model_space(self, arm_data: DataFrame) -> None:
param_vals.append(fill_values[p_name])
if len(param_vals) == 0:
continue
# For log_scale parameters, ensure lower bound is > 0
# as OOD arms may have values <= 0
if p.log_scale:
# Find the smallest positive value from param_vals
positive_vals = [v for v in param_vals if v > 0]
if positive_vals:
if isinstance(p, RangeParameter):
# For log_scale parameters, ensure lower bound is > 0
# as OOD arms may have values <= 0
if p.log_scale:
positive_vals = [v for v in param_vals if v > 0]
if not positive_vals:
# keep original lower bound
continue
p.lower = min(p.lower, min(positive_vals))
else:
# keep original lower bound
continue
p.lower = min(p.lower, min(param_vals))
p.upper = max(p.upper, max(param_vals))
else:
p.lower = min(p.lower, min(param_vals))
p.upper = max(p.upper, max(param_vals))
# ChoiceParameter. Only expand numeric ordered choice
# parameters; unordered / non-numeric choices break OneHot
# at gen time (stale one-hot dimensions in self.parameters).
if not (p.is_ordered and p.parameter_type.is_numeric):
continue
# If the parameter is declared INT but an observed value is
# non-integer, relax the model-space copy to FLOAT so the
# observation is preserved exactly rather than truncated by
# `_cast_values`. Safe because `_model_space` is adapter-local
# and not persisted; downstream consumers of ChoiceParameter
# do not branch on INT vs FLOAT.
if p.parameter_type == ParameterType.INT and any(
not float(v).is_integer() for v in param_vals
):
p._parameter_type = ParameterType.FLOAT
# Dedupe while preserving order: `set_values` (unlike
# `__init__`) does not dedupe its input, so duplicate observed
# values would otherwise corrupt downstream integer encodings.
existing = set(p.values)
extra_values: list[TParamValue] = []
for v in param_vals:
if v not in existing:
extra_values.append(v)
existing.add(v)
if not extra_values:
continue
# Numeric ordered choice parameters are enforced to have
# `sort_values=True` at construction time, so we can always
# sort here. Values are guaranteed numeric by the gate above,
# hence sortable.
p.set_values(sorted(cast(list[float], [*p.values, *extra_values])))
# Remove parameter constraints from the model space.
self._model_space.set_parameter_constraints([])

Expand Down
Loading
Loading