Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recommended way to associate metadata with parameters #371

Closed
davisyoshida opened this issue Apr 13, 2022 · 5 comments
Closed

Recommended way to associate metadata with parameters #371

davisyoshida opened this issue Apr 13, 2022 · 5 comments

Comments

@davisyoshida
Copy link

I have a project where I'm using custom_creator to compute some auxiliary information for each parameter. I'd like to end up with two pytrees with the same structure, params, and aux, then do something like jax.tree_map(my_function, params, aux).
(It's not possible to compute the auxiliary info directly from the params pytree, as it needs some contextual information that's only available from knowing which modules are being used etc).

My current solution is to use a custom creator which returns a named tuple having both the params and the extra info, and a custom getter which ignores that info:

from collections import namedtuple    
import haiku as hk    
import jax    
import jax.numpy as jnp    
    
ParamAndAux = namedtuple('ParamAndAux', ['param', 'aux'])    
    
def my_creator(next_creator, shape, dtype, init, context):    
    param = next_creator(shape, dtype, init)    
    aux = 12345 # Replace this with actually doing something interesting    
    return ParamAndAux(param=param, aux=aux)    
    
def my_getter(next_getter, value, context):    
    if isinstance(value, ParamAndAux):    
        return value.param    
    return next_getter(value)    
    
def split(params_and_aux):    
    inner_structure = jax.tree_util.tree_structure((0, 0))    
    outer_structure = jax.tree_util.tree_structure(    
        params_and_aux,    
        is_leaf=lambda n: isinstance(n, ParamAndAux)·    
    )    
    return jax.tree_util.tree_transpose(    
        outer_structure,    
        inner_structure,    
        params_and_aux)    
    
    
def main():    
    def f(x):    
        with hk.custom_creator(my_creator), hk.custom_getter(my_getter):    
            return hk.Linear(17)(x)    
    
    model = hk.without_apply_rng(hk.transform(f))    
    params_and_aux = model.init(jax.random.PRNGKey(0), jnp.zeros(7))    
    
    params, aux = split(params_and_aux)   # params and aux end up with the same structure as desired

if __name__ == '__main__':    
    main()

Is there a better way to accomplish this? One issue with this approach is that if any other getters or creators are added after mine, they probably won't work.

@tomhennigan
Copy link
Collaborator

Hi @davisyoshida, I can't think of a better solution, and someone internally is using this exact pattern so it might be useful to know that there is at least one other person who agrees with us.

One enhancement you might consider would be to for your type to implement __jax_array__. Then JAX operations would understand how to unpack it and you might be able to avoid the custom getter:

import chex
import jax.numpy as jnp

@chex.dataclass
class Box:
  value: jnp.ndarray

  def __jax_array__(self):
    return self.value

a = Box(value=jnp.ones([]))
x = jnp.ones([])
a + x  # works

Re the order of creators/getters, I understand this might be a source of issues (I suspect just for the getter) but in practice I think it is usually quite easy for folks to re-order getters in their program (and it is not typical to have lots of them) so I would hope this would only be a small amount of friction.

@davisyoshida
Copy link
Author

Thanks for the info. I didn't know about __jax_array__ before, and haven't been able to find any documentation about it. Is there somewhere I can read about how JAX makes use of it?

@IanQS
Copy link

IanQS commented Apr 23, 2022

+1 I'd never heard of __jax_array__ prior to this thread

@tomhennigan
Copy link
Collaborator

I think it is still very experimental, but you can find the PR adding it here: google/jax#5660

I believe there have been prior attempts to land this which were rolled back (see google/jax#5356).

@davisyoshida
Copy link
Author

@tomhennigan Thanks for the pointers. Unfortunately it looks like they're debating removing it (see here: google/jax#10065), so I think I'll just avoid it for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants