diff --git a/flax/core/lift.py b/flax/core/lift.py index 72ffc2c54..e54f1c7d6 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -35,6 +35,9 @@ from .scope import Scope, KindFilter, in_kind_filter, group_kinds from .named_call import named_call_p +from . import unified_transforms +from .unified_transforms import broadcast + scan_variable_modes = set(['carry', 'broadcast', 'scan', None]) ScanVariableMode = Union[str, Tuple[str, str]] @@ -99,8 +102,8 @@ def repack(inner_scope_tree): inner_scope.invalidate() inner_scope._validate_trace_level() mutable_variables = {key: val for key, val - in inner_scope._variables.items() - if not isinstance(val, FrozenDict)} + in inner_scope._variables.items() + if not isinstance(val, FrozenDict)} out_variable_groups = group_kinds( mutable_variables, tuple(out_variable_filters) + (True,)) remainder = tuple(out_variable_groups[-1].keys()) @@ -242,143 +245,71 @@ def mapped(variable_groups_xs, rng_groups_xs, args): inner, variable_in_groups, variable_out_groups, rng_groups) -def scan( - fn: Callable[..., Any], scope: 'Scope', init_carry: Any, xs: Any, - variable_modes: Mapping[KindFilter, ScanVariableMode], - split_rngs: Mapping[KindFilter, bool], - length: Optional[int] = None, reverse: bool = False) -> Callable[..., Any]: - """Wraps jax.lax.scan.""" - # TODO(jheek) scan can be simplified dramatically be making a version of lax.scan - # that follows an in_axis, out_axis api. - if length is None: - length, = set(x.shape[0] for x in jax.tree_leaves(xs)) - variable_groups, variable_modes = _unzip2(variable_modes.items()) - - def parse_mode(mode): - if isinstance(mode, str): - mode = (mode, mode) - mode_in, mode_out = mode - if mode_in not in scan_variable_modes or mode_out not in scan_variable_modes: - raise ValueError(f'illegal scan variable mode: {mode}') - return mode - variable_modes = tuple(parse_mode(m) for m in variable_modes) - +def scan(fn: Callable[..., Any], + variable_in_axes: Mapping[KindFilter, Any] = {}, + variable_out_axes: Mapping[KindFilter, Any] = {}, + variable_carry: KindFilter = False, + split_rngs: Mapping[KindFilter, bool] = {}, + in_axes=0, out_axes=0, + length: Optional[int] = None, + reverse: bool = False) -> Callable[..., Any]: + """Wraps jax.vmap.""" + variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) + variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) rng_groups, rng_splits = _unzip2(split_rngs.items()) - variable_in_groups = tuple( - False if mode[0] is None else group - for group, mode in zip(variable_groups, variable_modes)) - variable_out_groups = tuple( - False if mode[1] is None else group - for group, mode in zip(variable_groups, variable_modes)) - - def split(variable_groups_xs, i): - scan_vars_xs = [] - carry_vars_xs = [] - broadcast_vars_xs = [] - for variable_groups in variable_groups_xs: - scan_vars = tuple( - group if mode[i] == 'scan' else {} - for group, mode in zip(variable_groups, variable_modes)) - carry_vars = tuple( - group if mode[i] == 'carry' else {} - for group, mode in zip(variable_groups, variable_modes)) - broadcast_vars = tuple( - group if mode[i] == 'broadcast' else {} - for group, mode in zip(variable_groups, variable_modes)) - scan_vars_xs.append(scan_vars) - carry_vars_xs.append(carry_vars) - broadcast_vars_xs.append(broadcast_vars) - return tuple(scan_vars_xs), tuple(carry_vars_xs), tuple(broadcast_vars_xs) - - def combine(*variable_groups_xs): - combined_groups_xs = [] - for groups_xs in zip(*variable_groups_xs): - combined_groups = [] - for groups in zip(*groups_xs): - result = {} - for group in groups: - result.update(group) - combined_groups.append(result) - combined_groups_xs.append(tuple(combined_groups)) - return tuple(combined_groups_xs) - - def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs): + rng_axes = tuple(0 if rng_split else broadcast for rng_split in rng_splits) + + def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, init, *args): + def find_length(axis, x): + if axis is not None: + leaves = jax.tree_leaves(x) + if leaves: + return leaves[0].shape[axis] + return () # split rngs - split_fn = lambda rng: random.split(rng, length) - broadcast_rngs_xs = [] - scan_rngs_xs = [] - for rng_groups in rng_groups_xs: - broadcast_rngs_xs.append(tuple( - rng_group for rng_group, split - in zip(rng_groups, rng_splits) if not split)) - scan_rngs_xs.append(tuple( - jax.tree_map(split_fn, rng_group) - for rng_group, split in zip(rng_groups, rng_splits) if split)) - - def body(carry, xs, init_mode=False): - carry_vars_xs, c = carry - scan_vars_xs, scan_rngs_xs, x = xs - variable_groups_xs = combine(scan_vars_xs, carry_vars_xs, broadcast_vars_xs) - rng_groups_xs = [] - for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs, scan_rngs_xs): - rng_groups_xs.append(broadcast_rngs + scan_rngs) + lengths = jax.tree_multimap(find_length, in_axes, args) + if length is None: + d_length, = set(jax.tree_leaves(lengths)) + else: + d_length = length + split_fn = lambda rng: random.split(rng, d_length) + + def split_rngs(rng_groups): + return tuple( + jax.tree_map(split_fn, rng_group) if split else rng_group + for rng_group, split in zip(rng_groups, rng_splits)) + + rng_groups_xs = tuple(map(split_rngs, rng_groups_xs)) + + n = len(variable_groups_xs) + variable_in_axes_xs = (variable_in_axes,) * n + variable_out_axes_xs = (variable_out_axes,) * n + rng_axes_xs = (rng_axes,) * n + + @functools.partial(unified_transforms.scan, + in_axes=(variable_in_axes_xs, rng_axes_xs, in_axes), + out_axes=(out_axes, variable_out_axes_xs), + length=length, reverse=reverse) + def scanned(carry, variable_groups_xs, rng_groups_xs, args): + carry_vars, c = carry scope = scope_fn(variable_groups_xs, rng_groups_xs) - carry, y = fn(scope, c, x) - out_vars = repack_fn(scope) - scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split(out_vars, 1) - - # TODO(jheek) more informative error check - def check_shapes(c_in, c_out): - if not isinstance(c_in, jnp.ndarray) or not isinstance(c_out, jnp.ndarray): - return - if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype(c_in) != jnp.dtype(c_out): - raise ValueError() - try: - jax.tree_multimap(check_shapes, carry_vars_xs, carry_vars_out_xs) - except ValueError: - raise ValueError('carry variables must have the same shape and dtype before and after scan.') - - if init_mode: - return broadcast_vars_out_xs - else: - return (carry_vars_out_xs, carry), (scan_vars_xs, y) - broadcast_body = functools.partial(body, init_mode=True) - - scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split(variable_groups_xs, 0) - carry0 = (carry_vars_xs, init_carry) - xxs = (scan_vars_xs, scan_rngs_xs, xs) - - # use partial evaluation to find the variables that are broadcasted out - # an error is thrown if a broadcasted output has a dependency on any scan variables - carry_pvals = jax.tree_map( - lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)), - carry0) - scan_pvals = jax.tree_map( - lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape[1:], x.dtype)), - xxs) - input_pvals = (carry_pvals, scan_pvals) - in_pvals, in_tree = jax.tree_flatten(input_pvals) - f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(broadcast_body), in_tree) - - _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) - # _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True) - - out_flat = [] - for pv, const in out_pvals: - if pv is not None: - raise ValueError('broadcasted variable has a data dependency on the scan body.') - out_flat.append(const) - - (carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan( - body, carry0, xxs, length=length, reverse=reverse) - - broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat) - - out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs) - return (carry, ys), out_vars_xs + c, y = fn(scope, c, *args) + out_vars_xs = repack_fn(scope) + out_vars_xs_t = tuple(zip(*out_vars_xs)) + carry_vars = out_vars_xs_t[0] + scan_vars = out_vars_xs_t[1:] + return (carry_vars, c), (y, scan_vars) + + variable_groups_xs_t = tuple(zip(*variable_groups_xs)) + carry_vars = variable_groups_xs_t[0] + scan_vars = variable_groups_xs_t[1:] + (carry_vars, c), (ys, scan_vars) = scanned((carry_vars, init), scan_vars, rng_groups_xs, args) + out_vars_xs_t = (carry_vars,) + scan_vars + out_vars_xs = tuple(zip(*out_vars_xs_t)) + return (c, ys), out_vars_xs return pack( - inner, variable_in_groups, variable_out_groups, rng_groups)(scope) + inner, (variable_carry,) + variable_in_groups, (variable_carry,) + variable_out_groups, rng_groups) def custom_vjp(module_fn: Callable[..., Any], backward_fn: Callable[..., Any], @@ -466,27 +397,34 @@ def jitted(variable_groups_xs, rng_groups_xs, *args): def remat_scan(body_fn: Callable[..., Any], scope: Scope, carry: Any, lengths: Sequence[int], - variable_modes: Mapping[KindFilter, ScanVariableMode], - split_rngs: Mapping[KindFilter, bool]): + variable_carry: KindFilter = False, + variable_in_axes: Mapping[KindFilter, Any] = {}, + variable_out_axes: Mapping[KindFilter, Any] = {}, + split_rngs: Mapping[KindFilter, bool] = {}): # TODO(jheek) should remat scan have scan inputs/outputs? if len(lengths) == 1: - def wrapper(scope, carry, _): + def wrapper(scope, carry): return body_fn(scope, carry), () carry, _ = scan( - wrapper, scope, carry, (), + wrapper, length=lengths[0], - variable_modes=variable_modes, - split_rngs=split_rngs) + variable_carry=variable_carry, + variable_in_axes=variable_in_axes, + variable_out_axes=variable_out_axes, + split_rngs=split_rngs)(scope, carry) else: @remat - def inner_loop(scope, carry, _): - carry = remat_scan(body_fn, scope, carry, lengths[1:], variable_modes, split_rngs) + def inner_loop(scope, carry): + carry = remat_scan(body_fn, scope, carry, lengths[1:], + variable_carry, variable_in_axes, variable_out_axes, split_rngs) return carry, () carry, _ = scan( - inner_loop, scope, carry, (), + inner_loop, length=lengths[0], - variable_modes=variable_modes, - split_rngs=split_rngs) + variable_carry=variable_carry, + variable_in_axes=variable_in_axes, + variable_out_axes=variable_out_axes, + split_rngs=split_rngs)(scope, carry) return carry diff --git a/flax/core/tests/scope_test.py b/flax/core/tests/scope_test.py index 6ed742ef3..ff5f89cac 100644 --- a/flax/core/tests/scope_test.py +++ b/flax/core/tests/scope_test.py @@ -32,6 +32,8 @@ def f(scope): init(f)(random.PRNGKey(0)) + + if __name__ == '__main__': absltest.main() diff --git a/flax/core/unified_transforms.py b/flax/core/unified_transforms.py index aa217854b..299c9f3f3 100644 --- a/flax/core/unified_transforms.py +++ b/flax/core/unified_transforms.py @@ -13,34 +13,113 @@ # limitations under the License. from dataclasses import dataclass +import functools import jax +import jax.numpy as jnp from jax import lax +from jax.interpreters import partial_eval as pe +from jax import linear_util as lu + from typing import Union, Optional, Callable, Any +import numpy as np + + @dataclass(frozen=True) class Scan: axis: int ScanAxis = Optional[int] +class _Broadcast: + pass + +broadcast = _Broadcast() + def scan( fn: Callable[..., Any], - scan_in_axis: Any, - scan_out_axis: Any): + in_axes: Any, + out_axes: Any, + length: Optional[int] = None, + reverse: bool = False): + + def transpose_to_front(axis, xs): + if axis is broadcast: + return () + def trans(x): + perm = tuple(range(x.ndim)) + perm = (axis,) + tuple(np.delete(perm, axis)) + return jnp.transpose(x, perm) + return jax.tree_map(trans, xs) - def body_fn(c, x): - jax.tree_multimap() - c, y = fn(c, x) - return c, y + def transpose_from_front(axis, xs): + if axis is broadcast: + return () + def trans(x): + if axis < 0: + ax = x.ndim - axis + else: + ax = axis + assert ax < x.ndim + perm = tuple(range(1, ax + 1)) + (0,) + tuple(range(ax + 1, x.ndim)) + return jnp.transpose(x, perm) + return jax.tree_map(trans, xs) + def scan_fn(init, *args): + xs = jax.tree_multimap(transpose_to_front, in_axes, args) + def body_fn(c, xs, init_mode=False): + # inject constants + xs = jax.tree_multimap(lambda ax, arg, x: (arg if ax is broadcast else x), + in_axes, args, xs) + c, ys = fn(c, *xs) + + if init_mode: + ys = jax.tree_multimap(lambda ax, y: (y if ax is broadcast else ()), + out_axes, ys) + return ys + else: + ys = jax.tree_multimap(lambda ax, y: (() if ax is broadcast else y), + out_axes, ys) + return c, ys + return c, ys + broadcast_body = functools.partial(body_fn, init_mode=True) - def scan_fn(init, *args, - length: Optional[int] = None, reverse: bool = False): + carry_pvals = jax.tree_map( + lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), + init) + scan_pvals = jax.tree_map( + lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))), + xs) + input_pvals = (carry_pvals, scan_pvals) + in_pvals, in_tree = jax.tree_flatten(input_pvals) + f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(broadcast_body), in_tree) + _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals) - return lax.scan(body_fn, init, args, length=length, reverse=reverse) + out_flat = [] + for pv, const in out_pvals: + if pv is not None: + raise ValueError('broadcasted variable has a data dependency on the scan body.') + out_flat.append(const) + constants_out = jax.tree_unflatten(out_tree(), out_flat) + + c, ys = lax.scan(body_fn, init, xs, length=length, reverse=reverse) + ys = jax.tree_multimap(transpose_from_front, out_axes, ys) + ys = jax.tree_multimap(lambda ax, const, y: (const if ax is broadcast else y), + out_axes, constants_out, ys) + return c, ys return scan_fn + + +# def loop(c, x, y): +# print(c, x, y) +# return c + 1, (x * 2, y * 2) + + +# f = scan(loop, in_axes=(broadcast, 1), out_axes=(broadcast, 1)) +# c, (xs, ys) = f(0, 1., jnp.arange(3)[None]) +# print(c, xs, ys) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 273e78932..3fb2ee486 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -28,6 +28,6 @@ from .pooling import avg_pool, max_pool from .recurrent import GRUCell, LSTMCell from .stochastic import Dropout -from .transforms import jit, named_call, remat, scan, vmap +from .transforms import jit, named_call, remat, scan, vmap, broadcast # pylint: enable=g-multiple-import diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 6d1384189..bb0ce04bc 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -21,6 +21,7 @@ from flax.linen.module import wrap_method import jax +from flax.core.lift import broadcast # Utils # ----------------------------------------------------------------------------- @@ -199,101 +200,4 @@ def lift_transform(transform, target, *trafo_args, methods=None, **trafo_kwargs) vmap = functools.partial(lift_transform, lift.vmap) jit = functools.partial(lift_transform, lift.jit) remat = functools.partial(lift_transform, lift.remat) - - -# Scan specific class lifting -# ----------------------------------------------------------------------------- -def module_class_scan_transform( - module_class, - *trafo_args, - methods=None, - **trafo_kwargs): - # Prepare per-method transform args, kwargs. - if methods is None: - # Default case, just transform __call__ - class_trafo_args = {'__call__': (trafo_args, trafo_kwargs)} - elif isinstance(methods, (list, tuple)): - # Transform every method in methods with given args, kwargs. - class_trafo_args = {m: (trafo_args, trafo_kwargs) for m in methods} - elif isinstance(methods, dict): - # Pass different trafo args per each method. - assert trafo_args == () and trafo_kwargs == {}, ( - f"""When passing different scan args per method, - all args must be passed via methods kwarg.""") - class_trafo_args = {k: ((), v) for k, v in methods.items()} - - # Build the actual transformed class. - transformed_fns = {} - # for each of the specified methods: - for fn_name, fn_trafo_args in class_trafo_args.items(): - # get existing unbound method from class - fn = getattr(module_class, fn_name) - trafo_args, trafo_kwargs = fn_trafo_args - # we need to create a scope-function from our class for the given method - @functools.wraps(fn) - def wrapped_fn(self, *args, **kwargs): - if len(args) != 2: - raise ValueError('scan requires a Module taking two arguments, ' - 'a carry and an input.') - # make a scope-function to transform - def core_fn(scopes, *args_inner): - # make a clone of self using its arguments - attrs = {f.name: getattr(self, f.name) - for f in dataclasses.fields(self) if f.name != 'parent'} - # we reference module_class, not self.__class__ to avoid infinite loop - cloned = module_class(parent=None, **attrs) - cloned = set_module_scopes(cloned, scopes) - res = getattr(cloned, fn_name)(*args_inner, **kwargs) - # preserve submodule-tree stripped of scopes/tracers for introspection - object.__setattr__(self, 'children', clean_clone(cloned).children) - return res - # here we apply the given lifting transform to the scope-ingesting fn - return lift.scan(core_fn, get_module_scopes(self), - *args, *trafo_args, **trafo_kwargs) - transformed_fns[fn_name] = wrapped_fn - # construct new dynamic class w. transformed methods - return type('Scan' + module_class.__name__, - (module_class,), - transformed_fns) - - -# Scan as decorator on methods __inside__ class definition. -# ----------------------------------------------------------------------------- -def decorator_scan_transform(class_fn, *trafo_args, **trafo_kwargs): - # NB: due to the ordering of method decorators, we must re-wrap the class_fn - # to maintain Module state correctly for multiple invocations. If we want to - # save another stacktrace entry we could instead replicate its logic below. - rewrapped_fn = wrap_method(class_fn) - @functools.wraps(class_fn) - def wrapped_fn(self, *args, **kwargs): - if len(args) != 2: - raise ValueError('scan requires a Module taking two arguments, ' - 'a carry and an input.') - # make a scope-function to transform - def core_fn(scopes, *args_inner): - cloned = set_module_scopes(self, scopes) - res = rewrapped_fn(cloned, *args_inner, **kwargs) - # preserve submodule-tree stripped of scopes/tracers for introspection - object.__setattr__(self, 'children', clean_clone(cloned).children) - return res - # here we apply the given lifting transform to the scope-ingesting fn - return lift.scan(core_fn, get_module_scopes(self), - *args, *trafo_args, **trafo_kwargs) - return wrapped_fn - - -# scan wraps a class or used as decorator in def of class method. -# ----------------------------------------------------------------------------- -def scan(target, *trafo_args, methods=None, **trafo_kwargs): - """Applies scan to a Module class or as a decorator on class functions.""" - if inspect.isclass(target) and issubclass(target, Module): - return module_class_scan_transform( - target, *trafo_args, methods=methods, **trafo_kwargs) - # we presume this is being used as a function decorator in class definition - elif inspect.isfunction(target): - return decorator_scan_transform( - target, *trafo_args, **trafo_kwargs) - else: - raise ValueError( - 'Can only scan a Module subclass or decorate a function' - ' with scan in class definition.') +scan = functools.partial(lift_transform, lift.scan) diff --git a/linen_examples/core_design_test/big_resnets.py b/linen_examples/core_design_test/big_resnets.py index ff8f872fb..1d15cbaaf 100644 --- a/linen_examples/core_design_test/big_resnets.py +++ b/linen_examples/core_design_test/big_resnets.py @@ -59,7 +59,8 @@ def body_fn(scope, x): return lift.remat_scan( body_fn, scope, x, lengths=blocks, - variable_modes={'param': 'scan', 'batch_stats': 'scan'}, + variable_in_axes={'param': 0, 'batch_stats': 0}, + variable_out_axes={'param': 0, 'batch_stats': 0}, split_rngs={'param': True}) if __name__ == "__main__": diff --git a/linen_examples/core_design_test/scan.py b/linen_examples/core_design_test/scan.py index 6fc77ef44..cea202113 100644 --- a/linen_examples/core_design_test/scan.py +++ b/linen_examples/core_design_test/scan.py @@ -38,14 +38,18 @@ def body_fn(scope, c, x): if share_params: carry, ys = lift.scan( - body_fn, scope, (), xs, - variable_modes={'param': 'broadcast', 'counter': 'carry'}, - split_rngs={'param': False}) + body_fn, + variable_carry='counter', + variable_in_axes={'param': lift.broadcast}, + variable_out_axes={'param': lift.broadcast}, + split_rngs={'param': False})(scope, (), xs) else: carry, ys = lift.scan( - body_fn, scope, (), xs, - variable_modes={'param': 'scan', 'counter': 'carry'}, - split_rngs={'param': True}) + body_fn, + variable_carry='counter', + variable_in_axes={'param': 0}, + variable_out_axes={'param': 0}, + split_rngs={'param': True})(scope, (), xs) # output layer return carry, ys diff --git a/linen_examples/seq2seq/train.py b/linen_examples/seq2seq/train.py index c4a5dd185..c45eb5d29 100644 --- a/linen_examples/seq2seq/train.py +++ b/linen_examples/seq2seq/train.py @@ -168,7 +168,8 @@ class EncoderLSTM(nn.Module): @functools.partial( nn.transforms.scan, - variable_modes={'param': 'broadcast'}, + variable_in_axes={'param': nn.broadcast}, + variable_out_axes={'param': nn.broadcast}, split_rngs={'param': False}) @nn.compact def __call__(self, carry, x): @@ -215,7 +216,8 @@ class DecoderLSTM(nn.Module): @functools.partial( nn.transforms.scan, - variable_modes={'param': 'broadcast'}, + variable_in_axes={'param': nn.broadcast}, + variable_out_axes={'param': nn.broadcast}, split_rngs={'param': False}) @nn.compact def __call__(self, carry, x): diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 714145fb9..8994afd0c 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -161,7 +161,8 @@ class SimpleScan(nn.Module): @nn.compact def __call__(self, c, xs): LSTM = nn.scan(nn.LSTMCell, - variable_modes={'param': 'broadcast'}, + variable_in_axes={'param': nn.broadcast}, + variable_out_axes={'param': nn.broadcast}, split_rngs={'param': False}) return LSTM(name="lstm_cell")(c, xs) @@ -190,7 +191,8 @@ def __call__(self, c, xs): def test_scan_decorated(self): class SimpleScan(nn.Module): @partial(nn.scan, - variable_modes={'param': 'broadcast'}, + variable_in_axes={'param': nn.broadcast}, + variable_out_axes={'param': nn.broadcast}, split_rngs={'param': False}) @nn.compact def __call__(self, c, xs):