New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for custom pytrees #32
Comments
Hi @awav, thanks for trying Haiku! I think there are two Haiku assumptions that you are challenging here:
I think we can make this work, concretely I would suggest:
Putting that all together: import jax
import jax.numpy as jnp
import haiku as hk
from typing import NamedTuple
class S(NamedTuple):
x: jnp.ndarray
y: jnp.ndarray
@property
def shape(self):
# Hack to workaround the fact that `get_parameter` checks tensor shapes.
return ()
class SModule(hk.Module):
def __init__(self, x, y, name=None):
super().__init__(name=name)
self.s = hk.get_parameter("struct", (), None, init=lambda *_: S(x, y))
def __call__(self, x, a):
return jnp.sqrt(self.s.x ** 2 * self.s.y ** 2) * x * a
def loss(x):
s = SModule(1.0, 2.0)
a = hk.get_parameter("free", shape=(), dtype=jnp.float32, init=jnp.ones)
y = s(x, a)
return jnp.sum(y)
loss = hk.transform(loss)
x = jnp.array([2.0])
key = jax.random.PRNGKey(42)
params = loss.init(key, x)
jax.grad(loss.apply)(params, x) Output:
If this looks good then I'm happy to make a change to WDYT? |
@tomhennigan, for a very simple case, the namedtuple approach will work. However, the main challenge is the implementation of transformed parameters. class Parameter:
def __init__(self, init_constrained_value: jnp.ndarray, constraint: tfp.bijectors.Bijector):
# NOTE: Compute gradients w.r.t. this unconstrained value!!!
self._unconstrained_value = constraint.inverse(init_constrained_value)
self._constraint = constraint
# NOTE: convert the value in unconstrained space to the value in constrained space
def constrained_value(self):
return self._constraint.forward(self._unconstrained_value)
def __call__(self):
return self.constrained_value()
def loss(x):
p = Parameter(1.0, tfp.bijector.Exp())
return jnp.square(p())
def loss_complex(x):
class ProbModel:
def __init__(self):
self.variance = Parameter(1.0, tfp.bijector.Exp())
def __call__(self, x):
pass
m = ProbModel()
return m(x) After initialization, a researcher needs information about passed bijector for different reasons, that could be monitoring or debugging an algorithm. Does it make sense? Also, I don't really like |
@sharadmv has done a lot of thinking about probabilistic programming in JAX (outside of Haiku) and might have some useful input for us here.
Absolutely.
Agreed that it is ugly looking, I like your suggestion, I think we should probably call this Is there anything else in Haiku getting in your way for this type of research? |
@tomhennigan your |
Hey @mattwescott and @awav , sorry for the delay implementing this. Before adding to core I want to think carefully about how it will interact with JAX transforms, especially when those transforms are used inside a haiku transformed function (e.g. via For now you should be able to use this without needing changes in Haiku by adding the following utility function in your code and using it in your modules (it is slightly ugly since it adds a "Box" type around your type, but otherwise this should unblock you): from typing import Any, NamedTuple
class Box(NamedTuple):
value: Any
shape = property(fget=lambda _: ())
def get_parameter_tree(name, init):
return hk.get_parameter(name, [], init=lambda *_: Box(init())).value You can use it as so: >>> def f():
... p = get_parameter_tree("w", lambda: (jnp.ones([]), jnp.zeros([])))
... return p
>>> hk.transform(f, apply_rng=True).init(None)
frozendict({
'~': frozendict({
'w': Box(value=(DeviceArray(1., dtype=float32), DeviceArray(0., dtype=float32))),
}),
})
It isn't right now, the closest we have is >>> inits = {}
>>> def creator(next_getter, name, shape, dtype, init):
... inits[name] = init
... return next_getter(name, shape, dtype, init)
>>> f = lambda: hk.nets.MLP([300, 100, 10])(jnp.ones([1, 1]))
>>> f = hk.transform(f, apply_rng=True)
>>> with hk.experimental.custom_creator(creator):
... f.init(jax.random.PRNGKey(42))
>>> inits
{'mlp/~/linear_0/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df5f8>,
'mlp/~/linear_0/b': <function jax.numpy.lax_numpy.zeros>,
'mlp/~/linear_1/w': <haiku._src.initializers.TruncatedNormal at 0x7f283125b048>,
'mlp/~/linear_1/b': <function jax.numpy.lax_numpy.zeros>,
'mlp/~/linear_2/w': <haiku._src.initializers.TruncatedNormal at 0x7f28476df358>,
'mlp/~/linear_2/b': <function jax.numpy.lax_numpy.zeros>} I could imagine extending this custom getter to also pass the |
@tomhennigan thanks for the examples.
This would be great, so much cleaner! |
Would it be impractical to instead intercept module creation? With a mapping from module names to types, could use Either approach would likely be sufficient for me to adopt Haiku. |
We now have two methods that interact with `hk.get_parameter` in Haiku. `custom_creator`s are run _before_ parameters are created (e.g. as part of init) and can change the dtype or init function for a given parameter. Creators influence what ends up in the "params" dictionary returned by `f.init(rng, ..)`. `custom_getter`s (introduced in this change) allow you to intercept the parameter when the user calls `get_parameter` _after_ the parameter is created. The result of `custom_getter` is only passed to the caller and does not change what ends up in the `params` dict returned by `init`. As a concrete example: ```python def my_creator(next_creator, shape, dtype, init, context): print('running my_creator') # Change any of `shape`, `dtype` or `init` here. return next_creator(shape, dtype, init) def my_getter(next_getter, value, context): print('running my_getter') # Apply any changes to `value` here. return next_getter(value) def f(): with hk.experimental.custom_creator(my_creator), \ hk.experimental.custom_getter(my_getter): w = hk.get_parameter("w", [], init=jnp.zeros) w = hk.get_parameter("w", [], init=jnp.zeros) return w f = hk.transform(f, apply_rng=True) params = f.init(None) # running my_creator ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) f.apply(params, None) # running my_getter ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) ``` Ping #32. PiperOrigin-RevId: 308408822 Change-Id: I526d8299f75810bf2c5985eb56d274ed6e39cac6
Support for extracting module info in a creator has landed 😄 Here's an example colab using it to extract all info to a dict outside the function: https://colab.research.google.com/drive/1tt9ifYFsxvSSXaFAz_Oq59Im8QY4S16o Using it inside a transformed function is documented here: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.experimental.custom_creator |
@tomhennigan, I found out that |
whoa, that struct.dataclass is cool, and would solve headaches of passing modules to functions and getting |
Hello,
haiku
team! Thanks a lot for making awesomehaiku
.I'm interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in #16 (comment). The pytrees ideally fits into the described use case. The user can create its own differentiable "vectors" and I would expect
haiku
to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don't work at the moment.Failing example
Thanks
The text was updated successfully, but these errors were encountered: