Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Stacking of Treeo.Tree #23

Closed
peterroelants opened this issue Dec 22, 2022 · 3 comments
Closed

Stacking of Treeo.Tree #23

peterroelants opened this issue Dec 22, 2022 · 3 comments

Comments

@peterroelants
Copy link

peterroelants commented Dec 22, 2022

I'm running into some issues when trying to stack a list of Treeo.Tree objects into a single object. I've made a short example:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to

@dataclass
class Person(to.Tree):
    height: jnp.array = to.field(node=True) # I am a node field!
    age_static: jnp.array = to.field(node=False) # I am a static field!, I should not be updated.
    name: str = to.field(node=False) # I am a static field!

persons = [
    Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
    Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
    Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
]

# Stack (struct of arrays instead of list of structs)
jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

However, this fails with the following exception:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 18
     11     name: str = to.field(node=False) # I am a static field!
     13 persons = [
     14     Person(height=jnp.array(1.8), age_static=jnp.array(25.), name="John"),
     15     Person(height=jnp.array(1.7), age_static=jnp.array(100.), name="Wald"),
     16     Person(height=jnp.array(2.1), age_static=jnp.array(50.), name="Karen")
     17 ]
---> 18 jax.tree_map(lambda *values: jnp.stack(values, axis=0), *persons)

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in tree_map(f, tree, is_leaf, *rest)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/workspace/lcms_polymer_model/env/env_conda_local/lcms_polymer_model_env/lib/python3.10/site-packages/jax/_src/tree_util.py:199, in <listcomp>(.0)
    166 """Maps a multi-input function over pytree args to produce a new pytree.
    167 
    168 Args:
   (...)
    196   [[5, 7, 9], [6, 1, 2]]
    197 """
    198 leaves, treedef = tree_flatten(tree, is_leaf)
--> 199 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    200 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Mismatch custom node data: {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(25., dtype=float32, weak_type=True), 'name': 'John'} != {'_field_metadata': {'height': <treeo.types.FieldMetadata object at 0x7fb8b898ba00>, 'age_static': <treeo.types.FieldMetadata object at 0x7fb8b90c0a90>, 'name': <treeo.types.FieldMetadata object at 0x7fb8b8bf9db0>, '_field_metadata': <treeo.types.FieldMetadata object at 0x7fb8b89b56f0>, '_factory_fields': <treeo.types.FieldMetadata object at 0x7fb8b89b5750>, '_default_field_values': <treeo.types.FieldMetadata object at 0x7fb8b89b5660>, '_subtrees': <treeo.types.FieldMetadata object at 0x7fb8b89b5720>}, 'age_static': DeviceArray(100., dtype=float32, weak_type=True), 'name': 'Wald'}; value: Person(height=DeviceArray(1.7, dtype=float32, weak_type=True), age_static=DeviceArray(100., dtype=float32, weak_type=True), name='Wald').

Versions used:

  • JAX: 0.3.20
  • Treeo: 0.0.10

From a certain perspective this is expected because jax.tree_map does not apply to static (node=False) fields. So in this sense, this might not be really an issue with Treeo. However, I'm looking for some guidance on how to still be able to stack objects like this with static fields.
Has anyone has tried something similar and come up with a nice solution?

@peterroelants
Copy link
Author

It seems to me that if I want to fix this I need to somehow map over PyTreeDefs, I created a question at the JAX Github here: google/jax#13768

@cgarciae
Copy link
Owner

I don't think you can collectively tree_map Pytrees with different static fields (node=False). There are ways to go around this but the outcome is undefined (which static value to choose?).

@peterroelants
Copy link
Author

I don't think you can collectively tree_map Pytrees with different static fields (node=False).

It seems from the discussion at google/jax#13768 that this is indeed not possible.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants