Skip to content

Commit

Permalink
Remove @cached_property (does not exist in python3.7) (#842)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #842

Unfortunately, cached_property doesnt exist in python3.7 so it must be removed from the OSS RegistryBundle. Not sure how the unit tests didnt catch this (I exported from phabricator to gh) so I've added some new unit tests to make sure this cant happen again.

Reviewed By: EugenHotaj

Differential Revision: D34753355

fbshipit-source-id: 3ead42bc86e9cafc9f737e19ad6cd91a79b33e93
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Mar 9, 2022
1 parent 4a30549 commit b0510ee
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
32 changes: 32 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
CORE_CLASS_DECODER_REGISTRY,
)
from ax.storage.json_store.save import save_experiment
from ax.storage.registry_bundle import RegistryBundle
from ax.utils.common.testutils import TestCase
from ax.utils.measurement.synthetic_functions import ackley, branin, from_botorch
from ax.utils.testing.benchmark_stubs import (
Expand Down Expand Up @@ -453,6 +454,37 @@ class MyMetric(Metric):
self.assertEqual(loaded_experiment, experiment)
os.remove(f.name)

def testRegistryBundle(self):
class MyMetric(Metric):
pass

class MyRunner(Runner):
def run():
pass

def staging_required():
return False

bundle = RegistryBundle(
metric_clss={MyMetric: 1998}, runner_clss={MyRunner: None}
)

experiment = get_experiment_with_batch_and_single_trial()
experiment.runner = MyRunner()
experiment.add_tracking_metric(MyMetric(name="my_metric"))
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".json") as f:
save_experiment(
experiment,
f.name,
encoder_registry=bundle.encoder_registry,
)
loaded_experiment = load_experiment(
f.name,
decoder_registry=bundle.decoder_registry,
)
self.assertEqual(loaded_experiment, experiment)
os.remove(f.name)

def testEncodeUnknownClassToDict(self):
# Cannot encode `UnknownClass` type because it is not registered in the
# CLASS_ENCODER_REGISTRY.
Expand Down
22 changes: 13 additions & 9 deletions ax/storage/registry_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

from abc import abstractproperty, ABC
from functools import cached_property
from typing import Any, Callable, Optional, Type, Dict

from ax.core.metric import Metric
Expand Down Expand Up @@ -136,10 +135,7 @@ def __init__(
json_decoder_registry=json_decoder_registry,
json_class_decoder_registry=json_class_decoder_registry,
)

@cached_property
def sqa_config(self) -> SQAConfig:
return SQAConfig(
self._sqa_config = SQAConfig(
json_encoder_registry={**self.encoder_registry, **CORE_ENCODER_REGISTRY},
json_decoder_registry={**self.decoder_registry, **CORE_DECODER_REGISTRY},
metric_registry=self.metric_registry,
Expand All @@ -148,10 +144,18 @@ def sqa_config(self) -> SQAConfig:
json_class_decoder_registry=self.class_decoder_registry,
)

@cached_property
self._encoder = Encoder(self._sqa_config)
self._decoder = Decoder(self._sqa_config)

# TODO[mpolson64] change @property to @cached_property once we deprecate 3.7
@property
def sqa_config(self) -> SQAConfig:
return self._sqa_config

@property
def encoder(self) -> Encoder:
return Encoder(self.sqa_config)
return self._encoder

@cached_property
@property
def decoder(self) -> Decoder:
return Decoder(self.sqa_config)
return self._decoder
23 changes: 23 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.modelbridge.registry import Models
from ax.runners.synthetic import SyntheticRunner
from ax.storage.metric_registry import register_metric, CORE_METRIC_REGISTRY
from ax.storage.registry_bundle import RegistryBundle
from ax.storage.runner_registry import register_runner, CORE_RUNNER_REGISTRY
from ax.storage.sqa_store.db import (
get_engine,
Expand Down Expand Up @@ -1056,6 +1057,28 @@ class MyMetric(Metric):
loaded_experiment = load_experiment(experiment.name, config=sqa_config)
self.assertEqual(loaded_experiment, experiment)

def testRegistryBundle(self):
class MyRunner(Runner):
def run():
pass

def staging_required():
return False

class MyMetric(Metric):
pass

bundle = RegistryBundle(
metric_clss={MyMetric: 1998}, runner_clss={MyRunner: None}
)

experiment = get_experiment_with_batch_trial()
experiment.runner = MyRunner()
experiment.add_tracking_metric(MyMetric(name="my_metric"))
save_experiment(experiment, config=bundle.sqa_config)
loaded_experiment = load_experiment(experiment.name, config=bundle.sqa_config)
self.assertEqual(loaded_experiment, experiment)

def testEncodeDecodeGenerationStrategy(self):
# Cannot load generation strategy before it has been saved
with self.assertRaises(ObjectNotFoundError):
Expand Down

0 comments on commit b0510ee

Please sign in to comment.