Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Get all unique kinds #3

Closed
thomaspinder opened this issue Oct 21, 2021 · 1 comment
Closed

Get all unique kinds #3

thomaspinder opened this issue Oct 21, 2021 · 1 comment

Comments

@thomaspinder
Copy link
Contributor

Hi,

Is there a way that I can get a list of all the unique kinds within a nested dataclass? For example:

class KindOne: pass
class KindTwo: pass

@dataclass
class SubModel(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindOne
    )


@dataclass 
class Model(to.Tree):
    parameter: jnp.array = to.field(
        default=jnp.array([1.0]), node=True, kind=KindTwo
    )

m = Model()

m.unique_kinds() # [KindOne, KindTwo]
@cgarciae
Copy link
Owner

Interesting, there is a way to do this using a combination of to.apply and the .field_metadata property:

from dataclasses import dataclass
from typing import Set

import jax
import jax.numpy as jnp
import treeo as to
from treeo.utils import field


class KindOne: pass
class KindTwo: pass

@dataclass
class SubModel(to.Tree):
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindOne)

@dataclass
class Model(to.Tree):
    submodel: SubModel
    parameter: jnp.ndarray = to.field(default=jnp.array([1.0]), node=True, kind=KindTwo)

def unique_kinds(tree: to.Tree) -> Set[type]:
    kinds = set()

    def add_subtree_kinds(subtree: to.Tree):
        for field in subtree.field_metadata.values():
            if field.kind is not type(None):
                kinds.add(field.kind)

    to.apply(add_subtree_kinds, tree)

    return kinds

m = Model(SubModel())

print(unique_kinds(m))  # {KindOne, KindTwo}

apply traverses all the Trees within a Pytree, and field_metadata is a dictionary describing each field.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants