You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.
I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:
fromdataclassesimportdataclassimportjaximportjax.numpyasjnpimporttreeoasto@dataclassclassPerson(to.Tree):
height: jnp.array=to.field(node=True) # I am a node field!age_static: jnp.array=to.field(node=False) # I am a static field!, I should not be updated.name: str=to.field(node=False) # I am a static field!persons= [
Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]
# Stack (struct of arrays instead of list of structs)jax.tree_map(lambda*values: jnp.stack(values, axis=0), *persons)
However, this fails with the following exception:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[1], line 18
11 name: str= to.field(node=False) # I am a static field!13 persons = [
14 Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
15 Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
16 Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
17 ]
---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)
File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
166"""Maps a multi-input function over pytree args to produce a new pytree.167168 Args:
(...)
196 [[5, 7, 9], [6, 1, 2]]
197"""198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
200return treedef.unflatten(f(*xs) for xs inzip(*all_leaves))
File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
166"""Maps a multi-input function over pytree args to produce a new pytree.167168 Args:
(...)
196 [[5, 7, 9], [6, 1, 2]]
197"""198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
200return treedef.unflatten(f(*xs) for xs inzip(*all_leaves))
ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').
Versions used:
JAX: 0.3.20
Treeo: 0.0.10
From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields.
Has anyone has tried something similar and come up with a nice solution?
The text was updated successfully, but these errors were encountered:
I don't think you can collectively tree_map Pytrees with different static fields (node=False). There are ways to go around this but the outcome is undefined (which static value to choose?).
I'm running into some issues when trying to stack a list of
Treeo.Tree
objects into a single object. I've made a short example:However, this fails with the following exception:
Versions used:
From a certain perspective this is expected because
jax.tree_map
does not apply to static (node=False
) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields.Has anyone has tried something similar and come up with a nice solution?
The text was updated successfully, but these errors were encountered: