Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 82 additions & 45 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def is_array_ref(x) -> tp.TypeGuard[Ref]:
@dataclasses.dataclass
class VariableMetadata(tp.Generic[A]):
raw_value: A
set_value_hooks: tuple[SetValueHook[A], ...] = ()
get_value_hooks: tuple[GetValueHook[A], ...] = ()
create_value_hooks: tuple[CreateValueHook[A], ...] = ()
add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] = ()
remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] = ()
on_set_value: tuple[SetValueHook[A], ...] = ()
on_get_value: tuple[GetValueHook[A], ...] = ()
on_create_value: tuple[CreateValueHook[A], ...] = ()
on_add_axis: tuple[AddAxisHook[Variable[A]], ...] = ()
on_remove_axis: tuple[RemoveAxisHook[Variable[A]], ...] = ()
metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict)


Expand Down Expand Up @@ -1302,6 +1302,18 @@ def __init__(
)
eager_sharding = aux_metadata['eager_sharding']
metadata.update(aux_metadata)

if 'on_get_value' not in metadata and value.on_get_value:
metadata['on_get_value'] = value.on_get_value
if 'on_set_value' not in metadata and value.on_set_value:
metadata['on_set_value'] = value.on_set_value
if 'on_create_value' not in metadata and value.on_create_value:
metadata['on_create_value'] = value.on_create_value
if 'on_add_axis' not in metadata and value.on_add_axis:
metadata['on_add_axis'] = value.on_add_axis
if 'on_remove_axis' not in metadata and value.on_remove_axis:
metadata['on_remove_axis'] = value.on_remove_axis

value = tp.cast(A, value.raw_value)

if hijax is None:
Expand Down Expand Up @@ -1342,7 +1354,12 @@ def __init__(

# run create_value hooks
if 'on_create_value' in metadata:
value = metadata['on_create_value'](self, value)
hooks = metadata['on_create_value']
if isinstance(hooks, tuple):
for hook in hooks:
value = hook(self, value)
else:
value = hooks(self, value)

object.__setattr__(self, '_raw_value', value)
# run create_value hook
Expand Down Expand Up @@ -1656,13 +1673,23 @@ def get_value(self, *, index: tp.Any = MISSING) -> A:
elif is_array_ref(value):
value = value[...]
if 'on_get_value' in self._var_metadata:
value = self._var_metadata['on_get_value'](self, value)
hooks = self._var_metadata['on_get_value']
if isinstance(hooks, tuple):
for hook in hooks:
value = hook(self, value)
else:
value = hooks(self, value)
return value # type: ignore

def set_value(self, value: A, *, index: tp.Any = MISSING):
value = jax.tree.map(lambda x: x, value) # make a copy
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
hooks = self._var_metadata['on_set_value']
if isinstance(hooks, tuple):
for hook in hooks:
value = hook(self, value)
else:
value = hooks(self, value)
# update _raw_value
if is_array_ref(self._raw_value):
if isinstance(index, Missing):
Expand Down Expand Up @@ -1694,11 +1721,21 @@ def set_value(self, value: A, *, index: tp.Any = MISSING):

def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
hooks = self._var_metadata['on_add_axis']
if isinstance(hooks, tuple):
for hook in hooks:
hook(self, axis_index, axis_name)
else:
hooks(self, axis_index, axis_name)

def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
hooks = self._var_metadata['on_remove_axis']
if isinstance(hooks, tuple):
for hook in hooks:
hook(self, axis_index, axis_name)
else:
hooks(self, axis_index, axis_name)

@tp.overload
def copy(self, value: B, **kwargs) -> Variable[B]: ...
Expand Down Expand Up @@ -2206,69 +2243,69 @@ class Perturbation(Intermediate[A]):

def with_metadata(
initializer: F,
set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (),
get_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (),
create_value_hooks: tp.Union[
on_set_value: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (),
on_get_value: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (),
on_create_value: tp.Union[
CreateValueHook[A], tp.Sequence[CreateValueHook[A]]
] = (),
add_axis_hooks: tp.Union[
on_add_axis: tp.Union[
AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]]
] = (),
remove_axis_hooks: tp.Union[
on_remove_axis: tp.Union[
RemoveAxisHook[Variable[A]],
tp.Sequence[RemoveAxisHook[Variable[A]]],
] = (),
**metadata: tp.Any,
) -> F:
if set_value_hooks:
if callable(set_value_hooks):
set_value_hooks = (set_value_hooks,)
if on_set_value:
if callable(on_set_value):
on_set_value = (on_set_value,)
else:
set_value_hooks = tuple(set_value_hooks)
on_set_value = tuple(on_set_value)
else:
set_value_hooks = ()
on_set_value = ()

if get_value_hooks:
if callable(get_value_hooks):
get_value_hooks = (get_value_hooks,)
if on_get_value:
if callable(on_get_value):
on_get_value = (on_get_value,)
else:
get_value_hooks = tuple(get_value_hooks)
on_get_value = tuple(on_get_value)
else:
get_value_hooks = ()
on_get_value = ()

if create_value_hooks:
if callable(create_value_hooks):
create_value_hooks = (create_value_hooks,)
if on_create_value:
if callable(on_create_value):
on_create_value = (on_create_value,)
else:
create_value_hooks = tuple(create_value_hooks)
on_create_value = tuple(on_create_value)
else:
create_value_hooks = ()
on_create_value = ()

if add_axis_hooks:
if callable(add_axis_hooks):
add_axis_hooks = (add_axis_hooks,)
if on_add_axis:
if callable(on_add_axis):
on_add_axis = (on_add_axis,)
else:
add_axis_hooks = tuple(add_axis_hooks)
on_add_axis = tuple(on_add_axis)
else:
add_axis_hooks = ()
on_add_axis = ()

if remove_axis_hooks:
if callable(remove_axis_hooks):
remove_axis_hooks = (remove_axis_hooks,)
if on_remove_axis:
if callable(on_remove_axis):
on_remove_axis = (on_remove_axis,)
else:
remove_axis_hooks = tuple(remove_axis_hooks)
on_remove_axis = tuple(on_remove_axis)
else:
remove_axis_hooks = ()
on_remove_axis = ()

@functools.wraps(initializer)
def wrapper(*args):
return VariableMetadata(
initializer(*args),
set_value_hooks=set_value_hooks,
get_value_hooks=get_value_hooks,
create_value_hooks=create_value_hooks,
add_axis_hooks=add_axis_hooks,
remove_axis_hooks=remove_axis_hooks,
on_set_value=on_set_value,
on_get_value=on_get_value,
on_create_value=on_create_value,
on_add_axis=on_add_axis,
on_remove_axis=on_remove_axis,
metadata=metadata,
)

Expand Down
Loading