Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation for passing in trees #2367

Closed
srush opened this issue Mar 6, 2020 · 7 comments
Closed

Documentation for passing in trees #2367

srush opened this issue Mar 6, 2020 · 7 comments
Assignees
Labels
bug Something isn't working question Questions for the JAX team

Comments

@srush
Copy link

srush commented Mar 6, 2020

I would like to pass a dictionary through a vmap.

I made an attempt to understand https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html but couldn't figure out if this is something I need to understand or not?

vmap(loss, [None, 0, 0])(dictionary, X[i:i+batch], y[i:i+batch])
vmap(loss, [None, 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])
vmap(loss, [tree_flatten(dict_none), 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])

Either way I get an issue like this:

ValueError: Expected list, got (([<object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>], <object object at 0x7fcb21ef58b0>), <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>).

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/api.py in _flatten_axes(treedef, axis_tree)
    714     msg = ("axes specification must be a tree prefix of the corresponding "
    715            "value, got specification {} for value {}.")
--> 716     raise ValueError(msg.format(axis_tree, treedef))
    717   axes = [None if a is proxy else a for a in axes]
    718   assert len(axes) == treedef.num_leaves

ValueError: axes specification must be a tree prefix of the corresponding value, got specification [([], PyTreeDef(dict[['dense1', 'dense2', 'dense3']], [PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], [])])), 0, 0] for value PyTreeDef(tuple, [PyTreeDef(tuple, [PyTreeDef(list, [*,*,*,*,*,*]),*]),*,*]).
@shoyer
Copy link
Member

shoyer commented Mar 6, 2020

I suspect you could make a dictionary with the right keys/values by using the fromkeys constructor but it's a little hard to say for sure without a working example to test: vmap(loss, [dict.fromkeys(dictionary), 0, 0])(dictionary, X[i:i+batch], y[i:i+batch])

That said, if you don't want to include an argument in a JAX transformation, the easiest way is often just to define a new function with that argument already applied, e.g.,

from functools import partial
vmap(partial(loss, dictionary))(X[i:i+batch], y[i:i+batch])

@srush
Copy link
Author

srush commented Mar 6, 2020

Thanks that's a helpful tip!

@mattjj mattjj added the question Questions for the JAX team label Mar 10, 2020
@mattjj
Copy link
Member

mattjj commented Mar 10, 2020

Thanks for raising this, Sasha!

We don't want to make you understand the details of pytrees; we want vmap and all the JAX transformations to "just work" on (nested) standard Python containers. In particular, you should not need to read the pytrees docs (and it does say in bold at the top, "This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX.")

To more precisely answer your question, we might need a runnable repro. But this works:

import jax.numpy as np
from jax import vmap

dictionary = {'a': 5., 'b': np.ones(2)}
x = np.zeros(3)
y = np.arange(3.)


def f(dct, x, y):
  return dct['a'] + dct['b'] + x + y

result = vmap(f, (None, 0, 0))(dictionary, x, y)

Yet, understandably surprisingly, this doesn't!

result = vmap(f, [None, 0, 0])(dictionary, x, y)

# ValueError: axes specification must be a tree prefix
# of the corresponding value, got specification [None, 0, 0]
# for value PyTreeDef(tuple, [PyTreeDef(dict[['a', 'b']], [*,*]),*,*]).

The issue is that the axis specification has to be a tree prefix of the args tuple, meaning an int (i.e. a kind of pytree, a leaf), a tuple (because args is a tuple), or a tuple of pytrees (including a tuple of ints, or a tuple of other kinds of pytrees). It can't be a list!

I think this behavior is surprising because we're so used to treating lists and tuples interchangeably in Python APIs, like we treat 'foo' and "foo" string quoting interchangeably.

I want to fix this!

@mattjj mattjj self-assigned this Mar 10, 2020
@mattjj mattjj added the bug Something isn't working label Mar 10, 2020
@mattjj mattjj closed this as completed in ebbcbad Mar 10, 2020
@srush
Copy link
Author

srush commented Mar 10, 2020

Oh wow. I think this error message sent me down the rabbit hole of thinking I needed to understand PyTrees :) this is a much simpler error.

srvasude pushed a commit to srvasude/jax that referenced this issue May 5, 2020
@lukas-braun
Copy link

lukas-braun commented May 19, 2020

I have a related question: Is it possible to use vmap on an array that is in a dictionary? And if yes, how?

To stick with your example, how to vmap 'b', which is an array within a dictionary:

import jax.numpy as np
from jax import vmap

dictionary = {'a': 5., 'b': np.arange(5)}
c = 1.
d = 2.

def f(dct, x, y):
  return dct['a'] + dct['b'] + c + d

result = vmap(f, magic)(dictionary, c, d)

I don't understand how I would have to create the magic axes tuple for this example or if it is even possible?

@mattjj
Copy link
Member

mattjj commented May 19, 2020

@lukasbrauncom can you open a new issue? I can answer your question, but with a new issue it'll be more discoverable :)

@cottrell
Copy link
Contributor

cottrell commented Sep 3, 2021

Will ask in a separate thread but it seems the docs are missing entries on how to create pytrees (from existing ones) for example to manipulate params.

For example, if you want to see sensitivities w.r.t. params.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants