Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 84 additions & 29 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _attr_repr(value: Any):
value_rep = repr(value)
return value_rep


def _module_repr(module: 'Module', num_spaces: int = 4):
"""Returns a pretty printed representation of the module"""
cls = type(module)
Expand All @@ -91,6 +92,7 @@ def _module_repr(module: 'Module', num_spaces: int = 4):
else:
return f'{cls_name}()'


# Track parent relationship across Modules.
# -----------------------------------------------------------------------------
class _DynamicContext:
Expand Down Expand Up @@ -205,23 +207,33 @@ def _get_local_method_names(cls: Any, exclude: Iterable[str] = ()) -> Tuple[str]
return tuple(true_methods.difference(set(exclude)))


def wrap_method(fun: Callable[..., Any]) -> Callable[..., Any]:
def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]:
"""Manages Module state for a given user-defined method.

Args:
fun: User-defined Module method to manage state for.
Returns:
Wrapped method.
"""
# Don't rewrap methods that have already had the state management wrapper
# applied in the decorator stack. This wrapper should always be applied
# before transformation wrappers.
if hasattr(fun, 'method_handler_wrapped'):
return fun

@functools.wraps(fun)
def wrapped_module_method(self, *args, **kwargs):
is_compact_method = hasattr(fun, 'compact')
is_setup_method = fun.__name__ == 'setup'
# We lazily call setup() only when needed.
if not is_setup_method:
if not self._state.setup_called:
self.setup()
self._state.setup_called = True

if is_compact_method:
if self.scope is None:
raise ValueError("Can't call compact methods on unbound modules")

self._state.in_compact_method = True
elif is_setup_method:
self._state.in_setup = True
Expand All @@ -234,7 +246,7 @@ def wrapped_module_method(self, *args, **kwargs):
object.__setattr__(self, 'scope', self.scope.rewound())
if is_compact_method or is_setup_method:
self._state.reset()

wrapped_module_method.method_handler_wrapped = True
return wrapped_module_method


Expand Down Expand Up @@ -271,24 +283,46 @@ class _ModuleInternalState:
"""Ephemeral Module Evaluation State.

For clarity, we collect all of the temporary flags and ephemeral state used by
Modules for autonaming and error messages here.
Modules for autonaming and error messages here, alongside the rules used
to pass this ephemeral state across transform boundaries.
"""
in_compact_method: bool = False
in_setup: bool = False
setup_called: bool = False
last_varname: Optional[str] = None
autoname_cursor: Optional[dict] = dataclasses.field(default_factory=dict)
frozen: bool = False
children: Dict[str, Union[str, 'Module']] = dataclasses.field(default_factory=dict)

def reset(self):
"""Resets transient state."""
self.in_compact_method = False
self.in_setup = False
self.last_varname = None
self.autoname_cursor = dict()

_uninitialized_module_internal_state = _ModuleInternalState(
False, False, None, None)
def export(self):
"""Exports transform-preserved state across transform boundary."""
cloned = _ModuleInternalState(
in_compact_method=self.in_compact_method,
in_setup=self.in_setup,
setup_called=False, # setup_called is object local, not shared.
last_varname=self.last_varname,
autoname_cursor=dict(self.autoname_cursor))
return cloned

def reimport(self, other):
"""Re-imports transform-preserved state from across transform boundary."""
self.in_compact_method = other.in_compact_method
self.in_setup = other.in_setup
self.last_varname = other.last_varname
self.autoname_cursor = dict(other.autoname_cursor)

_uninitialized_module_internal_state = _ModuleInternalState()


_UNDEFINED_COPY_PICKLE_METHODS = (
'__getstate__', '__setstate__', '__getnewargs_ex__',
'__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__')

# Base Module definition.
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -396,11 +430,12 @@ def _wrap_module_methods(cls):
['__eq__', '__repr__', '__init__', '__hash__'])
for key in _get_local_method_names(cls, exclude=exclusions):
method = getattr(cls, key)
wrapped_method = wrap_method_once(method)
if _use_named_call and key != 'setup':
# We import named_call at runtime to avoid a circular import issue.
from flax.linen.transforms import named_call # pylint: disable=g-import-not-at-top
method = named_call(method)
setattr(cls, key, wrap_method(method))
wrapped_method = named_call(wrapped_method)
setattr(cls, key, wrapped_method)
return cls

def __setattr__(self, name: str, val: Any):
Expand All @@ -419,14 +454,14 @@ def __setattr__(self, name: str, val: Any):
name: Attribute to set.
val: Value of the attribute.
"""
if name != '_state' and self._state.frozen:
# raises a TypeError just like frozen python dataclasses
if name != '_state' and self._state.setup_called:
# Raises a TypeError just like frozen python dataclasses.
raise TypeError("Module instance is frozen outside of setup method.")

# We don't mess with the parent module.
if name == 'parent':
pass
# Modules have been passed in as dataclass args.
# Modules have been passed in as dataclass args and set in __init__.
elif name in self.__dataclass_fields__ and self.__dataclass_fields__[name].init: # pytype: disable=attribute-error
pass
# Submodules are being defined and attached in setup()
Expand Down Expand Up @@ -461,6 +496,28 @@ def __setattr__(self, name: str, val: Any):
# Finally, always run default __setattr__ to attach to self.__dict__.
object.__setattr__(self, name, val)

def __getattr__(self, name: str) -> Any:
Comment thread
jheek marked this conversation as resolved.
Outdated
"""Call setup() before getting any setup-defined attributes."""
# We don't want to return anything for python copy / pickle methods.
if name in _UNDEFINED_COPY_PICKLE_METHODS:
raise AttributeError()
# _state have class defaults to prevent infinite loop.
if self.parent and not self._state.setup_called and not self._state.in_setup:
self.setup()
self._state.setup_called = True
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'")

def __dir__(self) -> List[str]:
"""Call setup() before listing attributes."""
if self.parent and not self._state.setup_called and not self._state.in_setup:
self.setup()
self._state.setup_called = True
return object.__dir__(self) # pytype: disable=attribute-error

def __post_init__(self):
_check_omnistaging()
# In dataclasses, __init__ is overridden to process dataclass arguments,
Expand Down Expand Up @@ -513,43 +570,41 @@ def __post_init__(self):
else:
raise ValueError("parent must be None, Module or Scope")

# Call the user-defined initialization setup() function.
self.setup()
self._state.frozen = True

def __repr__(self):
return _module_repr(self)

def setup(self):
"""Initializes a Module (similar to ``__init__`` for non-dataclass Python classes).
"""Initializes a Module lazily (similar to a lazy ``__init__``).

``setup`` is called on a module instance at the moment it is safe to define
or access variables or submodules (once the module is "bound").
``setup`` is called once lazily on a module instance when a module
is bound, immediately before any other methods like ``__call__`` are
invoked, or before a ``setup``-defined attribute on `self` is accessed.

| This happens in three cases:
1. Immediately when invoking :meth:`apply`, :meth:`init` or
This can happen in three cases:

1. Immediately when invoking :meth:`apply`, :meth:`init` or
:meth:`init_and_output`.

2. When the module is given a name by being assigned to an attribute of
2. Once the module is given a name by being assigned to an attribute of
another module inside the other module's ``setup`` method
(see :meth:`__setattr__`)::

class MyModule(nn.Module):
def setup(self):
submodule = Conv(...)

# Accessing `submodule.variables` does not yet work here.
# Accessing `submodule` attributes does not yet work here.

# The following line invokes `self.__setattr__`, which gives
# `submodule` the name "conv1", which calls `submodule.setup`.
# `submodule` the name "conv1".
self.conv1 = submodule

# Accessing `submodule.variables` is now safe.

3. Immediately when a module is constructed inside a method wrapped with
:meth:`compact`.
# Accessing `submodule` attributes or methods is now safe and
# either causes setup() to be called once.

3. Once a module is constructed inside a method wrapped with
:meth:`compact`, immediately before another method is called or
``setup`` defined attribute is accessed.
"""
pass

Expand Down
39 changes: 18 additions & 21 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
from flax.core import lift, Scope
from flax.linen.module import Module
from flax.linen.module import wrap_method
from flax.linen.module import wrap_method_once
import jax

# Utils
Expand Down Expand Up @@ -134,13 +134,13 @@ def wrapped_fn(self, *args, **kwargs):
def core_fn(scopes, *args, **kwargs):
# make a clone of self using its arguments
attrs = {f.name: getattr(self, f.name)
for f in dataclasses.fields(self) if f.name != 'parent'}
for f in dataclasses.fields(self) if f.name != 'parent' and f.init}
# we reference module_class, not self.__class__ to avoid infinite loop
cloned = module_class(parent=None, **attrs)
cloned = set_module_scopes(cloned, scopes)
cloned._state = copy.deepcopy(self._state) # pylint: disable=protected-access
cloned._state = self._state.export() # pylint: disable=protected-access
res = fn(cloned, *args, **kwargs)
self._state = copy.deepcopy(cloned._state) # pylint: disable=protected-access
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
# here we apply the given lifting transform to the scope-ingesting fn
trafo_fn = transform(core_fn, *trafo_args, **trafo_kwargs)
Expand All @@ -149,7 +149,6 @@ def core_fn(scopes, *args, **kwargs):
return wrapped_fn
transformed_fns = {fn_name: create_trans_fn(fn_name, fn_trafo_args)
for fn_name, fn_trafo_args in class_trafo_args.items()}
transformed_fns['setup'] = lambda _: None
# construct new dynamic class w. transformed methods
transformed_cls = type(transform.__name__.capitalize() + module_class.__name__,
(module_class,),
Expand All @@ -160,18 +159,17 @@ def core_fn(scopes, *args, **kwargs):
# Function lifting as decorator on methods __inside__ class definition.
# -----------------------------------------------------------------------------
def decorator_lift_transform(transform, class_fn, *trafo_args, **trafo_kwargs):
# NB: due to the ordering of method decorators, we must re-wrap the class_fn
# to maintain Module state correctly for multiple invocations. If we want to
# save another stacktrace entry we could instead replicate its logic below.
rewrapped_fn = wrap_method(class_fn)
@functools.wraps(class_fn)
# Due to the ordering of method decorators, we must wrap the class_fn
# with the module state management wrapper first to maintain Module state correctly.
prewrapped_fn = wrap_method_once(class_fn)
@functools.wraps(prewrapped_fn)
def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned = set_module_scopes(self, scopes)
cloned._state = copy.deepcopy(self._state) # pylint: disable=protected-access
res = rewrapped_fn(cloned, *args, **kwargs)
self._state = copy.deepcopy(cloned._state) # pylint: disable=protected-access
cloned._state = self._state.export() # pylint: disable=protected-access
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
# here we apply the given lifting transform to the scope-ingesting fn
trafo_fn = transform(core_fn, *trafo_args, **trafo_kwargs)
Expand Down Expand Up @@ -206,11 +204,10 @@ def lift_transform(transform, target, *trafo_args, methods=None, **trafo_kwargs)
# Special case of decorator_lift_transform to handle named calls for profiling.
def named_call(class_fn):
"""Labels a method for labelled traces in profiles."""
# NB: due to the ordering of method decorators, we must re-wrap the class_fn
# to maintain Module state correctly for multiple invocations. If we want to
# save another stacktrace entry we could instead replicate its logic below.
rewrapped_fn = wrap_method(class_fn)
@functools.wraps(class_fn)
# Due to the ordering of method decorators, we must wrap the class_fn
# with the module state management wrapper first to maintain Module state correctly.
prewrapped_fn = wrap_method_once(class_fn)
@functools.wraps(prewrapped_fn)
def wrapped_fn(self, *args, **kwargs):
fn_name = class_fn.__name__
method_suffix = f'.{fn_name}' if fn_name != '__call__' else ''
Expand All @@ -219,9 +216,9 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned = set_module_scopes(self, scopes)
cloned._state = copy.deepcopy(self._state) # pylint: disable=protected-access
res = rewrapped_fn(cloned, *args, **kwargs)
self._state = copy.deepcopy(cloned._state) # pylint: disable=protected-access
cloned._state = self._state.export() # pylint: disable=protected-access
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
# here we apply the given lifting transform to the scope-ingesting fn
trafo_fn = lift.named_call(core_fn, full_name)
Expand Down
59 changes: 57 additions & 2 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,10 @@ def test_module_transform_with_setup(self):
class Foo(nn.Module):
def setup(self):
self.test = self.param('test', nn.initializers.ones, ())

def __call__(self, x):
return x * self.test

FooVmap = nn.vmap(Foo, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True})
variables = FooVmap().init(random.PRNGKey(0), jnp.ones((4,)))
self.assertEqual(variables['params']['test'].shape, (4,))
Expand Down Expand Up @@ -555,6 +555,61 @@ def __call__(self, x):
variable_shapes['params']['A_1']['Dense_0']['bias'],
(10, 3))

def test_nested_setup_calls_count(self):
D = 3
N = 4
cntr = 0
class Repeat(nn.Module):
mdl_def: Any
def setup(self):
self.lyrs = [self.mdl_def() for _ in range(N)]
@nn.remat # we just use remat as a convenient test of transform logic
def __call__(self, x):
for lyr in self.lyrs:
lyr(x)
return x
class Counter(nn.Module):
def setup(self):
nonlocal cntr
cntr += 1
self.dense = nn.Dense(2, use_bias=False)
@nn.remat
def __call__(self, x):
return self.dense(x)

def nested_repeat(mdl):
for _ in range(D):
mdl = partial(Repeat, mdl)
return mdl()
_ = nested_repeat(Counter).init(random.PRNGKey(0), jnp.ones((2,)))
self.assertEqual(cntr, 64)

def test_multimethod_setup_calls(self):
cntr=0
class A(nn.Module):
def setup(self):
nonlocal cntr
cntr+=1
self.d = nn.Dense(2)
@nn.remat
def foo(self, x):
return self.d(x)
@nn.remat
def bar(self, x):
return self.d(x)
class B(nn.Module):
def setup(self):
self.a = A()
def __call__(self, x):
y1 = self.a.foo(x)
y2 = self.a.bar(x)
return y1, y2

key = random.PRNGKey(0)
x = jnp.ones((2,))
(y1, y2), _ = B().init_with_output(key, x)
np.testing.assert_array_equal(y1, y2)
self.assertEqual(cntr, 2)

if __name__ == '__main__':
absltest.main()
Loading