Skip to content

Conversation

@dfm
Copy link
Contributor

@dfm dfm commented May 7, 2025

Having separate primitives for initial and final style meant that we needed some duplicated logic so this should be a net win.

Most of the tests pass, but there are a couple of failures related to closed-over tracers.

All tests green after #28605 goes in.

@dfm dfm self-assigned this May 7, 2025
@dfm dfm force-pushed the consolidate-custom-vjp-primitives branch 2 times, most recently from 9ed9482 to 8941ee8 Compare May 7, 2025 19:44
@dfm dfm added the pull ready Ready for copybara import and testing label May 7, 2025
@dfm dfm force-pushed the consolidate-custom-vjp-primitives branch 4 times, most recently from df642f1 to a0cf239 Compare May 8, 2025 09:50
@dfm dfm requested a review from mattjj May 8, 2025 09:54
@dfm dfm force-pushed the consolidate-custom-vjp-primitives branch 4 times, most recently from 0132a9c to 883e49b Compare May 12, 2025 15:12
copybara-service bot pushed a commit to google/aqt that referenced this pull request May 12, 2025
In jax-ml/jax#28589, we're slightly tweaking the behavior of custom_vjp when staged out. Of relevance here, the Jaxpr parameter is now called `call_jaxpr` (for consistency with other internal higher-order primitives) instead of `fun_jaxpr`. This change supports old and new versions of JAX.

PiperOrigin-RevId: 757772896
Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!!!

copybara-service bot pushed a commit to google/aqt that referenced this pull request May 12, 2025
In jax-ml/jax#28589, we're slightly tweaking the behavior of custom_vjp when staged out. Of relevance here, the Jaxpr parameter is now called `call_jaxpr` (for consistency with other internal higher-order primitives) instead of `fun_jaxpr`. This change supports old and new versions of JAX.

PiperOrigin-RevId: 757772896
copybara-service bot pushed a commit to google/aqt that referenced this pull request May 12, 2025
In jax-ml/jax#28589, we're slightly tweaking the behavior of custom_vjp when staged out. Of relevance here, the Jaxpr parameter is now called `call_jaxpr` (for consistency with other internal higher-order primitives) instead of `fun_jaxpr`. This change supports old and new versions of JAX.

PiperOrigin-RevId: 757883900
@dfm dfm force-pushed the consolidate-custom-vjp-primitives branch from 883e49b to 74938be Compare May 13, 2025 09:39
@copybara-service copybara-service bot merged commit 1ad9eae into jax-ml:main May 13, 2025
23 checks passed
copybara-service bot pushed a commit to jax-ml/jax-tpu-embedding that referenced this pull request May 15, 2025
The primitive name is changed in jax-ml/jax#28589

PiperOrigin-RevId: 759252787
copybara-service bot pushed a commit to jax-ml/jax-tpu-embedding that referenced this pull request May 15, 2025
The primitive name is changed in jax-ml/jax#28589

PiperOrigin-RevId: 759268475
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants