-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recommended way to associate metadata with parameters #371
Comments
Hi @davisyoshida, I can't think of a better solution, and someone internally is using this exact pattern so it might be useful to know that there is at least one other person who agrees with us. One enhancement you might consider would be to for your type to implement import chex
import jax.numpy as jnp
@chex.dataclass
class Box:
value: jnp.ndarray
def __jax_array__(self):
return self.value
a = Box(value=jnp.ones([]))
x = jnp.ones([])
a + x # works Re the order of creators/getters, I understand this might be a source of issues (I suspect just for the getter) but in practice I think it is usually quite easy for folks to re-order getters in their program (and it is not typical to have lots of them) so I would hope this would only be a small amount of friction. |
Thanks for the info. I didn't know about |
+1 I'd never heard of |
I think it is still very experimental, but you can find the PR adding it here: google/jax#5660 I believe there have been prior attempts to land this which were rolled back (see google/jax#5356). |
@tomhennigan Thanks for the pointers. Unfortunately it looks like they're debating removing it (see here: google/jax#10065), so I think I'll just avoid it for now. |
I have a project where I'm using custom_creator to compute some auxiliary information for each parameter. I'd like to end up with two pytrees with the same structure,
params
, andaux
, then do something likejax.tree_map(my_function, params, aux)
.(It's not possible to compute the auxiliary info directly from the
params
pytree, as it needs some contextual information that's only available from knowing which modules are being used etc).My current solution is to use a custom creator which returns a named tuple having both the params and the extra info, and a custom getter which ignores that info:
Is there a better way to accomplish this? One issue with this approach is that if any other getters or creators are added after mine, they probably won't work.
The text was updated successfully, but these errors were encountered: