-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Explicitly separate JAX and non-JAX data during Jittable serialization.
Previously, all properties of Jittable objects were considered non-JAX metadata for the purposes of JAX serialization. This kept bookkeeping to a minimum, but had several undesirable effects: 1. `jax.tree_map` treated Jittable objects as pytree nodes, but did not recur into their parameters, resulting in the Jittable being invisible to the `tree_map`. 2. Jitted functions would fail when called with a different distribution than the one they were initially compiled for, as the JAX parameters that differ between the instances would be considered different static metadata. 3. Modifying Jittable properties inside a jitted function resulted in a tracer, which would then be added back to the object's metadata on exiting the function, subsequently leaking into the rest of the code. More detail on these issues is reported in this Github issue: #162 This change addresses the issue by modifying the way that we serialize Jittables (which includes Distributions and Bijectors) to explicitly separate all JAX data in the `self.__dict__` from any metadata such as strings and primitives. PiperOrigin-RevId: 462347605
- Loading branch information
Jake Bruce
authored and
DistraxDev
committed
Jul 21, 2022
1 parent
b1cbbf8
commit 0ecad05
Showing
3 changed files
with
116 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters