Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hcb] Simplifications to the host_calback API #8678

Merged
merged 1 commit into from
Dec 11, 2021

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Nov 24, 2021

  • dropping support for special AD handling for hcb.id_tap and id_print.
    From now on, only the primals are tapped.

This allows us to make some significant cleanup in the internals.

@google-cla google-cla bot added the cla: yes label Nov 24, 2021
@gnecula gnecula marked this pull request as draft November 24, 2021 11:02
@gnecula gnecula self-assigned this Nov 24, 2021
@gnecula gnecula added the pull ready Ready for copybara import and testing label Nov 24, 2021
@gnecula gnecula force-pushed the hcb_simplify branch 3 times, most recently from d2ab70d to bd8236c Compare November 24, 2021 15:44
@shoyer
Copy link
Member

shoyer commented Nov 25, 2021

I guess the idea is that users should host_callback inside a custom_vjp if they want to capture values from the reverse pass?

@gnecula
Copy link
Collaborator Author

gnecula commented Nov 25, 2021

I guess the idea is that users should host_callback inside a custom_vjp if they want to capture values from the reverse pass?

Indeed, that is the idea. In fact, some folks already do this because they want to pre-process the gradients on the device before sending them to the host.

@gnecula gnecula force-pushed the hcb_simplify branch 4 times, most recently from 6ed9f07 to a3f7040 Compare December 10, 2021 13:08
@gnecula gnecula marked this pull request as ready for review December 10, 2021 13:19
* dropping support for special AD handling for hcb.id_tap and id_print.
  From now on, only the primals are tapped. The old behavior can be
  obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
  environment variale, or the --flax_host_callback_ad_transforms flag.
  Additionally, added documentation for how to implement the old behavior
  using JAX custom AD APIs.

This allows us to make some significant cleanup in the internals.
@copybara-service copybara-service bot merged commit 466eb7f into google:main Dec 11, 2021
@gnecula gnecula deleted the hcb_simplify branch December 11, 2021 09:49
copybara-service bot pushed a commit that referenced this pull request Aug 16, 2023
…ack_ad_transforms

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

PiperOrigin-RevId: 557402331
copybara-service bot pushed a commit that referenced this pull request Aug 16, 2023
…back_ad_transforms.

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557402331
copybara-service bot pushed a commit that referenced this pull request Aug 16, 2023
…back_ad_transforms.

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557402331
copybara-service bot pushed a commit that referenced this pull request Aug 16, 2023
…back_ad_transforms.

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557402331
copybara-service bot pushed a commit that referenced this pull request Aug 16, 2023
…back_ad_transforms.

This flag was added in #8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Aug 16, 2023
…back_ad_transforms.

This flag was added in google#8678 in December 2021
when we changed the behavior of host_callback to not have special handling for autodiff. Nobody is using that flag now.

This is part of a longer project to replace uses of host_callback with jax.pure_callback and jax.experimental.io_callback.

PiperOrigin-RevId: 557520668
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants