Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to define pallas kernels usable in both fwd and bwd mode? #19146

Open
GallagherCommaJack opened this issue Dec 30, 2023 · 3 comments
Open
Labels
enhancement New feature or request pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@GallagherCommaJack
Copy link
Contributor

right now, there's clear documentation re how to use custom_jvp and custom_vjp, but the automatic transposition of custom_jvp isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp isn't compatible with jvp.

@GallagherCommaJack GallagherCommaJack added the enhancement New feature or request label Dec 30, 2023
@patrick-kidger
Copy link
Collaborator

#17840 might be of relevance here, as this PR makes it possible to do jvp-of-custom_vjp.

In practive it's still a little buggy and I haven't had the chance to fix it up yet, but it might be the essence of the solution needed.

@skye
Copy link
Member

skye commented Jan 2, 2024

cc @sharadmv @apaszke @mattjj (I suspect one of you would be able to comment on this...)

@sharadmv
Copy link
Collaborator

sharadmv commented Jan 2, 2024

but the automatic transposition of custom_jvp isn't necessarily going to be very good for pallas kernels (is it even defined?) and using custom_vjp isn't compatible with jvp.

As Patrick implied, this seems to be more of a JAX issue than a pallas specific one. If JAX supports defining both custom vjp/jvp simultaneously, this pallas use case would work. @froystig

@superbobry superbobry added the pallas Issues pertaining to Pallas (GPU or TPU) label Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

5 participants