-
Notifications
You must be signed in to change notification settings - Fork 226
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add experimental custom_getter support to Haiku.
We now have two methods that interact with `hk.get_parameter` in Haiku. `custom_creator`s are run _before_ parameters are created (e.g. as part of init) and can change the dtype or init function for a given parameter. Creators influence what ends up in the "params" dictionary returned by `f.init(rng, ..)`. `custom_getter`s (introduced in this change) allow you to intercept the parameter when the user calls `get_parameter` _after_ the parameter is created. The result of `custom_getter` is only passed to the caller and does not change what ends up in the `params` dict returned by `init`. As a concrete example: ```python def my_creator(next_creator, shape, dtype, init, context): print('running my_creator') # Change any of `shape`, `dtype` or `init` here. return next_creator(shape, dtype, init) def my_getter(next_getter, value, context): print('running my_getter') # Apply any changes to `value` here. return next_getter(value) def f(): with hk.experimental.custom_creator(my_creator), \ hk.experimental.custom_getter(my_getter): w = hk.get_parameter("w", [], init=jnp.zeros) w = hk.get_parameter("w", [], init=jnp.zeros) return w f = hk.transform(f, apply_rng=True) params = f.init(None) # running my_creator ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) f.apply(params, None) # running my_getter ParamContext(full_name='~/w', module=None) # running my_getter ParamContext(full_name='~/w', module=None) ``` Ping #32. PiperOrigin-RevId: 308408822 Change-Id: I526d8299f75810bf2c5985eb56d274ed6e39cac6
- Loading branch information
1 parent
49b21f7
commit 67c510c
Showing
7 changed files
with
262 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.