Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 83 additions & 145 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from .scope import Scope, KindFilter, in_kind_filter, group_kinds
from .named_call import named_call_p

from . import unified_transforms
from .unified_transforms import broadcast

scan_variable_modes = set(['carry', 'broadcast', 'scan', None])

ScanVariableMode = Union[str, Tuple[str, str]]
Expand Down Expand Up @@ -99,8 +102,8 @@ def repack(inner_scope_tree):
inner_scope.invalidate()
inner_scope._validate_trace_level()
mutable_variables = {key: val for key, val
in inner_scope._variables.items()
if not isinstance(val, FrozenDict)}
in inner_scope._variables.items()
if not isinstance(val, FrozenDict)}
out_variable_groups = group_kinds(
mutable_variables, tuple(out_variable_filters) + (True,))
remainder = tuple(out_variable_groups[-1].keys())
Expand Down Expand Up @@ -242,143 +245,71 @@ def mapped(variable_groups_xs, rng_groups_xs, args):
inner, variable_in_groups, variable_out_groups, rng_groups)


def scan(
fn: Callable[..., Any], scope: 'Scope', init_carry: Any, xs: Any,
variable_modes: Mapping[KindFilter, ScanVariableMode],
split_rngs: Mapping[KindFilter, bool],
length: Optional[int] = None, reverse: bool = False) -> Callable[..., Any]:
"""Wraps jax.lax.scan."""
# TODO(jheek) scan can be simplified dramatically be making a version of lax.scan
# that follows an in_axis, out_axis api.
if length is None:
length, = set(x.shape[0] for x in jax.tree_leaves(xs))
variable_groups, variable_modes = _unzip2(variable_modes.items())

def parse_mode(mode):
if isinstance(mode, str):
mode = (mode, mode)
mode_in, mode_out = mode
if mode_in not in scan_variable_modes or mode_out not in scan_variable_modes:
raise ValueError(f'illegal scan variable mode: {mode}')
return mode
variable_modes = tuple(parse_mode(m) for m in variable_modes)

def scan(fn: Callable[..., Any],
variable_in_axes: Mapping[KindFilter, Any] = {},
variable_out_axes: Mapping[KindFilter, Any] = {},
variable_carry: KindFilter = False,
split_rngs: Mapping[KindFilter, bool] = {},
in_axes=0, out_axes=0,
length: Optional[int] = None,
reverse: bool = False) -> Callable[..., Any]:
"""Wraps jax.vmap."""
variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items())
variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items())
rng_groups, rng_splits = _unzip2(split_rngs.items())
variable_in_groups = tuple(
False if mode[0] is None else group
for group, mode in zip(variable_groups, variable_modes))
variable_out_groups = tuple(
False if mode[1] is None else group
for group, mode in zip(variable_groups, variable_modes))

def split(variable_groups_xs, i):
scan_vars_xs = []
carry_vars_xs = []
broadcast_vars_xs = []
for variable_groups in variable_groups_xs:
scan_vars = tuple(
group if mode[i] == 'scan' else {}
for group, mode in zip(variable_groups, variable_modes))
carry_vars = tuple(
group if mode[i] == 'carry' else {}
for group, mode in zip(variable_groups, variable_modes))
broadcast_vars = tuple(
group if mode[i] == 'broadcast' else {}
for group, mode in zip(variable_groups, variable_modes))
scan_vars_xs.append(scan_vars)
carry_vars_xs.append(carry_vars)
broadcast_vars_xs.append(broadcast_vars)
return tuple(scan_vars_xs), tuple(carry_vars_xs), tuple(broadcast_vars_xs)

def combine(*variable_groups_xs):
combined_groups_xs = []
for groups_xs in zip(*variable_groups_xs):
combined_groups = []
for groups in zip(*groups_xs):
result = {}
for group in groups:
result.update(group)
combined_groups.append(result)
combined_groups_xs.append(tuple(combined_groups))
return tuple(combined_groups_xs)

def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs):
rng_axes = tuple(0 if rng_split else broadcast for rng_split in rng_splits)

def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, init, *args):
def find_length(axis, x):
if axis is not None:
leaves = jax.tree_leaves(x)
if leaves:
return leaves[0].shape[axis]
return ()
# split rngs
split_fn = lambda rng: random.split(rng, length)
broadcast_rngs_xs = []
scan_rngs_xs = []
for rng_groups in rng_groups_xs:
broadcast_rngs_xs.append(tuple(
rng_group for rng_group, split
in zip(rng_groups, rng_splits) if not split))
scan_rngs_xs.append(tuple(
jax.tree_map(split_fn, rng_group)
for rng_group, split in zip(rng_groups, rng_splits) if split))

def body(carry, xs, init_mode=False):
carry_vars_xs, c = carry
scan_vars_xs, scan_rngs_xs, x = xs
variable_groups_xs = combine(scan_vars_xs, carry_vars_xs, broadcast_vars_xs)
rng_groups_xs = []
for broadcast_rngs, scan_rngs in zip(broadcast_rngs_xs, scan_rngs_xs):
rng_groups_xs.append(broadcast_rngs + scan_rngs)
lengths = jax.tree_multimap(find_length, in_axes, args)
if length is None:
d_length, = set(jax.tree_leaves(lengths))
else:
d_length = length
split_fn = lambda rng: random.split(rng, d_length)

def split_rngs(rng_groups):
return tuple(
jax.tree_map(split_fn, rng_group) if split else rng_group
for rng_group, split in zip(rng_groups, rng_splits))

rng_groups_xs = tuple(map(split_rngs, rng_groups_xs))

n = len(variable_groups_xs)
variable_in_axes_xs = (variable_in_axes,) * n
variable_out_axes_xs = (variable_out_axes,) * n
rng_axes_xs = (rng_axes,) * n

@functools.partial(unified_transforms.scan,
in_axes=(variable_in_axes_xs, rng_axes_xs, in_axes),
out_axes=(out_axes, variable_out_axes_xs),
length=length, reverse=reverse)
def scanned(carry, variable_groups_xs, rng_groups_xs, args):
carry_vars, c = carry
scope = scope_fn(variable_groups_xs, rng_groups_xs)
carry, y = fn(scope, c, x)
out_vars = repack_fn(scope)
scan_vars_xs, carry_vars_out_xs, broadcast_vars_out_xs = split(out_vars, 1)

# TODO(jheek) more informative error check
def check_shapes(c_in, c_out):
if not isinstance(c_in, jnp.ndarray) or not isinstance(c_out, jnp.ndarray):
return
if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype(c_in) != jnp.dtype(c_out):
raise ValueError()
try:
jax.tree_multimap(check_shapes, carry_vars_xs, carry_vars_out_xs)
except ValueError:
raise ValueError('carry variables must have the same shape and dtype before and after scan.')

if init_mode:
return broadcast_vars_out_xs
else:
return (carry_vars_out_xs, carry), (scan_vars_xs, y)
broadcast_body = functools.partial(body, init_mode=True)

scan_vars_xs, carry_vars_xs, broadcast_vars_xs = split(variable_groups_xs, 0)
carry0 = (carry_vars_xs, init_carry)
xxs = (scan_vars_xs, scan_rngs_xs, xs)

# use partial evaluation to find the variables that are broadcasted out
# an error is thrown if a broadcasted output has a dependency on any scan variables
carry_pvals = jax.tree_map(
lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype)),
carry0)
scan_pvals = jax.tree_map(
lambda x: pe.PartialVal.unknown(jax.ShapedArray(x.shape[1:], x.dtype)),
xxs)
input_pvals = (carry_pvals, scan_pvals)
in_pvals, in_tree = jax.tree_flatten(input_pvals)
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(lu.wrap_init(broadcast_body), in_tree)

_, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals)
# _, out_pvals, _ = pe.trace_to_jaxpr(f_flat, in_pvals, stage_out=True)

out_flat = []
for pv, const in out_pvals:
if pv is not None:
raise ValueError('broadcasted variable has a data dependency on the scan body.')
out_flat.append(const)

(carry_vars_xs, carry), (scan_vars_xs, ys) = lax.scan(
body, carry0, xxs, length=length, reverse=reverse)

broadcast_vars_xs = jax.tree_unflatten(out_tree(), out_flat)

out_vars_xs = combine(carry_vars_xs, scan_vars_xs, broadcast_vars_xs)
return (carry, ys), out_vars_xs
c, y = fn(scope, c, *args)
out_vars_xs = repack_fn(scope)
out_vars_xs_t = tuple(zip(*out_vars_xs))
carry_vars = out_vars_xs_t[0]
scan_vars = out_vars_xs_t[1:]
return (carry_vars, c), (y, scan_vars)

variable_groups_xs_t = tuple(zip(*variable_groups_xs))
carry_vars = variable_groups_xs_t[0]
scan_vars = variable_groups_xs_t[1:]
(carry_vars, c), (ys, scan_vars) = scanned((carry_vars, init), scan_vars, rng_groups_xs, args)
out_vars_xs_t = (carry_vars,) + scan_vars
out_vars_xs = tuple(zip(*out_vars_xs_t))
return (c, ys), out_vars_xs

return pack(
inner, variable_in_groups, variable_out_groups, rng_groups)(scope)
inner, (variable_carry,) + variable_in_groups, (variable_carry,) + variable_out_groups, rng_groups)


def custom_vjp(module_fn: Callable[..., Any], backward_fn: Callable[..., Any],
Expand Down Expand Up @@ -466,27 +397,34 @@ def jitted(variable_groups_xs, rng_groups_xs, *args):

def remat_scan(body_fn: Callable[..., Any], scope: Scope, carry: Any,
lengths: Sequence[int],
variable_modes: Mapping[KindFilter, ScanVariableMode],
split_rngs: Mapping[KindFilter, bool]):
variable_carry: KindFilter = False,
variable_in_axes: Mapping[KindFilter, Any] = {},
variable_out_axes: Mapping[KindFilter, Any] = {},
split_rngs: Mapping[KindFilter, bool] = {}):
# TODO(jheek) should remat scan have scan inputs/outputs?
if len(lengths) == 1:
def wrapper(scope, carry, _):
def wrapper(scope, carry):
return body_fn(scope, carry), ()
carry, _ = scan(
wrapper, scope, carry, (),
wrapper,
length=lengths[0],
variable_modes=variable_modes,
split_rngs=split_rngs)
variable_carry=variable_carry,
variable_in_axes=variable_in_axes,
variable_out_axes=variable_out_axes,
split_rngs=split_rngs)(scope, carry)
else:
@remat
def inner_loop(scope, carry, _):
carry = remat_scan(body_fn, scope, carry, lengths[1:], variable_modes, split_rngs)
def inner_loop(scope, carry):
carry = remat_scan(body_fn, scope, carry, lengths[1:],
variable_carry, variable_in_axes, variable_out_axes, split_rngs)
return carry, ()
carry, _ = scan(
inner_loop, scope, carry, (),
inner_loop,
length=lengths[0],
variable_modes=variable_modes,
split_rngs=split_rngs)
variable_carry=variable_carry,
variable_in_axes=variable_in_axes,
variable_out_axes=variable_out_axes,
split_rngs=split_rngs)(scope, carry)
return carry


Expand Down
2 changes: 2 additions & 0 deletions flax/core/tests/scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def f(scope):

init(f)(random.PRNGKey(0))




if __name__ == '__main__':
absltest.main()
Loading