Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make custom_transforms handle pytrees, add api.defvjp #818

Merged
merged 10 commits into from
Jun 11, 2019
Merged

Conversation

mattjj
Copy link
Member

@mattjj mattjj commented Jun 5, 2019

Before this PR, using custom_transforms (from #419) with functions that took tuples/lists/dicts (i.e. pytrees) as arguments would silently fail. Moreover, defining custom rules that took as input or returned as output any pytree values wouldn't work. Finally, custom_transforms functions didn't automatically support vmap.

This PR fixes all that!

This PR also improves the API by adding defjvp and defvjp (cf. #636) functions (and their variants) to api.py, which can be called directly on a custom_transforms function rather than having to dig out a reference to the underlying primitive. There's also a custom_gradient convenience wrapper, which works like tf.custom_gradient.

This API will likely be the main mechanism by which users will define custom gradients.

TODO:

  • add docstrings

In the future we should add a tutorial explanation of how to do this in an .md file, like in Autograd's tutorial, at which point we can finally close #116.

It would also be good to add more tests, especially for error messages.

@fehiepsi
Copy link
Member

fehiepsi commented Jun 5, 2019

Hi @mattjj , will this support batching rule for transforms with 2 arguments? Currently, we have to use the following pattern for a custom transform with 2 arguments

def f_batching_rule(batched_args, batch_dims):
    x, y = batched_args
    bx, by = batch_dims
    # promote shapes
    sx, sy = np.shape(x), np.shape(y)
    nx = len(sx) + int(bx is None)
    ny = len(sy) + int(by is None)
    nd = max(nx, ny)
    x = np.reshape(x, (1,) * (nd - len(sx)) + sx)
    y = np.reshape(y, (1,) * (nd - len(sy)) + sy)
    # correct bx, by due to promoting
    bx = bx + nd - len(sx) if bx is not None else nd - len(sx) - 1
    by = by + nd - len(sy) if by is not None else nd - len(sy) - 1
    # move bx, by to front
    x = batching.move_dim_to_front(x, bx)
    y = batching.move_dim_to_front(y, by)
    return f(x, y), 0

Basically, the code promote x,y to have the same ndim by appending singleton dimensions to the front, then correct vmap dims due to this promoting. Finally, it moves vmap dims to the front.

@mattjj
Copy link
Member Author

mattjj commented Jun 5, 2019

It should handle arbitrary numbers of arguments according to standard vmap semantics. Do you have a test case I can try?

@fehiepsi
Copy link
Member

fehiepsi commented Jun 6, 2019

Sure, for example,

import jax.numpy as np
from jax import custom_transforms, vmap

@custom_transforms
def xlogy(x, y):
    return x * np.log(y)

vmap(xlogy)(np.ones(3), np.ones((3, 2)))

@mattjj
Copy link
Member Author

mattjj commented Jun 6, 2019

That seems to work on this branch: it produces a zeros array of shape (3, 2). Is that the result you expect?

@fehiepsi
Copy link
Member

fehiepsi commented Jun 6, 2019

Yup, that is the expected result. It is great to see that vmap is working for these custom transforms now.

Copy link
Collaborator

@skye skye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still reviewing, gonna send these out in case you wanna pipeline reviewing + responding...

jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Show resolved Hide resolved
Copy link
Collaborator

@skye skye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finished reviewing the original PR, gonna look at changes now...

jax/interpreters/ad.py Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
jax/api.py Outdated Show resolved Hide resolved
val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)

ans = api.grad(foo, (0, 1))(3., 4.)
self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)

def test_defvjp_all(self):
@api.custom_transforms
def foo(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test for a function that takes multiple and/or pytree arguments? (Sorry if I missed it below)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This can be a follow-up change, but I think is worth doing, at the very least to make sure we have a runnable example of multiple args.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do in a follow-up; left a todo.

@mattjj
Copy link
Member Author

mattjj commented Jun 11, 2019

@skye convinced me she was right about not having api.defjvp2/api.defvjp2 on API minimalism grounds (rather than readability grounds). I updated the PR to only have a defjvp and defvjp, though both act like the "2" variants in that they take ans as an argument (which can be ignored).

Currently that means that the api.py functions are slightly off from the corresponding ad.py versions, but we'll remedy that in a follow-up.

EDIT: one other thing to mention: it's easy to add convenience variants if we decide we want them, but it's hard to remove them once added.

@mattjj mattjj merged commit 581d24c into master Jun 11, 2019
@mattjj mattjj deleted the custom-transforms branch October 22, 2019 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Easy api for custom primitives and vjps
5 participants