From cf271ba7fc8b522f83546183df51fa86286c0cbf Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 19:45:42 -0700 Subject: [PATCH] Fix prior deserialization for priors with buffered attributes (#5167) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5167 The Ax JSON decoder's `botorch_component_from_json` strips the `BUFFERED_PREFIX` from state_dict keys only for `TransformedDistribution` subclasses. This misses priors like `BetaPrior` whose underlying distribution (`Beta`) uses `property` descriptors delegating to an internal `Dirichlet`, causing `_bufferize_attributes` to use the prefix. Broaden the check from `TransformedDistribution` to `(TransformedDistribution, Prior)` so all gpytorch priors with buffered attributes deserialize correctly. Reviewed By: sdaulton Differential Revision: D100341242 --- ax/storage/json_store/decoders.py | 9 ++++--- .../json_store/tests/test_json_store.py | 24 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index accf67e599f..b31be6bb617 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -54,9 +54,9 @@ from botorch.models.transforms.input import ChainedInputTransform, InputTransform from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.utils.types import _DefaultType, DEFAULT +from gpytorch.priors import Prior from gpytorch.priors.utils import BUFFERED_PREFIX from pyre_extensions import assert_is_instance -from torch.distributions.transformed_distribution import TransformedDistribution logger: logging.Logger = get_logger(__name__) @@ -369,8 +369,11 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) -> for k, v in state_dict.items() } ) - if issubclass(botorch_class, TransformedDistribution): - # Extract the buffered attributes for transformed priors. + if issubclass(botorch_class, Prior): + # Extract the buffered attributes for priors. Some priors (e.g. + # BetaPrior, LogNormalPrior) store parameters with BUFFERED_PREFIX + # because their underlying distribution uses @property descriptors + # that cannot be deleted by _bufferize_attributes. for k in list(state_dict.keys()): if k.startswith(BUFFERED_PREFIX): state_dict[k[len(BUFFERED_PREFIX) :]] = state_dict.pop(k) diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 7cf66b43fc4..065f00dee28 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1069,6 +1069,30 @@ def test_BadStateDict(self) -> None: del expected_json["state_dict"]["lower_bound"] botorch_component_from_json(interval.__class__, expected_json) + def test_prior_roundtrip_serialization(self) -> None: + """Test encode/decode roundtrip for priors with buffered attributes. + + Priors whose underlying distribution uses @property descriptors + (e.g. BetaPrior via Dirichlet, LogNormalPrior via TransformedDistribution) + store state_dict keys with BUFFERED_PREFIX. The decoder must strip + the prefix to match __init__ arg names. + """ + from botorch.models.utils.priors import BetaPrior + from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior, NormalPrior + + priors = [ + ("BetaPrior", BetaPrior(concentration1=2.5, concentration0=1.5)), + ("GammaPrior", GammaPrior(concentration=2.0, rate=1.0)), + ("NormalPrior", NormalPrior(loc=0.0, scale=1.0)), + ("LogNormalPrior", LogNormalPrior(loc=0.0, scale=1.0)), + ] + for name, prior in priors: + with self.subTest(prior=name): + encoded = botorch_component_to_dict(prior) + decoded = botorch_component_from_json(prior.__class__, encoded) + self.assertIsInstance(decoded, prior.__class__) + self.assertEqual(decoded.state_dict(), prior.state_dict()) + def test_observation_features_backward_compatibility(self) -> None: json = { "__type": "ObservationFeatures",