Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions ax/storage/json_store/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.priors import Prior
from gpytorch.priors.utils import BUFFERED_PREFIX
from pyre_extensions import assert_is_instance
from torch.distributions.transformed_distribution import TransformedDistribution

logger: logging.Logger = get_logger(__name__)

Expand Down Expand Up @@ -369,8 +369,11 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) ->
for k, v in state_dict.items()
}
)
if issubclass(botorch_class, TransformedDistribution):
# Extract the buffered attributes for transformed priors.
if issubclass(botorch_class, Prior):
# Extract the buffered attributes for priors. Some priors (e.g.
# BetaPrior, LogNormalPrior) store parameters with BUFFERED_PREFIX
# because their underlying distribution uses @property descriptors
# that cannot be deleted by _bufferize_attributes.
for k in list(state_dict.keys()):
if k.startswith(BUFFERED_PREFIX):
state_dict[k[len(BUFFERED_PREFIX) :]] = state_dict.pop(k)
Expand Down
24 changes: 24 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,30 @@ def test_BadStateDict(self) -> None:
del expected_json["state_dict"]["lower_bound"]
botorch_component_from_json(interval.__class__, expected_json)

def test_prior_roundtrip_serialization(self) -> None:
"""Test encode/decode roundtrip for priors with buffered attributes.

Priors whose underlying distribution uses @property descriptors
(e.g. BetaPrior via Dirichlet, LogNormalPrior via TransformedDistribution)
store state_dict keys with BUFFERED_PREFIX. The decoder must strip
the prefix to match __init__ arg names.
"""
from botorch.models.utils.priors import BetaPrior
from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior, NormalPrior

priors = [
("BetaPrior", BetaPrior(concentration1=2.5, concentration0=1.5)),
("GammaPrior", GammaPrior(concentration=2.0, rate=1.0)),
("NormalPrior", NormalPrior(loc=0.0, scale=1.0)),
("LogNormalPrior", LogNormalPrior(loc=0.0, scale=1.0)),
]
for name, prior in priors:
with self.subTest(prior=name):
encoded = botorch_component_to_dict(prior)
decoded = botorch_component_from_json(prior.__class__, encoded)
self.assertIsInstance(decoded, prior.__class__)
self.assertEqual(decoded.state_dict(), prior.state_dict())

def test_observation_features_backward_compatibility(self) -> None:
json = {
"__type": "ObservationFeatures",
Expand Down
Loading