Implementing vmap over list of PyTrees #13698
Replies: 1 comment
-
|
points = Point3(jnp.array([1, 2]), jnp.array([2, 4]), jnp.array([0, 1]))
jax.vmap(lambda x: x + x)(points)As you mentioned, this is a struct-of-arrays pattern, whereas your example was an array-of-structs pattern. If you're unable to switch to a struct-of-arrays pattern, you won't be able to use Regarding the error you're seeing: it's not directly related to the struct-of-arrays issue, but rather to the fact that to be used within JAX transforms, PyTrees must be flexible regarding input validation. This issue is discussed at Custom PyTrees and Initialization. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hey there! I'm currently trying to build a didactic example for a class I'm holding on Neural Rendering.
I want to focus less on numerical / implementation optimizations in order to make the graphics concepts more easy to follow. At the same time I want to illustrate some of the nice things about JAX (jit, grad, vmap, pmap, etc.). Therefore, I went for an
OOPlike approach - loosely inspired from this google-research repo.I have searched for similar issues(#5322) and am aware that a
Struct-of-Arraysis preferred in terms of efficiency. I am just curious if there is any sensible way of implementing this toy example without making use pack-unpack tricks (I feel that manually flattening the tree and unflattering the tree inside some specific ops would defeat the purpose of the example and I should just directly resort to SoA variants)My goal for one exercise would be to
vmapon top of a list of such primitives where I make use of the overloaded operators.However, when trying to
vmapthat function I get the following error:When inspecting what gets passed in from the
tree_unflattenoperation into the constructorThe full stack-trace is the following:
Sorry if I broke any discussions formatting rules!
Beta Was this translation helpful? Give feedback.
All reactions