Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

from .frozen_dict import FrozenDict, freeze, unfreeze
from .tracers import current_trace, trace_level, check_trace_level
from .scope import in_kind_filter, Scope, Array, apply, init
from .scope import Scope, Array, apply, init
from .lift import scan, vmap, jit
4 changes: 2 additions & 2 deletions flax/core/flax_functional_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@
{
"output_type": "stream",
"name": "stdout",
"text": "FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})\n"
"text": "FrozenDict({'params': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})\n"
},
{
"output_type": "execute_result",
"data": {
"text/plain": "(DeviceArray([[-0.00302252]], dtype=float32),\n FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674, 0.33191404],\n [-0.7799348 , 0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))"
"text/plain": "(DeviceArray([[-0.00302252]], dtype=float32),\n FrozenDict({'params': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674, 0.33191404],\n [-0.7799348 , 0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))"
},
"metadata": {},
"execution_count": 4
Expand Down
77 changes: 38 additions & 39 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .frozen_dict import FrozenDict
from .frozen_dict import unfreeze

from .scope import Scope, KindFilter, in_kind_filter, group_kinds
from .scope import Scope, CollectionFilter, PRNGSequenceFilter, in_filter, group_collections
from .named_call import named_call_p

from . import unified_transforms
Expand Down Expand Up @@ -69,9 +69,9 @@ def _dup_scopes(orig_scopes, scopes, paths):
return scopes

def pack(fn: Callable[..., Any],
in_variable_filters: Sequence[KindFilter],
out_variable_filters: Sequence[KindFilter],
rng_filters: Sequence[KindFilter]) -> Callable[..., Any]:
in_variable_filters: Sequence[CollectionFilter],
out_variable_filters: Sequence[CollectionFilter],
rng_filters: Sequence[PRNGSequenceFilter]) -> Callable[..., Any]:
"""Pack variables and rngs for functional transformations."""
@functools.wraps(fn)
def wrapper(scope: Scope, *args):
Expand All @@ -83,20 +83,20 @@ def wrapper(scope: Scope, *args):

for scope in scopes:
scope._validate_trace_level()
scope._populate_kinds()
variable_groups_xs.append(group_kinds(scope._variables, in_variable_filters))
# Make sure in only variable kinds are frozen
scope._populate_collections()
variable_groups_xs.append(group_collections(scope._variables, in_variable_filters))
# Make sure in only variable collections are frozen
for variable_groups in variable_groups_xs:
for variable_group in variable_groups:
for kind, kind_variables in variable_group.items():
kind_in_out = any(
in_kind_filter(kind_filter, kind)
for kind_filter in out_variable_filters)
if not kind_in_out:
variable_group[kind] = freeze(kind_variables)
for col_name, collection in variable_group.items():
col_in_out = any(
in_filter(col_filter, col_name)
for col_filter in out_variable_filters)
if not col_in_out:
variable_group[col_name] = freeze(collection)
rng_groups_xs = []
for scope in scopes:
rng_groups = group_kinds(scope.rngs, rng_filters)
rng_groups = group_collections(scope.rngs, rng_filters)
for rng_group in rng_groups:
for kind in rng_group:
rng_group[kind] = scope.make_rng(kind)
Expand Down Expand Up @@ -134,7 +134,7 @@ def repack(inner_scope_tree):
mutable_variables = {key: val for key, val
in inner_scope._variables.items()
if not isinstance(val, FrozenDict)}
out_variable_groups = group_kinds(
out_variable_groups = group_collections(
mutable_variables, tuple(out_variable_filters) + (True,))
remainder = tuple(out_variable_groups[-1].keys())
if remainder:
Expand All @@ -149,21 +149,21 @@ def repack(inner_scope_tree):
inner_scope.invalidate()
for scope, out_variable_groups in zip(scopes, out_variable_groups_xs):
for out_variable_group in out_variable_groups:
for kind, kind_variables in out_variable_group.items():
for name, value in kind_variables.items():
scope.put_variable(kind, name, value)
for col_name, collection in out_variable_group.items():
for name, value in collection.items():
scope.put_variable(col_name, name, value)
return y
return wrapper

id_fn = lambda x: x

def transform_module(fn: Callable[..., Any],
target: KindFilter = 'param',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also considering making the "standard" collections into constance, e.g. flax.linen.PARAMS = 'params' and flax.linen.normalization.BATCH_STATS = 'batch_stats'?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably will do in seperate PR though.

target: CollectionFilter = 'params',
trans_in_fn: Callable[..., Any] = id_fn,
trans_out_fn: Callable[..., Any] = id_fn,
init: bool = True, mutable: bool = False,
rngs: KindFilter = True,
variables: KindFilter = True):
rngs: PRNGSequenceFilter = True,
variables: CollectionFilter = True):
def wrapper(scope, *args, **kwargs):
if init:
vs = scope.variables()
Expand All @@ -182,11 +182,11 @@ def wrapper(scope, *args, **kwargs):


def transform(
target: KindFilter,
target: CollectionFilter,
trans_in_fn: Callable[..., Any] = id_fn,
trans_out_fn: Callable[..., Any] = id_fn,
init: bool = False, mutable: bool = False,
rngs: KindFilter = True, variables: KindFilter = True):
rngs: PRNGSequenceFilter = True, variables: CollectionFilter = True):
def wrapper(scope_fn, repack, variable_groups_xs, rng_groups_xs, fn, *args):
assert len(variable_groups_xs) == 1, 'transform does not support multi-scope lifting.'
target, variables = variable_groups_xs[0]
Expand Down Expand Up @@ -231,7 +231,7 @@ class Out(Generic[T]):
axis: T


def _split_in_out_axes(xs: Mapping[KindFilter, Any]):
def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]):
unpack = lambda v: v.axis if isinstance(v, (In, Out)) else v
in_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, Out)}
out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)}
Expand All @@ -243,8 +243,8 @@ def _split_in_out_axes(xs: Mapping[KindFilter, Any]):


def vmap(fn: Callable[..., Any],
variable_axes: Mapping[KindFilter, InOutAxis],
split_rngs: Mapping[KindFilter, bool],
variable_axes: Mapping[CollectionFilter, InOutAxis],
split_rngs: Mapping[PRNGSequenceFilter, bool],
in_axes=0, out_axes=0, axis_size=None) -> Callable[..., Any]:
"""Wraps jax.vmap."""
variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes)
Expand Down Expand Up @@ -301,9 +301,9 @@ def mapped(variable_groups_xs, rng_groups_xs, args):


def scan(fn: Callable[..., Any],
variable_axes: Mapping[KindFilter, InOutScanAxis] = {},
variable_carry: KindFilter = False,
split_rngs: Mapping[KindFilter, bool] = {},
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {},
variable_carry: CollectionFilter = False,
split_rngs: Mapping[PRNGSequenceFilter, bool] = {},
in_axes=0, out_axes=0,
length: Optional[int] = None,
reverse: bool = False) -> Callable[..., Any]:
Expand Down Expand Up @@ -368,7 +368,7 @@ def scanned(carry, variable_groups_xs, rng_groups_xs, args):


def custom_vjp(module_fn: Callable[..., Any], backward_fn: Callable[..., Any],
grad_kind: KindFilter='param',
grad_kind: CollectionFilter='params',
nondiff_argnums=()):
def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, *args):
assert len(variable_groups_xs) == 1, 'transform does not support multi-scope lifting.'
Expand Down Expand Up @@ -409,8 +409,8 @@ def f_bwd(*args):


def remat(fn: Callable[..., Any],
variables: KindFilter = True,
rngs: KindFilter = True) -> Callable[..., Any]:
variables: CollectionFilter = True,
rngs: PRNGSequenceFilter = True) -> Callable[..., Any]:
"""Wraps jax.jit."""
def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args):
@jax.remat
Expand All @@ -428,9 +428,8 @@ def jit(fn: Callable[..., Any],
static_argnums: Union[int, Iterable[int]] = (),
device=None,
backend: Union[str, None] = None,
in_variables: KindFilter = True,
out_variables: KindFilter = True,
rngs: KindFilter = True) -> Callable[..., Any]:
variables: CollectionFilter = True,
rngs: PRNGSequenceFilter = True) -> Callable[..., Any]:
"""Wraps jax.jit."""
if not isinstance(static_argnums, Iterable):
static_argnums = (static_argnums,)
Expand All @@ -447,14 +446,14 @@ def jitted(variable_groups_xs, rng_groups_xs, *args):

return jitted(variable_groups_xs, rng_groups_xs, *args)

return pack(inner, (in_variables,), (out_variables,), (rngs,))
return pack(inner, (variables,), (variables,), (rngs,))


def remat_scan(body_fn: Callable[..., Any], scope: Scope, carry: Any,
lengths: Sequence[int],
variable_carry: KindFilter = False,
variable_axes: Mapping[KindFilter, InOutScanAxis] = {},
split_rngs: Mapping[KindFilter, bool] = {}):
variable_carry: CollectionFilter = False,
variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {},
split_rngs: Mapping[PRNGSequenceFilter, bool] = {}):
# TODO(jheek) should remat scan have scan inputs/outputs?
if len(lengths) == 1:
def wrapper(scope, carry):
Expand Down
Loading