Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Use field kinds within tree_map #2

Closed
thomaspinder opened this issue Oct 21, 2021 · 10 comments
Closed

Use field kinds within tree_map #2

thomaspinder opened this issue Oct 21, 2021 · 10 comments

Comments

@thomaspinder
Copy link
Contributor

Firstly, thanks for creating Treeo - it's a fantastic package.

Is there a way to use methods defined within a field's kind object within a tree_map call? For example, consider the following MWE

import jax.numpy as jnp

class Parameter:
    def transform(self):
        return jnp.exp(self)


@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

is there a way that I could do something similar to the following pseudocode snippet:

m = Model()
jax.tree_map(lamdba x: x.transform(), to.filter(m, Parameter))
@cgarciae
Copy link
Owner

Hey @thomaspinder! Thanks for the kind words.

Let me first rule out the easy solution, curious if this is enough?

class Parameter:
    pass
    
def transform(x):
    return jnp.exp(x)

@dataclass
class Model(to.Tree):
    lengthscale: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )

m = Model()
jax.tree_map(transform, to.filter(m, Parameter))

One thing to notice is that Kinds are just types that serve as metadata linked to a field but its not expected that they will be instantiated.

@thomaspinder
Copy link
Contributor Author

Hey @cgarciae , thanks, but perhaps my MWE was an oversimplification. The reason for defining the transform as a method of the kind class is that there can be numerous classes e.g.,

class PositiveParameter():
    def transform(self):
        return jnp.abs(self)

class NegativeParameter():
    def transform(self):
        return jnp.array(-1.) * self 

and so on...

This makes the solution you've proposed a little more tricky as there'd need to be some awkward function mappings.

@thomaspinder
Copy link
Contributor Author

thomaspinder commented Oct 21, 2021

Based on the tidy solution you've provided in #3 , one possible solution to this problem could be the following. Do you see any issues with this?

from dataclasses import dataclass
from typing import Set

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class KindOne:
    def transform(self):
        def transform_fn(x):
            return jnp.abs(x)
        return transform_fn

class KindTwo: 
    def transform(self):
        def transform_fn(x):
            return jnp.array(-1.) * x
        return transform_fn


@dataclass
class SubModel(to.Tree):
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)

@dataclass
class Model(to.Tree):
    submodel: SubModel
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)

def unique_kinds(tree: to.Tree) -> Set[type]:
    kinds = set()

    def add_subtree_kinds(subtree: to.Tree):
        for field in subtree.field_metadata.values():
            if field.kind is not type(None):
                kinds.add(field.kind)

    to.apply(add_subtree_kinds, tree)

    return list(kinds)


sub_m = SubModel()
m = Model(submodel=sub_m)


for kind in unique_kinds(m):
    transform = kind().transform()
    m = to.map(transform, m, kind)

@cgarciae
Copy link
Owner

@thomaspinder I was guessing you where trying to do this 😅

Here is the solution:

from dataclasses import dataclass

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class Parameter:
    @staticmethod
    def transform(x):
        return jnp.exp(x)


@dataclass
class Model(to.Tree):
    lengthscale: jnp.ndarray = to.field(
        default=jnp.array([1.0]), node=True, kind=Parameter
    )


m = Model()

with to.add_field_info():
    m2 = jax.tree_map(lambda field: field.kind.transform(field.value), to.filter(m, Parameter))

print(m2)

The add_field_info function probably needs a section on the User Guide, what it does is that when flattening a Tree its leaves will all be of a type called FieldInfo which among other things contains the kind and value attributes which you can use to achieve what you want. Note that I've changed transform to be a staticmethod.

My thoughts is that if this pattern becomes more widespread it would be convenient to add a add_field_info: bool argument to to.map so you could write something like this:

m2 = to.map(
    lambda field: field.kind.transform(field.value), 
    to.filter(m, Parameter), 
    add_field_info=True,
)

@cgarciae
Copy link
Owner

BTW: Not sure if this is relevant to you but if you are doing something like this:

params = jax.tree_map(some_function, to.filter(m, Parameter))
m = to.merge(m, params)

You can simply use:

m = to.map(some_function, m, Parameter)

@cgarciae
Copy link
Owner

cgarciae commented Oct 21, 2021

@thomaspinder sure! That solution based on #3 works. For ergonomics you can convert transform to be a staticmethod so you don't have to instantiate the kind.

@thomaspinder
Copy link
Contributor Author

Thanks so much. The solution you give using with to.add_field_info()... is the perfect solution to my problem. Adding the additional argument to to.map() would be really great - if you ever want a hand with this e.g., writing tests/documentation, then I'd be happy to help you out.

@cgarciae
Copy link
Owner

@thomaspinder happy to guide you if you want to contribute 🌝
This issue looks very self contained, can be a good starting point.

Ping me if you need anything.

@thomaspinder
Copy link
Contributor Author

Sure! I'd be happy to contribute. Are you able to outline the main steps that I should be mindful of when doing this?

@cgarciae
Copy link
Owner

cgarciae commented Nov 1, 2021

I think adding a field_info: bool argument to map and then conditionally using the add_field_info context manager over this line should be enough:

https://github.com/cgarciae/treeo/blob/master/treeo/api.py#L197

Also try to add a test 😃. Sadly we don't have a contributing document yet but to start developing do the following:

  1. Install poetry
  2. Run poetry install to install dependencies
  3. Run poetry shell to activate environment.
  4. Run pre-commit install to install precommit hooks.
  5. Run pytest to run tests.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants