Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly separate JAX and non-JAX data during Jittable serialization. #177

Merged
merged 1 commit into from
Jul 21, 2022

Commits on Jul 21, 2022

  1. Explicitly separate JAX and non-JAX data during Jittable serialization.

    Previously, all properties of Jittable objects were considered non-JAX metadata for the purposes of JAX serialization. This kept bookkeeping to a minimum, but had several undesirable effects:
    
    1. `jax.tree_map` treated Jittable objects as pytree nodes, but did not recur into their parameters, resulting in the Jittable being invisible to the `tree_map`.
    2. Jitted functions would fail when called with a different distribution than the one they were initially compiled for, as the JAX parameters that differ between the instances would be considered different static metadata.
    3. Modifying Jittable properties inside a jitted function resulted in a tracer, which would then be added back to the object's metadata on exiting the function, subsequently leaking into the rest of the code.
    
    More detail on these issues is reported in this Github issue: #162
    
    This change addresses the issue by modifying the way that we serialize Jittables (which includes Distributions and Bijectors) to explicitly separate all JAX data in the `self.__dict__` from any metadata such as strings and primitives.
    
    PiperOrigin-RevId: 462347605
    Jake Bruce authored and DistraxDev committed Jul 21, 2022
    Configuration menu
    Copy the full SHA
    0ecad05 View commit details
    Browse the repository at this point in the history