Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# pyre-strict

import logging

from math import ceil
from typing import Any, cast, Mapping

Expand Down Expand Up @@ -53,6 +55,8 @@
from sqlalchemy.orm import defaultload, joinedload, lazyload, noload
from sqlalchemy.orm.exc import DetachedInstanceError

logger: logging.Logger = logging.getLogger(__name__)


# ---------------------------- Loading `Experiment`. ---------------------------

Expand Down Expand Up @@ -537,21 +541,32 @@ def _load_generation_strategy_by_id(
def get_generation_strategy_id(experiment_name: str, decoder: Decoder) -> int | None:
"""Get DB ID of the generation strategy, associated with the experiment
with the given name if its in DB, return None otherwise.

If multiple generation strategies are associated with the experiment,
returns the latest one (highest DB ID).
"""
exp_sqa_class = decoder.config.class_to_sqa_class[Experiment]
gs_sqa_class = decoder.config.class_to_sqa_class[GenerationStrategy]
with session_scope() as session:
sqa_gs_id = (
sqa_gs_ids = (
session.query(gs_sqa_class.id) # pyre-ignore[16]
.join(exp_sqa_class.generation_strategy) # pyre-ignore[16]
# pyre-fixme[16]: `SQABase` has no attribute `name`.
.filter(exp_sqa_class.name == experiment_name)
.one_or_none()
.order_by(gs_sqa_class.id.desc())
.all()
)

if sqa_gs_id is None:
if not sqa_gs_ids:
return None
return sqa_gs_id[0]

if len(sqa_gs_ids) > 1:
logger.warning(
f"Found {len(sqa_gs_ids)} generation strategies for experiment "
f"{experiment_name}. Loading the latest one (id={sqa_gs_ids[0][0]})."
)

return sqa_gs_ids[0][0]


def get_generation_strategy_sqa(
Expand Down
5 changes: 0 additions & 5 deletions ax/storage/sqa_store/sqa_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,13 @@
class BaseNullableEnum(types.TypeDecorator):
cache_ok = True

# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def __init__(self, enum: Any, *arg: list[Any], **kw: dict[Any, Any]) -> None:
types.TypeDecorator.__init__(self, *arg, **kw)
# pyre-fixme[4]: Attribute must be annotated.
self._member_map = enum._member_map_
# pyre-fixme[4]: Attribute must be annotated.
self._value2member_map = enum._value2member_map_

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def process_bind_param(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
Expand All @@ -40,8 +37,6 @@ def process_bind_param(self, value: Any, dialect: Any) -> Any:
)
return val._value_

# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def process_result_value(self, value: Any, dialect: Any) -> Any:
if value is None:
return value
Expand Down
28 changes: 28 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2910,6 +2910,34 @@ def test_delete_generation_strategy_max_gs_to_delete(self) -> None:
# Full GS fails the equality check
self.assertEqual(str(generation_strategy), str(loaded_generation_strategy))

def test_load_latest_generation_strategy_when_multiple_exist(self) -> None:
experiment = get_branin_experiment()
gs1 = choose_generation_strategy_legacy(experiment.search_space)
gs1.experiment = experiment
save_experiment(experiment)
save_generation_strategy(generation_strategy=gs1)
self.assertEqual(
gs1.db_id,
load_generation_strategy_by_experiment_name(experiment.name).db_id,
)

# create a second generation strategy for the experiment
gs2 = choose_generation_strategy_legacy(experiment.search_space)
gs2._name = "second_gs"
gs2.experiment = experiment
save_generation_strategy(generation_strategy=gs2)

# check that the latest generation stragey is loaded
with self.assertLogs(
"ax.storage.sqa_store.load", level=logging.WARNING
) as logs:
loaded_gs = load_generation_strategy_by_experiment_name(experiment.name)
self.assertEqual(loaded_gs.db_id, gs2.db_id)
self.assertEqual(loaded_gs.name, gs2.name)
self.assertTrue(
any("Found 2 generation strategies" in log for log in logs.output)
)

def test_query_historical_experiments_given_parameters(self) -> None:
# This test validates the query behavior for historical experiments.
config = SQAConfig(experiment_type_enum=TestExperimentTypeEnum)
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


# pyre-fixme[3]: Return annotation cannot be `Any`.
def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any:
"""Ensure that exactly one of `exactly_one_fields` has a value set."""
values = [getattr(instance, field) is not None for field in exactly_one_fields]
Expand Down