Skip to content

Commit

Permalink
Add support for init=None in apply.
Browse files Browse the repository at this point in the history
Fixes #588.

PiperOrigin-RevId: 500935075
  • Loading branch information
tomhennigan authored and Copybara-Service committed Jan 10, 2023
1 parent 057878b commit da41cf3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
8 changes: 7 additions & 1 deletion haiku/_src/base.py
Expand Up @@ -461,6 +461,11 @@ def reset():
return wrapped


def throw_if_run(shape, dtype):
del shape, dtype
raise ValueError("Initializer must be specified.")


@replaceable
def get_parameter(
name: str,
Expand Down Expand Up @@ -493,7 +498,8 @@ def get_parameter(
assert_context("get_parameter")
assert_jax_usage("get_parameter")

init = check_not_none(init, "Initializer must be specified.")
if init is None:
init = throw_if_run

bundle_name = current_name()
frame = current_frame()
Expand Down
15 changes: 15 additions & 0 deletions haiku/_src/base_test.py
Expand Up @@ -134,6 +134,21 @@ def test_get_parameter_wrong_shape(self):
base.get_parameter("w", (1,), init=jnp.zeros)
base.get_parameter("w", (2,), init=jnp.zeros)

def test_get_parameter_no_init(self):
with base.new_context():
with self.assertRaisesRegex(ValueError, "Initializer must be specified."):
base.get_parameter("w", [])

def test_get_parameter_no_init_during_init_second_call(self):
with base.new_context():
w = base.get_parameter("w", [], init=jnp.zeros)
self.assertIs(base.get_parameter("w", []), w)

def test_get_parameter_no_init_during_apply(self):
w = jnp.zeros([])
with base.new_context(params={"~": {"w": w}}):
self.assertIs(base.get_parameter("w", []), w)

@parameterized.parameters(base.next_rng_key, lambda: base.next_rng_keys(1))
def test_rng_no_transform(self, f):
with self.assertRaisesRegex(ValueError,
Expand Down

0 comments on commit da41cf3

Please sign in to comment.