New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: tree vectorizing transformation #3263
Conversation
Still doesn't work for nested calls to JVP.
basis = jnp.eye(x.size, dtype=x.dtype) | ||
y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is very close to working, except we don't handle basis
properly yet because it has a different tree structure. (This works in other cases, but not for transformed tree_call functions yet.)
|
||
@tree_vectorize | ||
def f(g, x): | ||
return jax.grad(g)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gradient support needs JaxprTrace.process_tree_call
(which seems straightforward, but tedious).
@lu.transformation_with_aux | ||
def _tree_batch_subtrace(main, in_dims_trees, *in_vals_trees, **params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apaszke This is starting to feel pretty repetitive, especially when we consider doing the same thing for every higher order fucntions/transformation in JAX (e.g., partial eval, scan, while_loop, etc). It feels like there must be a better way to leverage existing support for wrapping trees...
A simpler alternative to google#3263
Fixes google#1012 The idea here is a simpler alternative to google#3263, based on a custom pytree (`tree_math.Vector`) rather than introducing a new final-style transformation. Eventually, it should be possible to write something like: ```python @functools.partial(tree_math.wrap, vector_argnums=(1,)) def odeint_midpoint(f, y0, h, steps): F = tree_math.unwrap(f, vector_argnums=(0,)) def step_fun(i, y): return y + h * F(y + F(y, i * h) / 2, (i + 0.5) * h)) return lax.fori_loop(0, num_steps, bodstep_funy_fun, y0) ``` Aside from `wrap` and `unwrap`, this is exactly how you would write a simple ODE solver in JAX without support for PyTrees. We currently [do something very similar](https://github.com/google/jax-cfd/blob/a92862fb757d122e7c5eee23a3b783997a2aeafe/jax_cfd/spectral/time_stepping.py#L92) for implementing ODE solvers in JAX-CFD. The upside of this approach is that writing a custom pytree is easy and entirely decoupled from JAX's other transformations. So we get support for JAX transformations (e.g., `grad`/`jit`/`vmap`) and control flow primitives (e.g., `while_loop`) for free. The downside is that it won't be possible to use standard `jax.numpy` functions on `Vector` objects, unless we make `jax.numpy` aware of `tree_math` (ala google#8381). Instead, I've added a few specialized helper functions like `where` and `maximum`. For now, this seems like an acceptable trade-off, given that the core use-case for tree math is to make it easy to write new algorithms from scientific computing, for which infix arithmetic is important enough that it should easily justify minor deviations from standard `jax.numpy`. The other change from google#3263 is that I've only written a simple "Vector" class for now rather than making an attempt to support arbitrary higher dimensional arrays. This also considerably simplifies the implementation. In the future, we can imagine adding a "Matrix" class, for methods that need to keep track of multiple vectors (e.g., ODE solvers, GMRES, L-BFGS, Lanczos, etc).
Closing this in favor of #8504 |
Fixes #1012
The idea behind this transformation is that every argument to a function decorated by
@tree_vectorize
is virtually flattened and concatenated into 1D vector, e.g.,{'x': 1, 'y': array([[2, 3]])}
becomesarray([1, 2, 3])
.However, under the covers the original arrays are still preserved with their original shapes. This has several advantages:
The hope is that this should make it considerably easier to write efficient numerical algorithms that work on pytrees (e.g., for optimization, linear equation solving or ODE solving).
Here are a few simple examples:
Even though the inputs are always treated as vectors, outputs can have any number of dimensions, e.g., producing nested pytrees-of-pytrees, as shown in this example from the unit tests:
TODO:
cg
:jit
while_loop
odeint
:concatenate
gather
scatter
odeint
!tree_call
works withvmap
,jvp
andvjp
(i.e., to implementjacfwd
andjacrev
)Trace
methods likeprocess_call
,post_process_call
x + jnp.zeros(...)
and explicit broadcasting work.tree_vectorize
use cases that works by explicitly unraveling/concatenating