Skip to content

Commit

Permalink
Add experimental custom_getter support to Haiku.
Browse files Browse the repository at this point in the history
We now have two methods that interact with `hk.get_parameter` in Haiku.

`custom_creator`s are run _before_ parameters are created (e.g. as part of init)
and can change the dtype or init function for a given parameter. Creators
influence what ends up in the "params" dictionary returned by `f.init(rng, ..)`.

`custom_getter`s (introduced in this change) allow you to intercept the
parameter when the user calls `get_parameter` _after_ the parameter is created.
The result of `custom_getter` is only passed to the caller and does not change
what ends up in the `params` dict returned by `init`. As a concrete example:

```python
def my_creator(next_creator, shape, dtype, init, context):
  print('running my_creator')
  # Change any of `shape`, `dtype` or `init` here.
  return next_creator(shape, dtype, init)

def my_getter(next_getter, value, context):
  print('running my_getter')
  # Apply any changes to `value` here.
  return next_getter(value)

def f():
  with hk.experimental.custom_creator(my_creator), \
       hk.experimental.custom_getter(my_getter):
    w = hk.get_parameter("w", [], init=jnp.zeros)
    w = hk.get_parameter("w", [], init=jnp.zeros)
    return w

f = hk.transform(f, apply_rng=True)

params = f.init(None)
# running my_creator ParamContext(full_name='~/w', module=None)
# running my_getter ParamContext(full_name='~/w', module=None)
# running my_getter ParamContext(full_name='~/w', module=None)

f.apply(params, None)
# running my_getter ParamContext(full_name='~/w', module=None)
# running my_getter ParamContext(full_name='~/w', module=None)
```

Ping #32.

PiperOrigin-RevId: 308408822
Change-Id: I526d8299f75810bf2c5985eb56d274ed6e39cac6
  • Loading branch information
tomhennigan authored and Copybara-Service committed Apr 25, 2020
1 parent 49b21f7 commit 67c510c
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 83 deletions.
27 changes: 25 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ Parameters and State

.. autofunction:: set_state

.. autofunction:: custom_creator

.. autofunction:: transparent

Random Number Generators
Expand Down Expand Up @@ -345,6 +343,31 @@ ResNet50
.. autoclass:: ResNet50
:members:

Experimental
------------

.. automodule:: haiku.experimental

custom_creator
~~~~~~~~~~~~~

.. autofunction:: custom_creator

custom_getter
~~~~~~~~~~~~~

.. autofunction:: custom_getter

lift
~~~~

.. autofunction:: lift

to_dot
~~~~~~

.. autofunction:: to_dot

References
----------

Expand Down
147 changes: 121 additions & 26 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@

import collections
import contextlib
import functools
from typing import (Iterator, Iterable, MutableMapping, NamedTuple, Optional,
Set, Tuple, Union)
from typing import (Callable, Iterator, Iterable, MutableMapping, NamedTuple,
Optional, Set, Tuple, Union)

from haiku._src import data_structures
from haiku._src.typing import (Shape, DType, ParamName, Initializer, Params, # pylint: disable=g-multiple-import
State, ParamCreator, PRNGKey, PRNGSeed)
from haiku._src.typing import ( # pylint: disable=g-multiple-import
Shape,
DType,
ParamName,
Initializer,
Params,
State,
Module,
PRNGKey,
PRNGSeed,
)
import jax
import jax.numpy as jnp

Expand All @@ -37,7 +45,8 @@

# TODO(tomhennigan) Should creator_stack be part of frame?
frame_stack = ThreadLocalStack() # type: ThreadLocalStack["Frame"]
creator_stack = ThreadLocalStack() # type: ThreadLocalStack[ParamCreator]
creator_stack = ThreadLocalStack() # type: ThreadLocalStack["ParamCreator"]
getter_stack = ThreadLocalStack() # type: ThreadLocalStack["ParamGetter"]


class Frame(NamedTuple):
Expand Down Expand Up @@ -187,20 +196,26 @@ def inside_transform():
return bool(frame_stack)


def safe_get_module_name(module) -> str:
def safe_get_module_name(module: Module) -> str:
# TODO(tomhennigan) Module specific code should be part of `module.py`.
if not hasattr(module, "module_name"):
raise ValueError("The super constructor must be called before you create "
"parameters or submodules.")
return module.module_name


def current_bundle_name():
def current_module() -> Optional[Module]:
frame = current_frame()
if frame.module_stack:
module = frame.module_stack.peek().module
module_name = safe_get_module_name(module)
return module_name
return frame.module_stack.peek().module
else:
return None


def current_bundle_name():
module = current_module()
if module is not None:
return safe_get_module_name(module)
else:
# Any parameters defined outside an `hk.Module` are put in the same group.
return "~"
Expand Down Expand Up @@ -272,6 +287,10 @@ def get_parameter(
param = create_parameter(fq_name, shape, dtype, init)
params[name] = param # pytype: disable=unsupported-operands

# Custom getters allow a hook for users to customize the value returned by
# get_parameter. For example casting values to some dtype.
param = run_custom_getters(fq_name, param)

assert param.shape == tuple(shape), (
"{!r} with shape {!r} does not match shape={!r} dtype={!r}".format(
fq_name, param.shape, shape, dtype))
Expand All @@ -280,23 +299,24 @@ def get_parameter(


def create_parameter(
original_name: ParamName,
full_name: ParamName,
shape: Shape,
dtype: DType = jnp.float32,
init: Initializer = None,
) -> jnp.ndarray:
"""Creates a parameter by running user defined creators then init.
>>> def fp16_creator(next_creator, name, shape, dtype):
... return next_creator(name, shape, jnp.float16)
>>> def zeros_creator(next_creator, shape, dtype, init, context):
... init = jnp.zeros
... return next_creator(shape, dtype, init)
>>> with hk.experimental.custom_creator(fp16_creator):
>>> with hk.experimental.custom_creator(zeros_creator):
... w = hk.get_parameter("w", [], jnp.float32, init=jnp.ones)
>>> w.dtype
dtype('float16')
Args:
original_name: Name of the parameter, including parent module name.
full_name: Name of the parameter, including parent module name.
shape: The shape of the parameter.
dtype: The dtype of the parameter.
init: A callable of shape, dtype to generate an initial value for the
Expand All @@ -308,20 +328,27 @@ def create_parameter(
if not creator_stack:
return init(shape, dtype)

def next_creator(name, shape, dtype, init):
if name != original_name:
raise ValueError(
"Modifying variable `name` in a custom creator is not supported.")
context = ParamContext(full_name=full_name, module=current_module())
creator_stack_copy = creator_stack.clone()

def next_creator(shape, dtype, init):
if creator_stack_copy:
return creator_stack_copy.popleft()(name, shape, dtype, init)
return creator_stack_copy.popleft()(next_creator, shape, dtype, init,
context)
else:
return init(shape, dtype)

creator_stack_copy = creator_stack.map(
lambda c: functools.partial(c, next_creator))
return next_creator(shape, dtype, init)


return creator_stack_copy.popleft()(original_name, shape, dtype, init)
class ParamContext(NamedTuple):
"""Read only state showing where parameters are being created."""
full_name: str
module: Optional[Module]

NextCreator = Callable[[Shape, DType, Initializer], jnp.ndarray]
ParamCreator = Callable[[NextCreator, Shape, DType, Initializer, ParamContext],
jnp.ndarray]


def custom_creator(creator: ParamCreator):
Expand All @@ -330,8 +357,9 @@ def custom_creator(creator: ParamCreator):
When new parameters are created via :func:`get_parameter` we first run custom
creators passing user defined values through. For example:
>>> def zeros_creator(next_creator, name, shape, dtype, init):
... return next_creator(name, shape, dtype, init=jnp.zeros)
>>> def zeros_creator(next_creator, shape, dtype, init, context):
... init = jnp.zeros
... return next_creator(shape, dtype, init)
>>> with hk.experimental.custom_creator(zeros_creator):
... w = hk.get_parameter("w", [], jnp.float32, jnp.ones)
Expand All @@ -344,9 +372,76 @@ def custom_creator(creator: ParamCreator):
Returns:
Context manager under which the creator is active.
"""
assert_context("experimental.custom_creator")
return creator_stack(creator)


def run_custom_getters(
full_name: ParamName,
value: jnp.ndarray,
) -> jnp.ndarray:
"""Creates a parameter by running user defined creators then init.
>>> def bfloat16_scope(next_getter, value, context):
... if value.dtype == jnp.float32:
... value = value.astype(jnp.bfloat16)
... return next_getter(value)
>>> with hk.experimental.custom_getter(bfloat16_scope):
... w = hk.get_parameter("w", [], jnp.float32, init=jnp.ones)
>>> w.dtype
dtype('bfloat16')
Args:
full_name: Name of the parameter, including parent module name.
value: The current value of the parameter.
Returns:
A jnp.ndarray with the parameter of the given shape/dtype.
"""
if not getter_stack:
return value

context = ParamContext(full_name=full_name, module=current_module())
getter_stack_copy = getter_stack.clone()

def next_creator(value):
if getter_stack_copy:
return getter_stack_copy.popleft()(next_creator, value, context)
else:
return value

return next_creator(value)

NextGetter = Callable[[ParamName, jnp.ndarray], jnp.ndarray]
ParamGetter = Callable[[NextGetter, jnp.ndarray, ParamContext], jnp.ndarray]


def custom_getter(getter: ParamGetter):
"""Registers a custom parameter getter.
When parameters are retrieved using :func:`get_parameter` we always run all
custom getters before returning a value to the user.
>>> def bf16_getter(next_getter, value, context):
... value = value.astype(jnp.bfloat16)
... return next_getter(value)
>>> with hk.experimental.custom_getter(bf16_getter):
... w = hk.get_parameter("w", [], jnp.float32, jnp.ones)
>>> w.dtype
dtype(bfloat16)
Args:
getter: A parameter getter.
Returns:
Context manager under which the getter is active.
"""
assert_context("experimental.custom_getter")
return getter_stack(getter)


def assert_is_prng_key(key: PRNGKey):
"""Asserts that the given input looks like a `jax.random.PRNGKey`."""
if not hasattr(key, "shape") or not hasattr(key, "dtype"):
Expand Down
79 changes: 65 additions & 14 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,35 +129,27 @@ def maybe_three():
self.assertTrue(jnp.all(jnp.array(rngs) == jnp.array(maybes)))

def test_init_custom_creator(self):
def zeros_creator(next_creator, name, shape, dtype, init):
self.assertEqual(name, "~/w")
def zeros_creator(next_creator, shape, dtype, init, context):
self.assertEqual(context.full_name, "~/w")
self.assertEqual(shape, [])
self.assertEqual(dtype, jnp.float32)
self.assertEqual(init, jnp.ones)
return next_creator(name, shape, dtype, jnp.zeros)
return next_creator(shape, dtype, jnp.zeros)

with base.new_context() as ctx:
with base.custom_creator(zeros_creator):
base.get_parameter("w", [], init=jnp.ones)

self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})

def test_unable_to_mutate_name(self):
def mutates_name(next_creator, name, shape, dtype, init):
next_creator(name + "_foo", shape, dtype, init)

with base.new_context(), base.custom_creator(mutates_name):
with self.assertRaisesRegex(ValueError,
"Modifying .*name.* not supported"):
base.get_parameter("w", [], init=jnp.ones)

def test_nested_creators(self):
log = []

def logging_creator(log_msg):
def _logging_creator(next_creator, name, shape, dtype, init):
def _logging_creator(next_creator, shape, dtype, init, context):
del context
log.append(log_msg)
return next_creator(name, shape, dtype, init)
return next_creator(shape, dtype, init)
return _logging_creator

with base.new_context():
Expand All @@ -168,6 +160,65 @@ def _logging_creator(next_creator, name, shape, dtype, init):

self.assertEqual(log, ["a", "b", "c"])

def test_custom_getter_bf16(self):
def bf16_getter(next_getter, value, context):
del context
if value.dtype == jnp.float32:
value = value.astype(jnp.bfloat16)
return next_getter(value)

with base.new_context() as ctx:
with base.custom_getter(bf16_getter):
f = base.get_parameter("f", [], jnp.float32, init=jnp.ones)
i = base.get_parameter("i", [], jnp.int32, init=jnp.ones)

params = ctx.collect_params()
self.assertEqual(params["~"]["f"].dtype, jnp.float32)
self.assertEqual(f.dtype, jnp.bfloat16)
self.assertEqual(params["~"]["i"].dtype, jnp.int32)
self.assertEqual(i.dtype, jnp.int32)

def test_nested_getters(self):
log = []

def logging_getter(log_msg, dtype_in, dtype_out):
def _logging_getter(next_getter, value, context):
del context
log.append(log_msg)
self.assertEqual(value.dtype, dtype_in)
value = value.astype(dtype_out)
return next_getter(value)
return _logging_getter

with base.new_context():
with base.custom_getter(logging_getter("a", jnp.float32, jnp.bfloat16)), \
base.custom_getter(logging_getter("b", jnp.bfloat16, jnp.int32)), \
base.custom_getter(logging_getter("c", jnp.int32, jnp.int8)):
w = base.get_parameter("w", [], init=jnp.ones)

self.assertEqual(w.dtype, jnp.int8)
self.assertEqual(log, ["a", "b", "c"])

def test_creator_requires_context(self):
def my_creator(next_creator, shape, dtype, init, context):
del context
return next_creator(shape, dtype, init)

with self.assertRaisesRegex(ValueError,
"must be used as part of an `hk.transform`"):
with base.custom_creator(my_creator):
pass

def test_getter_requires_context(self):
def my_getter(next_getter, value, context):
del context
return next_getter(value)

with self.assertRaisesRegex(ValueError,
"must be used as part of an `hk.transform`"):
with base.custom_getter(my_getter):
pass

def test_get_state_no_init_raises(self):
with base.new_context():
with self.assertRaisesRegex(ValueError, "set an init function"):
Expand Down

0 comments on commit 67c510c

Please sign in to comment.