Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
WIP
This is an attempt to resolve #174. I've implemented
vmap
in basically the same style asgrad
except it supports multiple vectorized arguments. We don't includeout_axes
like Jax, but that's easy to add either with an additionaltranspose
or to propagate during the expression transformation.There are still a lot of checks and validations that need to be done:
I wanted to open this to collect feedback on the foundation before moving forward with implementing some additional batch rules. Based on this PR, we'd then have to revisit some Nx implementations to support batching. Off the top of my head I believe most of the LinAlg operators will need to support batched matrix operations.
I think we can also discuss the possibility of implementing automatic vectorization based on input names (e.g.
:batch
name is always treated as a batch dimension).