Skip to content

Commit

Permalink
Add the ability for parameters to not be stored.
Browse files Browse the repository at this point in the history
By default, Haiku will put the value returned from `hk.get_parameter`,
`hk.get_state` and `hk.set_state` into the dictionaries returned by `init`. This
is not always desirable.

For example, a user may want to have part of their network come from a
pretrained checkpoint, and they may want to freeze those values (aka. have them
not appear in the params dict passed later to `grad`). You can achieve this by
manipulating the params dict, however sometimes it is more convenient to do this
using custom creators/getters/setters.

Consider the following function:

>>> def f(x):
...   x = hk.Linear(300, name='torso')(x)
...   x = hk.Linear(10, name='tail')(x)
...   return x

Imagine you have a pre-trained set of weights for the torso:

>>> pretrained = {'torso': {'w': jnp.ones([28 * 28, 300]),
...                         'b': jnp.ones([300])}}

First we define a creator, that tells Haiku to not store any parameters that are
part of the pretrained dict:

>>> def my_creator(next_creator, shape, dtype, init, context):
...   if context.module_name in pretrained:
...     return hk.experimental.DO_NOT_STORE
...   return next_creator(shape, dtype, init)

Then we need a getter that provides the parameter value from the pretrained
dict:

>>> def my_getter(next_getter, value, context):
...   if context.module_name in pretrained:
...     assert value is hk.experimental.DO_NOT_STORE
...     value = pretrained[context.module_name][context.name]
...   return next_getter(value)

Finally we'll wrap our function in context managers activating our creator and
getter:

>>> def f_with_pretrained_torso(x):
...   with hk.custom_creator(my_creator), \
...        hk.custom_getter(my_getter):
...     return f(x)

You can see that when we run our function we only get parameters from modules
that were not in the pretrained dict:

>>> f_with_pretrained_torso = hk.transform(f_with_pretrained_torso)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 28 * 28])
>>> params = f_with_pretrained_torso.init(rng, x)
>>> assert list(params) == ['tail']

This value can be used in initialisers, `hk.custom_creator` or
`hk.custom_setter`.

PiperOrigin-RevId: 450009234
  • Loading branch information
tomhennigan authored and Copybara-Service committed May 20, 2022
1 parent 6b0c22c commit 2a6c034
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 24 deletions.
6 changes: 6 additions & 0 deletions docs/api.rst
Expand Up @@ -1117,6 +1117,7 @@ Managing State
transparent_lift_with_state
LiftWithStateUpdater
check_jax_usage
DO_NOT_STORE

current_name
~~~~~~~~~~~~
Expand Down Expand Up @@ -1158,6 +1159,11 @@ check_jax_usage

.. autofunction:: check_jax_usage

DO_NOT_STORE
~~~~~~~~~~~~

.. autoclass:: DO_NOT_STORE

Optimizations
-------------

Expand Down
159 changes: 135 additions & 24 deletions haiku/_src/base.py
Expand Up @@ -37,6 +37,12 @@
from jax import core as jax_core
import jax.numpy as jnp

try:
from typing import final # pylint: disable=g-import-not-at-top
except ImportError:
# Pre Python 3.8.
final = lambda cls: cls

DEFAULT_PRNG_RESERVE_SIZE = 1

Stack = data_structures.Stack
Expand Down Expand Up @@ -352,6 +358,90 @@ def params_frozen():
return current_frame().params_frozen


@final
class DoNotStore:
r"""Causes a parameter or state value to not be stored.
By default, Haiku will put the value returned from
:func:`~haiku.get_parameter`, :func:`~haiku.get_state` and
:func:`~haiku.set_state` into the dictionaries returned by ``init``. This is
not always desirable.
For example, a user may want to have part of their network come from a
pretrained checkpoint, and they may want to freeze those values (aka. have
them not appear in the params dict passed later to ``grad``). You can achieve
this by manipulating the params dict, however sometimes it is more convenient
to do this using custom creators/getters/setters.
Consider the following function:
>>> def f(x):
... x = hk.Linear(300, name='torso')(x)
... x = hk.Linear(10, name='tail')(x)
... return x
Imagine you have a pre-trained set of weights for the torso:
>>> pretrained = {'torso': {'w': jnp.ones([28 * 28, 300]),
... 'b': jnp.ones([300])}}
First we define a creator, that tells Haiku to not store any parameters that
are part of the pretrained dict:
>>> def my_creator(next_creator, shape, dtype, init, context):
... if context.module_name in pretrained:
... return hk.experimental.DO_NOT_STORE
... return next_creator(shape, dtype, init)
Then we need a getter that provides the parameter value from the pretrained
dict:
>>> def my_getter(next_getter, value, context):
... if context.module_name in pretrained:
... assert value is hk.experimental.DO_NOT_STORE
... value = pretrained[context.module_name][context.name]
... return next_getter(value)
Finally we'll wrap our function in context managers activating our creator and
getter:
>>> def f_with_pretrained_torso(x):
... with hk.custom_creator(my_creator), \
... hk.custom_getter(my_getter):
... return f(x)
You can see that when we run our function we only get parameters from modules
that were not in the pretrained dict:
>>> f_with_pretrained_torso = hk.transform(f_with_pretrained_torso)
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.ones([1, 28 * 28])
>>> params = f_with_pretrained_torso.init(rng, x)
>>> assert list(params) == ['tail']
This value can be used in initialisers, :func:`~haiku.custom_creator` or
:func:`~haiku.custom_setter`.
"""

@property
def shape(self):
raise ValueError("DO_NOT_STORE does not have a shape.")

@property
def dtype(self):
raise ValueError("DO_NOT_STORE does not have a dtype.")

DO_NOT_STORE = DoNotStore()

T = TypeVar("T")


def check_not_none(value: Optional[T], msg: str) -> T:
if value is None:
raise ValueError(msg)
return value


def get_parameter(
name: str,
shape: Sequence[int],
Expand Down Expand Up @@ -383,49 +473,62 @@ def get_parameter(
assert_context("get_parameter")
assert_jax_usage("get_parameter")

if init is None:
raise ValueError("Initializer must be specified.")
init = check_not_none(init, "Initializer must be specified.")

bundle_name = current_bundle_name()
frame = current_frame()

if frame.params_frozen and bundle_name not in frame.params:
raise ValueError(
"Unable to retrieve parameter {!r} for module {!r}. "
"All parameters must be created as part of `init`.".format(
name, bundle_name))

params = frame.params[bundle_name]
param = params.get(name)
fq_name = bundle_name + "/" + name
context = GetterContext(full_name=fq_name, module=current_module(),
original_dtype=dtype, original_shape=shape)
if param is None:
if frame.params_frozen:
raise ValueError(
"Unable to retrieve parameter {!r} for module {!r}. "
"All parameters must be created as part of `init`.".format(
name, bundle_name))
original_dtype=dtype, original_shape=shape,
original_init=init)

if bundle_name not in frame.params:
param = None
else:
param = frame.params[bundle_name].get(name)

if param is None:
if param_creator_stack:
param = run_creators(param_creator_stack, context, shape, dtype, init)
else:
param = init(shape, dtype)
params[name] = param # pytype: disable=unsupported-operands

if param is DO_NOT_STORE:
# Initializers or custom creators that return `DO_NOT_STORE` are required
# to produce a value for the parameter via a custom getter.
remove_if_empty(frame.params, bundle_name)
else:
if frame.params_frozen:
# Throw if we needed to re-init the parameter during apply.
raise ValueError(
f"Unable to retrieve parameter {name!r} for module "
f"{bundle_name!r} All parameters must be created as part of "
"`init`.")

param = check_not_none(param, "Parameters cannot be `None`.")
frame.params[bundle_name][name] = param

# Custom getters allow a hook for users to customize the value returned by
# get_parameter. For example casting values to some dtype.
if param_getter_stack:
param = run_getters(param_getter_stack, context, param)

param = check_not_none(param, "Parameters cannot be `None`.")

if param.shape != tuple(shape):
raise ValueError(
"{!r} with retrieved shape {!r} does not match shape={!r} dtype={!r}"
.format(fq_name, param.shape, shape, dtype))
f"{fq_name!r} with retrieved shape {param.shape!r} does not match "
f"shape={shape!r} dtype={dtype!r}")

return param


def remove_if_empty(bundle, key):
if key in bundle and not bundle[key]:
del bundle[key]


class GetterContext(NamedTuple):
"""Context about where parameters are being created.
Expand All @@ -444,6 +547,7 @@ class GetterContext(NamedTuple):
module: Optional[Module]
original_dtype: Any
original_shape: Sequence[int]
original_init: Optional[Initializer]

@property
def module_name(self):
Expand Down Expand Up @@ -968,7 +1072,7 @@ def get_state(
bundle_name = current_bundle_name()
state = current_frame().state[bundle_name]
fq_name = f"{bundle_name}/{name}"
context = GetterContext(fq_name, current_module(), dtype, shape)
context = GetterContext(fq_name, current_module(), dtype, shape, init)

value = state.get(name, None)
if value is None:
Expand All @@ -984,7 +1088,8 @@ def get_state(
else:
value = init(shape, dtype)

state[name] = StatePair(value, value)
if value is not DO_NOT_STORE:
state[name] = StatePair(value, value)
else:
value = value.current

Expand Down Expand Up @@ -1027,16 +1132,22 @@ def set_state(name: str, value):
"""
assert_context("set_state")
assert_jax_usage("set_state")
state = current_frame().state[current_bundle_name()]
frame = current_frame()
bundle_name = current_bundle_name()
state = frame.state[bundle_name]

if state_setter_stack:
shape = jax.tree_map(maybe_shape, value)
dtype = jax.tree_map(maybe_dtype, value)
fq_name = current_bundle_name() + "/" + name
fq_name = bundle_name + "/" + name
context = SetterContext(full_name=fq_name, module=current_module(),
original_dtype=dtype, original_shape=shape)
value = run_setters(state_setter_stack, context, value)

if value is DO_NOT_STORE:
remove_if_empty(frame.state, bundle_name)
return

if name in state:
initial, _ = state[name]
current = value
Expand Down
31 changes: 31 additions & 0 deletions haiku/_src/base_test.py
Expand Up @@ -611,5 +611,36 @@ def test_unsafe_use_of_jax(self, haiku_side_effect_fn, jax_fn):
with self.assertRaises(base.JaxUsageError):
f(x)

def test_do_not_store(self):
def my_creator(next_creator, shape, dtype, init, context):
del next_creator, shape, dtype, init, context
return base.DO_NOT_STORE

def my_getter(next_getter, value, context):
assert value is base.DO_NOT_STORE
return next_getter(
context.original_init(context.original_shape, context.original_dtype))

def my_setter(next_setter, value, context):
del next_setter, value, context
return base.DO_NOT_STORE

with base.new_context() as ctx:
with base.custom_creator(my_creator, state=True), \
base.custom_getter(my_getter, state=True), \
base.custom_setter(my_setter):
self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
base.set_state("s2", jnp.ones([]))

self.assertEmpty(ctx.collect_params())
self.assertEmpty(ctx.collect_state())

def test_do_not_store_array_like(self):
with self.assertRaises(ValueError):
base.DO_NOT_STORE.shape # pylint: disable=pointless-statement # pytype: disable=attribute-error
with self.assertRaises(ValueError):
base.DO_NOT_STORE.dtype # pylint: disable=pointless-statement # pytype: disable=attribute-error

if __name__ == "__main__":
absltest.main()
9 changes: 9 additions & 0 deletions haiku/_src/integration/doctest_test.py
Expand Up @@ -90,7 +90,16 @@ def run_test():
test_name = name + "_" + attr_name
test_names.append(test_name)
module.__test__[test_name] = attr_value
elif (isinstance(value, str) or inspect.isfunction(value) or
inspect.ismethod(value) or inspect.isclass(value)):
test_names.append(name)
module.__test__[name] = value
elif hasattr(value, "__doc__"):
test_names.append(name)
module.__test__[name] = value.__doc__
else:
# This will probably fail, DocTestFinder.find: __test__ values must be
# strings, functions, methods, classes, or modules
test_names.append(name)
module.__test__[name] = value

Expand Down
29 changes: 29 additions & 0 deletions haiku/_src/transform_test.py
Expand Up @@ -528,6 +528,35 @@ def __init__(self):
self.assertNotIsInstance(m, Mapping)
self.assertIs(transform.check_mapping("params", m), m)

def test_do_not_store(self):
def my_creator(next_creator, shape, dtype, init, context):
del next_creator, shape, dtype, init, context
return base.DO_NOT_STORE

def my_getter(next_getter, value, context):
assert value is base.DO_NOT_STORE
return next_getter(
context.original_init(context.original_shape, context.original_dtype))

def my_setter(next_setter, value, context):
del next_setter, value, context
return base.DO_NOT_STORE

def f():
with base.custom_creator(my_creator, state=True), \
base.custom_getter(my_getter, state=True), \
base.custom_setter(my_setter):
self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
base.set_state("s2", jnp.ones([]))

f = transform.transform_with_state(f)
params, state = f.init(None)
self.assertEmpty(params)
self.assertEmpty(state)
_, state = f.apply({}, {}, None)
self.assertEmpty(state)


class ObjectWithTransform:

Expand Down
3 changes: 3 additions & 0 deletions haiku/experimental/__init__.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-importing-member
"""Experimental features developed by the Haiku core team.
Features may be removed or modified at any time.
Expand All @@ -20,6 +21,7 @@
from haiku._src.base import current_name
from haiku._src.base import custom_creator
from haiku._src.base import custom_getter
from haiku._src.base import DO_NOT_STORE
from haiku._src.base import GetterContext
from haiku._src.config import check_jax_usage
from haiku._src.config import module_auto_repr
Expand Down Expand Up @@ -58,6 +60,7 @@
"current_name",
"custom_creator",
"custom_getter",
"DO_NOT_STORE",
"intercept_methods",
"jaxpr_info",
"layer_stack",
Expand Down

0 comments on commit 2a6c034

Please sign in to comment.