Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: tree vectorizing transformation #3263

Closed
wants to merge 56 commits into from
Closed

WIP: tree vectorizing transformation #3263

wants to merge 56 commits into from

Conversation

shoyer
Copy link
Member

@shoyer shoyer commented May 31, 2020

Fixes #1012

The idea behind this transformation is that every argument to a function decorated by @tree_vectorize is virtually flattened and concatenated into 1D vector, e.g., {'x': 1, 'y': array([[2, 3]])} becomes array([1, 2, 3]).

However, under the covers the original arrays are still preserved with their original shapes. This has several advantages:

  1. we don't actually need to copy the arrays into a giant vector (avoiding the unnecessary memory copies)
  2. all operations happen on arrays with the original shapes (useful for predictable performance on TPUs)
  3. we can restore the tree structure on the outputs of the function

The hope is that this should make it considerably easier to write efficient numerical algorithms that work on pytrees (e.g., for optimization, linear equation solving or ODE solving).

Here are a few simple examples:

from jax import tree_vectorize
import jax.numpy as jnp

x = {'a': 1, 'b': jnp.array([2, 3])}
y = 1
add = tree_vectorize(jnp.add)
print(add(x, y))
# {'a': DeviceArray(2, dtype=int32), 'b': DeviceArray([3, 4], dtype=int32)}

@tree_vectorize
def norm(x):
  assert x.shape == (3,)
  return jnp.sqrt(jnp.sum(x ** 2))

print(norm(x))
# 3.7416575

Even though the inputs are always treated as vectors, outputs can have any number of dimensions, e.g., producing nested pytrees-of-pytrees, as shown in this example from the unit tests:

@tree_vectorize
@shapecheck(['n', 'm'], '(n, m)')  # just to show that it works!
def add_outer(x, y):
  return jnp.expand_dims(x, 1) + jnp.expand_dims(y, 0)

tree1 = {'a': 1, 'b': jnp.array([2, 3])}
tree2 = {'c': 10, 'd': jnp.array([20, 30])}
pprint.pprint(add_outer(tree1, tree2))
# {'a': {'c': DeviceArray(11, dtype=int32),
#        'd': DeviceArray([21, 31], dtype=int32)},
#  'b': {'c': DeviceArray([12, 13], dtype=int32),
#        'd': DeviceArray([[22, 32],
#              [23, 33]], dtype=int32)}}

TODO:

  • missing ops for cg:
    • jit
    • while_loop
  • further missing ops for odeint:
    • concatenate
    • gather
    • scatter
  • measure performance improvement for odeint!
  • figure out the API, e.g.,
    • explicit API for indicating functions, especially with multiple outputs?
    • explicit API for indicating non-vectorized arguments?
  • make things more robust
    • tree_call works with vmap, jvp and vjp (i.e., to implement jacfwd and jacrev)
    • missing Trace methods like process_call, post_process_call
    • support matching against explicit axes with trivial treedefs, e.g., so x + jnp.zeros(...) and explicit broadcasting work.
  • more comprehensive tests
    • check exceptions
    • make a reference implementation for a subset of tree_vectorize use cases that works by explicitly unraveling/concatenating

@shoyer shoyer requested a review from mattjj May 31, 2020 05:54
@shoyer shoyer mentioned this pull request May 31, 2020
@shoyer shoyer marked this pull request as ready for review June 17, 2020 23:13
@shoyer shoyer added the pull ready Ready for copybara import and testing label Jan 14, 2021
Comment on lines +500 to +501
basis = jnp.eye(x.size, dtype=x.dtype)
y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is very close to working, except we don't handle basis properly yet because it has a different tree structure. (This works in other cases, but not for transformed tree_call functions yet.)


@tree_vectorize
def f(g, x):
return jax.grad(g)(x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient support needs JaxprTrace.process_tree_call (which seems straightforward, but tedious).

Comment on lines +332 to +333
@lu.transformation_with_aux
def _tree_batch_subtrace(main, in_dims_trees, *in_vals_trees, **params):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke This is starting to feel pretty repetitive, especially when we consider doing the same thing for every higher order fucntions/transformation in JAX (e.g., partial eval, scan, while_loop, etc). It feels like there must be a better way to leverage existing support for wrapping trees...

shoyer added a commit to shoyer/jax that referenced this pull request Nov 10, 2021
shoyer added a commit to shoyer/jax that referenced this pull request Nov 10, 2021
Fixes google#1012

The idea here is a simpler alternative to google#3263, based on a custom
pytree (`tree_math.Vector`) rather than introducing a new final-style
transformation.

Eventually, it should be possible to write something like:
```python
@functools.partial(tree_math.wrap, vector_argnums=(1,))
def odeint_midpoint(f, y0, h, steps):
    F = tree_math.unwrap(f, vector_argnums=(0,))
    def step_fun(i, y):
        return y + h * F(y + F(y, i * h) / 2, (i + 0.5) * h))
    return lax.fori_loop(0, num_steps, bodstep_funy_fun, y0)
```
Aside from `wrap` and `unwrap`, this is exactly how you would write a
simple ODE solver in JAX without support for PyTrees. We currently
[do something very similar](https://github.com/google/jax-cfd/blob/a92862fb757d122e7c5eee23a3b783997a2aeafe/jax_cfd/spectral/time_stepping.py#L92)
for implementing ODE solvers in JAX-CFD.

The upside of this approach is that writing a custom pytree is easy and
entirely decoupled from JAX's other transformations. So we get support
for JAX transformations (e.g., `grad`/`jit`/`vmap`) and control flow
primitives (e.g., `while_loop`) for free.

The downside is that it won't be possible to use standard `jax.numpy`
functions on `Vector` objects, unless we make `jax.numpy` aware of
`tree_math` (ala google#8381). Instead, I've added a few specialized helper
functions like `where` and `maximum`. For now, this seems like an
acceptable trade-off, given that the core use-case for tree math is to
make it easy to write new algorithms from scientific computing, for
which infix arithmetic is important enough that it should easily justify
minor deviations from standard `jax.numpy`.

The other change from google#3263 is that I've only written a simple "Vector"
class for now rather than making an attempt to support arbitrary higher
dimensional arrays. This also considerably simplifies the
implementation. In the future, we can imagine adding a "Matrix" class,
for methods that need to keep track of multiple vectors (e.g., ODE
solvers, GMRES, L-BFGS, Lanczos, etc).
@shoyer
Copy link
Member Author

shoyer commented Nov 10, 2021

Closing this in favor of #8504

@shoyer shoyer closed this Nov 10, 2021
@shoyer shoyer changed the title Tree vectorizing transformation WIP: tree vectorizing transformation Nov 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

pytree transformation
3 participants