Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more tuples #1224

Merged
merged 6 commits into from Aug 21, 2019

Conversation

@mattjj
Copy link
Collaborator

commented Aug 21, 2019

This PR removes tuples from the jaxpr language, and JaxTuples / AbstractTuples / DeviceTuples / ShardedDeviceTuples from the implementation. Instead of tuples, jaxprs and primitives have multiple outputs.

Here are the main reasons for removing tuples:

  1. every tracer needed its own transparent tuple type (e.g. JaxprTracerTuple, TangentTuple) so that packing into a tuple with a traced value didn't irreversibly lift quantities into a trace (e.g. so that tuple pack/unpack didn't cause broadcasting of unmapped values for a vmap trace),
  2. allowing jaxpr variables to be bound to tuples meant types (e.g. for linearity, both in the AD sense and the PRNG sense) and lattice representations (e.g. for the initial-style analog of the batched/unbatched problem above) were more complex and the logic dealing with them had recursions everywhere,
  3. relatedly, to support lax.scan's mixed linear/nonlinear extensive arguments we had to introduce pat_fmap pattern destructuring and JaxprEqn.restructure, which polluted all jaxprs and jaxpr interpreters,
  4. tuple munging wasn't actually a good way to do bookkeeping in e.g. control flow primitives and was leading to extra jaxpr munging work,
  5. tuples basically doubled the number of data types we had to support (e.g. with device-backed versions), including in things like dispatch logic,
  6. no user ever wanted to see or learn about JaxTuples, yet it was one of the first things they would encounter in api.py.

So this PR gets rid of them! We also took the opportunity to:

  • revise control flow (including a more conservative and efficient vmap-of-while rule, a simpler implementation of cond, a removal of all jaxpr eqn munging, better compilation caching, and more modular code)
  • make xla.py and pxla.py support the introduction of user-defined types like the rest of the system (to be used in follow-up work on PRNG key linearity checking)
  • a cleanup of batching.py utility functions
  • de-register None as a pytree so that users can handle it as they like (and e.g. use it as their own sentinel value).

I checked for performance regressions on pmap using a benchmark from @hawkinsp and @ibab; in 2 out of my 3 runs it went from ~15.5ms to ~14.5ms, and in the third run they both were ~15ms, so I don't think this presents a significant performance regression (it may already be faster, and in any case I'm optimistic it opens up more opportunities for improvement).

Co-authored by @dougalm, particularly the hard control flow parts. @dougalm also got us started by de-tupling jaxprs and ad.py.

Two things that will probably come in smaller follow-up PRs:

  1. look at optimizing the jit dispatch path like we did for the pmap path, specifically in special-casing the handling of DeviceArray as an inlined fast path
  2. remove Primitive.multiple_results in favor of making all primitives return multiple results (unless there are performance implications, which seems unlikely).
dougalm and others added 4 commits Jul 26, 2019
Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
De-dup equations with multiple lhs vars when creating a jaxpr
Co-authored-by: Matthew Johnson <mattjj@google.com>

@googlebot googlebot added the cla: yes label Aug 21, 2019

@mattjj mattjj force-pushed the no-more-tuples branch from d53013f to 5fbf0c0 Aug 21, 2019

De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>

@mattjj mattjj force-pushed the no-more-tuples branch from 5fbf0c0 to b702f8d Aug 21, 2019

@mattjj mattjj requested a review from dougalm Aug 21, 2019

@mattjj mattjj merged commit a8e0c25 into master Aug 21, 2019

3 checks passed

cla/google All necessary CLAs are signed
continuous-integration/travis-ci/pr The Travis CI build passed
Details
continuous-integration/travis-ci/push The Travis CI build passed
Details

@mattjj mattjj deleted the no-more-tuples branch Aug 21, 2019

@mattjj

This comment has been minimized.

Copy link
Collaborator Author

commented Aug 21, 2019

b702f8d was the result of a rebase so that the commit history would look like this:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
You can’t perform that action at this time.