New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Documentation for passing in trees #2367
Comments
I suspect you could make a dictionary with the right keys/values by using the That said, if you don't want to include an argument in a JAX transformation, the easiest way is often just to define a new function with that argument already applied, e.g., from functools import partial
vmap(partial(loss, dictionary))(X[i:i+batch], y[i:i+batch]) |
Thanks that's a helpful tip! |
Thanks for raising this, Sasha! We don't want to make you understand the details of pytrees; we want To more precisely answer your question, we might need a runnable repro. But this works: import jax.numpy as np
from jax import vmap
dictionary = {'a': 5., 'b': np.ones(2)}
x = np.zeros(3)
y = np.arange(3.)
def f(dct, x, y):
return dct['a'] + dct['b'] + x + y
result = vmap(f, (None, 0, 0))(dictionary, x, y) Yet, understandably surprisingly, this doesn't! result = vmap(f, [None, 0, 0])(dictionary, x, y)
# ValueError: axes specification must be a tree prefix
# of the corresponding value, got specification [None, 0, 0]
# for value PyTreeDef(tuple, [PyTreeDef(dict[['a', 'b']], [*,*]),*,*]). The issue is that the axis specification has to be a tree prefix of the I think this behavior is surprising because we're so used to treating lists and tuples interchangeably in Python APIs, like we treat I want to fix this! |
Oh wow. I think this error message sent me down the rabbit hole of thinking I needed to understand PyTrees :) this is a much simpler error. |
I have a related question: Is it possible to use To stick with your example, how to
I don't understand how I would have to create the |
@lukasbrauncom can you open a new issue? I can answer your question, but with a new issue it'll be more discoverable :) |
Will ask in a separate thread but it seems the docs are missing entries on how to create pytrees (from existing ones) for example to manipulate params. For example, if you want to see sensitivities w.r.t. params. |
I would like to pass a dictionary through a vmap.
I made an attempt to understand https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html but couldn't figure out if this is something I need to understand or not?
vmap(loss, [None, 0, 0])(dictionary, X[i:i+batch], y[i:i+batch])
vmap(loss, [None, 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])
vmap(loss, [tree_flatten(dict_none), 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])
Either way I get an issue like this:
The text was updated successfully, but these errors were encountered: