diff --git a/flax/core/scope.py b/flax/core/scope.py index c69011556..f4eb487e2 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -833,7 +833,7 @@ def param(self, name: str, init_fn: Callable[..., T], *init_args, # for inference to a half float type for example. if jnp.shape(val) != jnp.shape(abs_val): raise errors.ScopeParamShapeError(name, self.path_text, - jnp.shape(val), jnp.shape(abs_val)) + jnp.shape(abs_val), jnp.shape(val)) else: if not self.is_mutable_collection('params'): if self.is_collection_empty('params'): diff --git a/flax/errors.py b/flax/errors.py index 79eb88ee0..24aa8516e 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -260,9 +260,9 @@ def __call__(self, x): """ def __init__(self, param_name, scope_path, value_shape, init_shape): - super().__init__('Inconsistent shapes between value and initializer ' - f'for parameter "{param_name}" in "{scope_path}": ' - f'{value_shape}, {init_shape}.') + super().__init__(f'Initializer expected to generate shape {init_shape} ' + f'but got shape {value_shape} instead for parameter ' + f'"{param_name}" in "{scope_path}".') class ScopeVariableNotFoundError(FlaxError): diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index cda54824d..5994a2890 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -107,7 +107,7 @@ def test_inconsistent_param_shapes(self): def f(scope): scope.param('test', nn.initializers.ones_init(), (4,)) - msg = r'Inconsistent shapes between value and initializer for parameter "test" in "/": \(2,\), \(4,\).' + msg = r'Initializer expected to generate shape \(2,\) but got shape \(4,\) instead for parameter "test" in "/"' with self.assertRaisesRegex(errors.ScopeParamShapeError, msg): apply(f)(freeze({'params': {'test': np.ones((2,))}}))