From 5a4859d2c88c5e0119193851d57067514f5587a7 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 7 Apr 2026 09:45:48 -0700 Subject: [PATCH] Fix the wrong names for the hooks in nnx.with_metadata, fixed such that names match the proper hook names in Variable class. PiperOrigin-RevId: 895947493 --- flax/nnx/variablelib.py | 127 ++++++++++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 45 deletions(-) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 1e71dba0a..151798446 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -256,11 +256,11 @@ def is_array_ref(x) -> tp.TypeGuard[Ref]: @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): raw_value: A - set_value_hooks: tuple[SetValueHook[A], ...] = () - get_value_hooks: tuple[GetValueHook[A], ...] = () - create_value_hooks: tuple[CreateValueHook[A], ...] = () - add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] = () - remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] = () + on_set_value: tuple[SetValueHook[A], ...] = () + on_get_value: tuple[GetValueHook[A], ...] = () + on_create_value: tuple[CreateValueHook[A], ...] = () + on_add_axis: tuple[AddAxisHook[Variable[A]], ...] = () + on_remove_axis: tuple[RemoveAxisHook[Variable[A]], ...] = () metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) @@ -1302,6 +1302,18 @@ def __init__( ) eager_sharding = aux_metadata['eager_sharding'] metadata.update(aux_metadata) + + if 'on_get_value' not in metadata and value.on_get_value: + metadata['on_get_value'] = value.on_get_value + if 'on_set_value' not in metadata and value.on_set_value: + metadata['on_set_value'] = value.on_set_value + if 'on_create_value' not in metadata and value.on_create_value: + metadata['on_create_value'] = value.on_create_value + if 'on_add_axis' not in metadata and value.on_add_axis: + metadata['on_add_axis'] = value.on_add_axis + if 'on_remove_axis' not in metadata and value.on_remove_axis: + metadata['on_remove_axis'] = value.on_remove_axis + value = tp.cast(A, value.raw_value) if hijax is None: @@ -1342,7 +1354,12 @@ def __init__( # run create_value hooks if 'on_create_value' in metadata: - value = metadata['on_create_value'](self, value) + hooks = metadata['on_create_value'] + if isinstance(hooks, tuple): + for hook in hooks: + value = hook(self, value) + else: + value = hooks(self, value) object.__setattr__(self, '_raw_value', value) # run create_value hook @@ -1656,13 +1673,23 @@ def get_value(self, *, index: tp.Any = MISSING) -> A: elif is_array_ref(value): value = value[...] if 'on_get_value' in self._var_metadata: - value = self._var_metadata['on_get_value'](self, value) + hooks = self._var_metadata['on_get_value'] + if isinstance(hooks, tuple): + for hook in hooks: + value = hook(self, value) + else: + value = hooks(self, value) return value # type: ignore def set_value(self, value: A, *, index: tp.Any = MISSING): value = jax.tree.map(lambda x: x, value) # make a copy if 'on_set_value' in self._var_metadata: - value = self._var_metadata['on_set_value'](self, value) + hooks = self._var_metadata['on_set_value'] + if isinstance(hooks, tuple): + for hook in hooks: + value = hook(self, value) + else: + value = hooks(self, value) # update _raw_value if is_array_ref(self._raw_value): if isinstance(index, Missing): @@ -1694,11 +1721,21 @@ def set_value(self, value: A, *, index: tp.Any = MISSING): def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: - self._var_metadata['on_add_axis'](self, axis_index, axis_name) + hooks = self._var_metadata['on_add_axis'] + if isinstance(hooks, tuple): + for hook in hooks: + hook(self, axis_index, axis_name) + else: + hooks(self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: - self._var_metadata['on_remove_axis'](self, axis_index, axis_name) + hooks = self._var_metadata['on_remove_axis'] + if isinstance(hooks, tuple): + for hook in hooks: + hook(self, axis_index, axis_name) + else: + hooks(self, axis_index, axis_name) @tp.overload def copy(self, value: B, **kwargs) -> Variable[B]: ... @@ -2206,69 +2243,69 @@ class Perturbation(Intermediate[A]): def with_metadata( initializer: F, - set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), - get_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), - create_value_hooks: tp.Union[ + on_set_value: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), + on_get_value: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), + on_create_value: tp.Union[ CreateValueHook[A], tp.Sequence[CreateValueHook[A]] ] = (), - add_axis_hooks: tp.Union[ + on_add_axis: tp.Union[ AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] ] = (), - remove_axis_hooks: tp.Union[ + on_remove_axis: tp.Union[ RemoveAxisHook[Variable[A]], tp.Sequence[RemoveAxisHook[Variable[A]]], ] = (), **metadata: tp.Any, ) -> F: - if set_value_hooks: - if callable(set_value_hooks): - set_value_hooks = (set_value_hooks,) + if on_set_value: + if callable(on_set_value): + on_set_value = (on_set_value,) else: - set_value_hooks = tuple(set_value_hooks) + on_set_value = tuple(on_set_value) else: - set_value_hooks = () + on_set_value = () - if get_value_hooks: - if callable(get_value_hooks): - get_value_hooks = (get_value_hooks,) + if on_get_value: + if callable(on_get_value): + on_get_value = (on_get_value,) else: - get_value_hooks = tuple(get_value_hooks) + on_get_value = tuple(on_get_value) else: - get_value_hooks = () + on_get_value = () - if create_value_hooks: - if callable(create_value_hooks): - create_value_hooks = (create_value_hooks,) + if on_create_value: + if callable(on_create_value): + on_create_value = (on_create_value,) else: - create_value_hooks = tuple(create_value_hooks) + on_create_value = tuple(on_create_value) else: - create_value_hooks = () + on_create_value = () - if add_axis_hooks: - if callable(add_axis_hooks): - add_axis_hooks = (add_axis_hooks,) + if on_add_axis: + if callable(on_add_axis): + on_add_axis = (on_add_axis,) else: - add_axis_hooks = tuple(add_axis_hooks) + on_add_axis = tuple(on_add_axis) else: - add_axis_hooks = () + on_add_axis = () - if remove_axis_hooks: - if callable(remove_axis_hooks): - remove_axis_hooks = (remove_axis_hooks,) + if on_remove_axis: + if callable(on_remove_axis): + on_remove_axis = (on_remove_axis,) else: - remove_axis_hooks = tuple(remove_axis_hooks) + on_remove_axis = tuple(on_remove_axis) else: - remove_axis_hooks = () + on_remove_axis = () @functools.wraps(initializer) def wrapper(*args): return VariableMetadata( initializer(*args), - set_value_hooks=set_value_hooks, - get_value_hooks=get_value_hooks, - create_value_hooks=create_value_hooks, - add_axis_hooks=add_axis_hooks, - remove_axis_hooks=remove_axis_hooks, + on_set_value=on_set_value, + on_get_value=on_get_value, + on_create_value=on_create_value, + on_add_axis=on_add_axis, + on_remove_axis=on_remove_axis, metadata=metadata, )