Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the ability for parameters to not be stored.
By default, Haiku will put the value returned from `hk.get_parameter`, `hk.get_state` and `hk.set_state` into the dictionaries returned by `init`. This is not always desirable. For example, a user may want to have part of their network come from a pretrained checkpoint, and they may want to freeze those values (aka. have them not appear in the params dict passed later to `grad`). You can achieve this by manipulating the params dict, however sometimes it is more convenient to do this using custom creators/getters/setters. Consider the following function: >>> def f(x): ... x = hk.Linear(300, name='torso')(x) ... x = hk.Linear(10, name='tail')(x) ... return x Imagine you have a pre-trained set of weights for the torso: >>> pretrained = {'torso': {'w': jnp.ones([28 * 28, 300]), ... 'b': jnp.ones([300])}} First we define a creator, that tells Haiku to not store any parameters that are part of the pretrained dict: >>> def my_creator(next_creator, shape, dtype, init, context): ... if context.module_name in pretrained: ... return hk.experimental.DO_NOT_STORE ... return next_creator(shape, dtype, init) Then we need a getter that provides the parameter value from the pretrained dict: >>> def my_getter(next_getter, value, context): ... if context.module_name in pretrained: ... assert value is hk.experimental.DO_NOT_STORE ... value = pretrained[context.module_name][context.name] ... return next_getter(value) Finally we'll wrap our function in context managers activating our creator and getter: >>> def f_with_pretrained_torso(x): ... with hk.custom_creator(my_creator), \ ... hk.custom_getter(my_getter): ... return f(x) You can see that when we run our function we only get parameters from modules that were not in the pretrained dict: >>> f_with_pretrained_torso = hk.transform(f_with_pretrained_torso) >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 28 * 28]) >>> params = f_with_pretrained_torso.init(rng, x) >>> assert list(params) == ['tail'] This value can be used in initialisers, `hk.custom_creator` or `hk.custom_setter`. PiperOrigin-RevId: 450009234
- Loading branch information
1 parent
6b0c22c
commit 2a6c034
Showing
6 changed files
with
213 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters