-
Notifications
You must be signed in to change notification settings - Fork 4
Use field kinds within tree_map #2
Comments
Hey @thomaspinder! Thanks for the kind words. Let me first rule out the easy solution, curious if this is enough? class Parameter:
pass
def transform(x):
return jnp.exp(x)
@dataclass
class Model(to.Tree):
lengthscale: jnp.array = to.field(
default=jnp.array([1.0]), node=True, kind=Parameter
)
m = Model()
jax.tree_map(transform, to.filter(m, Parameter)) One thing to notice is that Kinds are just types that serve as metadata linked to a field but its not expected that they will be instantiated. |
Hey @cgarciae , thanks, but perhaps my MWE was an oversimplification. The reason for defining the transform as a method of the kind class is that there can be numerous classes e.g.,
and so on... This makes the solution you've proposed a little more tricky as there'd need to be some awkward function mappings. |
Based on the tidy solution you've provided in #3 , one possible solution to this problem could be the following. Do you see any issues with this?
|
@thomaspinder I was guessing you where trying to do this 😅 Here is the solution: from dataclasses import dataclass
import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field
class Parameter:
@staticmethod
def transform(x):
return jnp.exp(x)
@dataclass
class Model(to.Tree):
lengthscale: jnp.ndarray = to.field(
default=jnp.array([1.0]), node=True, kind=Parameter
)
m = Model()
with to.add_field_info():
m2 = jax.tree_map(lambda field: field.kind.transform(field.value), to.filter(m, Parameter))
print(m2) The My thoughts is that if this pattern becomes more widespread it would be convenient to add a m2 = to.map(
lambda field: field.kind.transform(field.value),
to.filter(m, Parameter),
add_field_info=True,
) |
BTW: Not sure if this is relevant to you but if you are doing something like this: params = jax.tree_map(some_function, to.filter(m, Parameter))
m = to.merge(m, params) You can simply use: m = to.map(some_function, m, Parameter) |
@thomaspinder sure! That solution based on #3 works. For ergonomics you can convert |
Thanks so much. The solution you give using |
@thomaspinder happy to guide you if you want to contribute 🌝 Ping me if you need anything. |
Sure! I'd be happy to contribute. Are you able to outline the main steps that I should be mindful of when doing this? |
I think adding a https://github.com/cgarciae/treeo/blob/master/treeo/api.py#L197 Also try to add a test 😃. Sadly we don't have a contributing document yet but to start developing do the following:
|
Firstly, thanks for creating Treeo - it's a fantastic package.
Is there a way to use methods defined within a field's
kind
object within atree_map
call? For example, consider the following MWEis there a way that I could do something similar to the following pseudocode snippet:
The text was updated successfully, but these errors were encountered: