diff --git a/haiku/_src/base.py b/haiku/_src/base.py index 70163117d..c94f02e6a 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -461,6 +461,11 @@ def reset(): return wrapped +def throw_if_run(shape, dtype): + del shape, dtype + raise ValueError("Initializer must be specified.") + + @replaceable def get_parameter( name: str, @@ -493,7 +498,8 @@ def get_parameter( assert_context("get_parameter") assert_jax_usage("get_parameter") - init = check_not_none(init, "Initializer must be specified.") + if init is None: + init = throw_if_run bundle_name = current_name() frame = current_frame() diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index dd7090c8b..016a1537b 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -134,6 +134,21 @@ def test_get_parameter_wrong_shape(self): base.get_parameter("w", (1,), init=jnp.zeros) base.get_parameter("w", (2,), init=jnp.zeros) + def test_get_parameter_no_init(self): + with base.new_context(): + with self.assertRaisesRegex(ValueError, "Initializer must be specified."): + base.get_parameter("w", []) + + def test_get_parameter_no_init_during_init_second_call(self): + with base.new_context(): + w = base.get_parameter("w", [], init=jnp.zeros) + self.assertIs(base.get_parameter("w", []), w) + + def test_get_parameter_no_init_during_apply(self): + w = jnp.zeros([]) + with base.new_context(params={"~": {"w": w}}): + self.assertIs(base.get_parameter("w", []), w) + @parameterized.parameters(base.next_rng_key, lambda: base.next_rng_keys(1)) def test_rng_no_transform(self, f): with self.assertRaisesRegex(ValueError,