-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make custom_transforms handle pytrees, add api.defvjp #818
Conversation
With advice from @dougalm!
Hi @mattjj , will this support batching rule for transforms with 2 arguments? Currently, we have to use the following pattern for a custom transform with 2 arguments
Basically, the code promote x,y to have the same |
It should handle arbitrary numbers of arguments according to standard vmap semantics. Do you have a test case I can try? |
Sure, for example,
|
That seems to work on this branch: it produces a zeros array of shape (3, 2). Is that the result you expect? |
Yup, that is the expected result. It is great to see that vmap is working for these custom transforms now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still reviewing, gonna send these out in case you wanna pipeline reviewing + responding...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished reviewing the original PR, gonna look at changes now...
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) | ||
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False) | ||
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False) | ||
|
||
ans = api.grad(foo, (0, 1))(3., 4.) | ||
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False) | ||
|
||
def test_defvjp_all(self): | ||
@api.custom_transforms | ||
def foo(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a test for a function that takes multiple and/or pytree arguments? (Sorry if I missed it below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This can be a follow-up change, but I think is worth doing, at the very least to make sure we have a runnable example of multiple args.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do in a follow-up; left a todo.
@skye convinced me she was right about not having Currently that means that the api.py functions are slightly off from the corresponding ad.py versions, but we'll remedy that in a follow-up. EDIT: one other thing to mention: it's easy to add convenience variants if we decide we want them, but it's hard to remove them once added. |
Before this PR, using
custom_transforms
(from #419) with functions that took tuples/lists/dicts (i.e. pytrees) as arguments would silently fail. Moreover, defining custom rules that took as input or returned as output any pytree values wouldn't work. Finally,custom_transforms
functions didn't automatically support vmap.This PR fixes all that!
This PR also improves the API by adding
defjvp
anddefvjp
(cf. #636) functions (and their variants) to api.py, which can be called directly on acustom_transforms
function rather than having to dig out a reference to the underlying primitive. There's also acustom_gradient
convenience wrapper, which works liketf.custom_gradient
.This API will likely be the main mechanism by which users will define custom gradients.
TODO:
In the future we should add a tutorial explanation of how to do this in an .md file, like in Autograd's tutorial, at which point we can finally close #116.
It would also be good to add more tests, especially for error messages.