Skip to content

Commit

Permalink
Create (internal only for now) context manager API for Haiku modules.
Browse files Browse the repository at this point in the history
This makes transform trivial (the actual impl is only longer because of error
checking):

    def transform_with_state(f):
      def init_fn(rng, x):
        with new_context(rng=rng) as ctx:
          f(x)
        return ctx.collect_params(), ctx.collect_state()

      def apply_fn(params, state, rng, x):
        with new_context(params=params, state=state, rng=rng) as ctx:
          return f(x), ctx.collect_state()

      return init_fn, apply_fn

But also means we could in theory offer a fully imperative API:

    with new_context(rng=rng) as ctx:
      mod = hk.nets.MLP([300, 100, 10])
      mod(example)
      params = ctx.collect_params()

    .. at some point later ..

    with new_context(params=params):
      out = mod(x)

My motivation for exploring this API is that users very commonly want to be able
to use their Haiku modules without having to wrap them in `hk.transform`, and
some advanced users would like to be able to produce functions that look like
`hk.transform` but with a different contract (e.g. producing multiple apply
methods).

My gut feeling is that this API while being strictly more flexible is also quite
dangerous in JAX (e.g. there are many assumptions about functional purity in JAX
which the results of `hk.transform` satisfies that this does not). I would like
to make this available for folks to play with, but not expose it or promote it
widely for now.

Ping #16.

PiperOrigin-RevId: 301013361
Change-Id: Ia5766a470ef08109c90be6d07f1226b742880c2b
  • Loading branch information
tomhennigan authored and Copybara-Service committed Mar 15, 2020
1 parent 3a34141 commit 7560469
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 12 deletions.
107 changes: 95 additions & 12 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,97 @@ def module(self, module_state: ModuleState):
current_frame = frame_stack.peek


class HaikuContext(object):
"""Collects and injects values for computations."""

__slots__ = ("__params", "__state", "__rng",
"__expected_stack", "__names", "__counter")

def __init__(
self,
params: Union[Params, MutableParams],
state: Union[State, MutableState],
rng: Optional["PRNGSequence"],
):
# NOTE: Using __ vs. _ since these are "really" private (as in using these
# properties directly could result in broken behaviour).
self.__params = params
self.__state = state
self.__rng = rng
self.__expected_stack = ThreadLocalStack()
self.__names = set()
self.__counter = collections.Counter()

def collect_params(self) -> Params:
return data_structures.to_immutable_dict(self.__params)

def collect_initial_state(self) -> State:
return _extract_state(self.__state, initial=True)

def collect_state(self) -> State:
return _extract_state(self.__state, initial=False)

def __enter__(self):
frame = Frame.create(
params=self.__params, state=self.__state, rng=self.__rng)
frame.used_names_stack.push(self.__names)
frame.counter_stack.push(self.__counter)
self.__expected_stack.push(frame)
frame_stack.push(frame)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
actual = frame_stack.pop()
expected = self.__expected_stack.pop()
assert actual is expected


def new_context(
*,
params: Optional[Params] = None,
state: Optional[State] = None,
rng: Optional[Union[PRNGKey, PRNGSeed]] = None,
) -> HaikuContext:
"""Collects the results of hk.{get,set}_{parameter,state} calls.
>>> with new_context(rng=jax.random.PRNGKey(42)) as ctx:
... mod = hk.nets.MLP([300, 100, 10])
... y1 = mod(jnp.ones([1, 1]))
>>> assert len(jax.tree_leaves(ctx.collect_params())) == 6
>>> with ctx:
... y2 = mod(jnp.ones([1, 1]))
The same module instance in the same context will produce the same value:
>>> assert (y1 == y2).all()
Args:
params: Optional parameter values to inject.
state: Optional state values to inject.
rng: Optional rng to inject.
Returns:
Context manager which closes over mutable Haiku internal state.
"""
if params is None:
params = collections.defaultdict(dict)
else:
params = data_structures.to_immutable_dict(params)

if state is None:
state = collections.defaultdict(dict)
else:
state = {m: {k: StatePair(v, v) for k, v in p.items()}
for m, p in state.items()}

if rng is not None and not isinstance(rng, PRNGSequence):
rng = PRNGSequence(rng)

return HaikuContext(params, state, rng)


def inside_transform():
return bool(frame_stack)

Expand Down Expand Up @@ -318,8 +409,6 @@ def init_fn(
**kwargs,
):
"""Initializes your function collecting parameters and state."""
params = collections.defaultdict(dict)
state = collections.defaultdict(dict)
if rng is not None:
try:
rng = PRNGSequence(rng)
Expand All @@ -328,12 +417,10 @@ def init_fn(
"Init must be called with an RNG as the first argument, the "
"required signature is: `init(rng, *a, **k)`") from e

with frame_stack(Frame.create(params=params, state=state, rng=rng)):
with new_context(rng=rng) as ctx:
f(*args, **kwargs)

params = data_structures.to_immutable_dict(params)
state = _extract_state(state, initial=True)
return params, state
return ctx.collect_params(), ctx.collect_initial_state()

# EXPERIMENTAL: Expose the original function as a private attribute.
init_fn._original_fn = f # pylint: disable=protected-access
Expand All @@ -360,9 +447,6 @@ def apply_fn(
):
"""Applies your function injecting parameters and state."""
# TODO(tomhennigan) Remove support for `None` params (used in tests).
params = data_structures.to_immutable_dict(params)
state = {m: {k: StatePair(v, v) for k, v in p.items()}
for m, p in state.items()}
if rng is not None:
try:
rng = PRNGSequence(rng)
Expand All @@ -375,11 +459,10 @@ def apply_fn(
f"Apply must be called with an RNG as the {position} argument, "
f"the required signature is: `{signature}`") from e

with frame_stack(Frame.create(params=params, state=state, rng=rng)):
with new_context(params=params, state=state, rng=rng) as ctx:
out = f(*args, **kwargs)

state = _extract_state(state, initial=False)
return out, state
return out, ctx.collect_state()

# EXPERIMENTAL: Expose the original function as a private attribute.
apply_fn._original_fn = f # pylint: disable=protected-access
Expand Down
18 changes: 18 additions & 0 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,24 @@ def with_decorator():
self.assertNotEqual(without_decorator_out, expected_output)
self.assertEqual(with_decorator_out, expected_output)

def test_new_context(self):
with base.new_context() as ctx:
pass
self.assertEmpty(ctx.collect_params())
self.assertEmpty(ctx.collect_initial_state())
self.assertEmpty(ctx.collect_state())

def test_context_copies_input(self):
before = {"~": {"w": jnp.array(1.)}}
with base.new_context(params=before, state=before) as ctx:
base.get_parameter("w", [], init=jnp.ones)
base.set_state("w", jnp.array(2.))
self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
self.assertIsNot(ctx.collect_initial_state(), before)
self.assertEqual(ctx.collect_initial_state(), before)
self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
self.assertEqual(before, {"~": {"w": jnp.array(1.)}})


class ObjectWithTransform(object):

Expand Down
30 changes: 30 additions & 0 deletions haiku/_src/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,36 @@ def method_hook(mod, method_name):
[("scalar_module", "__call__", ("scalar_module/w",)),
("captures_module", "__call__", ("scalar_module/w",))])

def test_context_reuse_same_instance(self):
params = {"parent_module/~/child_module": {"w": jnp.array(2.)},
"parent_module/~/child_module_1": {"w": jnp.array(3.)},
"parent_module_1/~/child_module": {"w": jnp.array(4.)},
"parent_module_1/~/child_module_1": {"w": jnp.array(5.)}}

with base.new_context(params=params) as ctx:
mod1 = ParentModule()
mod2 = ParentModule()
self.assertEqual(mod1.module_name, "parent_module")
self.assertEqual(mod2.module_name, "parent_module_1")
for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
self.assertEqual(parent.child1(), c1)
self.assertEqual(parent.child2(), c2)

with ctx:
for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
self.assertEqual(parent.child1(), c1)
self.assertEqual(parent.child2(), c2)

# Creating a new context should not be a problem.
with base.new_context(params=ctx.collect_params()) as ctx:
mod1 = ParentModule()
mod2 = ParentModule()
self.assertEqual(mod1.module_name, "parent_module")
self.assertEqual(mod2.module_name, "parent_module_1")
for parent, (c1, c2) in ((mod1, (2., 3.)), (mod2, (4., 5.))):
self.assertEqual(parent.child1(), c1)
self.assertEqual(parent.child2(), c2)


class CapturesModule(module.Module):

Expand Down

0 comments on commit 7560469

Please sign in to comment.