Create Custom PyTreeDef objects, or adapt existing ones. #13768
-
Is it possible to create custom Context: The following example based on Treeo fails: from dataclasses import dataclass
import jax
import jax.numpy as jnp
import treeo as to
@dataclass
class Person(to.Tree):
height: jnp.array = to.field(node=True) # I am a node field!
age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
name: str = to.field(node=False) # I am a static field!
persons = [
Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]
# Stack (struct of arrays instead of list of structs)
jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
This code results in this exception (open dropdown to see).
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 11 replies
-
I've not heard of the There has been discussion surrounding automatic registration of dataclasses as pytrees (see #2371) but currently that registration is not automatic, so you need to use these registration mechanisms. ALternatively, if you believe this is something that is supposed to happen automatically when you import |
Beta Was this translation helpful? Give feedback.
-
From the discussion at: #13768 (reply in thread) It is not possible to create or edit
|
Beta Was this translation helpful? Give feedback.
From the discussion at: #13768 (reply in thread)
It is not possible to create or edit
PyTreeDef
objects.jaxlib.xla_extension.pytree.PyTreeDef
are defined as C++ code in the Tensorflow XLA Library. These do not seem to expose methods that could be useful in solving this issue.