Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of vmap #389

Closed
wants to merge 2 commits into from
Closed

Initial implementation of vmap #389

wants to merge 2 commits into from

Conversation

seanmor5
Copy link
Collaborator

@seanmor5 seanmor5 commented May 4, 2021

WIP

This is an attempt to resolve #174. I've implemented vmap in basically the same style as grad except it supports multiple vectorized arguments. We don't include out_axes like Jax, but that's easy to add either with an additional transpose or to propagate during the expression transformation.

There are still a lot of checks and validations that need to be done:

  • Normalize in axes
  • Do we support multiple in axes?
  • What constraints should be placed on in axes?

I wanted to open this to collect feedback on the foundation before moving forward with implementing some additional batch rules. Based on this PR, we'd then have to revisit some Nx implementations to support batching. Off the top of my head I believe most of the LinAlg operators will need to support batched matrix operations.

I think we can also discuss the possibility of implementing automatic vectorization based on input names (e.g. :batch name is always treated as a batch dimension).

{res, cache}

%{} ->
{res, cache} = vectorize(op, args, in_axes, expr, cache)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are missing a traverse_args somewhere that would effectively make this recursive. :)

@@ -233,6 +233,13 @@ defmodule Nx.Defn.Kernel do
grad
end

@doc """
Vectorizes `fun` over given `args`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some examples and more explanation? :) I am struggling to wrap my head around the current API. I assume that fun receives the vectorized tensor? Are there any restrictions on args? What are the possible values for in_axes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps handling dot is a good example of how the code will change before and after with and without vmap. :)

@josevalim
Copy link
Collaborator

Just so everyone is on the same page, I talked to @seanmor5 and we will try to implement this as tensor axes metadata instead of a transform.

@seanmor5 seanmor5 closed this Jun 18, 2021
@seanmor5 seanmor5 deleted the sm-vmap branch June 18, 2021 16:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support vmap
2 participants