Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly separate JAX and non-JAX data during Jittable serialization. #177

Merged
merged 1 commit into from
Jul 21, 2022

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Jul 8, 2022

Explicitly separate JAX and non-JAX data during Jittable serialization.

Previously, all properties of Jittable objects were considered non-JAX metadata for the purposes of JAX serialization. This kept bookkeeping to a minimum, but had several undesirable effects:

  1. jax.tree_map treated Jittable objects as pytree nodes, but did not recur into their parameters, resulting in the Jittable being invisible to the tree_map.
  2. Jitted functions would fail when called with a different distribution than the one they were initially compiled for, as the JAX parameters that differ between the instances would be considered different static metadata.
  3. Modifying Jittable properties inside a jitted function resulted in a tracer, which would then be added back to the object's metadata on exiting the function, subsequently leaking into the rest of the code.

More detail on these issues is reported in this Github issue: #162

This change addresses the issue by modifying the way that we serialize Jittables (which includes Distributions and Bijectors) to explicitly separate all JAX data in the self.__dict__ from any metadata such as strings and primitives.

@copybara-service copybara-service bot force-pushed the test_458051922 branch 4 times, most recently from 62bf00e to 4359409 Compare July 13, 2022 11:49
@copybara-service copybara-service bot force-pushed the test_458051922 branch 2 times, most recently from 93cb123 to 68655b1 Compare July 21, 2022 10:08
Previously, all properties of Jittable objects were considered non-JAX metadata for the purposes of JAX serialization. This kept bookkeeping to a minimum, but had several undesirable effects:

1. `jax.tree_map` treated Jittable objects as pytree nodes, but did not recur into their parameters, resulting in the Jittable being invisible to the `tree_map`.
2. Jitted functions would fail when called with a different distribution than the one they were initially compiled for, as the JAX parameters that differ between the instances would be considered different static metadata.
3. Modifying Jittable properties inside a jitted function resulted in a tracer, which would then be added back to the object's metadata on exiting the function, subsequently leaking into the rest of the code.

More detail on these issues is reported in this Github issue: #162

This change addresses the issue by modifying the way that we serialize Jittables (which includes Distributions and Bijectors) to explicitly separate all JAX data in the `self.__dict__` from any metadata such as strings and primitives.

PiperOrigin-RevId: 462347605
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

0 participants