diff --git a/flax/core/__init__.py b/flax/core/__init__.py index 99ff513ee..e62288834 100644 --- a/flax/core/__init__.py +++ b/flax/core/__init__.py @@ -14,5 +14,5 @@ from .frozen_dict import FrozenDict, freeze, unfreeze from .tracers import current_trace, trace_level, check_trace_level -from .scope import in_kind_filter, Scope, Array, apply, init +from .scope import Scope, Array, apply, init from .lift import scan, vmap, jit diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index 7a6a3ba12..e3abbf33e 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -92,12 +92,12 @@ { "output_type": "stream", "name": "stdout", - "text": "FrozenDict({'param': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})\n" + "text": "FrozenDict({'params': FrozenDict({'kernel': DeviceArray([[ 0.15374057, -0.6807397 , -1.3350962 ],\n [ 0.59940743, -0.69430196, -0.7663768 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)})})\n" }, { "output_type": "execute_result", "data": { - "text/plain": "(DeviceArray([[-0.00302252]], dtype=float32),\n FrozenDict({'param': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674, 0.33191404],\n [-0.7799348 , 0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))" + "text/plain": "(DeviceArray([[-0.00302252]], dtype=float32),\n FrozenDict({'params': FrozenDict({'hidden': FrozenDict({'kernel': DeviceArray([[-1.1642578 , -0.04300674, 0.33191404],\n [-0.7799348 , 0.24048047, -0.6054149 ]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}), 'out': FrozenDict({'kernel': DeviceArray([[ 0.21448377],\n [-0.01530595],\n [ 0.14402702]], dtype=float32), 'bias': DeviceArray([0.], dtype=float32)})})}))" }, "metadata": {}, "execution_count": 4 diff --git a/flax/core/lift.py b/flax/core/lift.py index 8a12f2c74..822f78116 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -28,7 +28,7 @@ from .frozen_dict import FrozenDict from .frozen_dict import unfreeze -from .scope import Scope, KindFilter, in_kind_filter, group_kinds +from .scope import Scope, CollectionFilter, PRNGSequenceFilter, in_filter, group_collections from .named_call import named_call_p from . import unified_transforms @@ -69,9 +69,9 @@ def _dup_scopes(orig_scopes, scopes, paths): return scopes def pack(fn: Callable[..., Any], - in_variable_filters: Sequence[KindFilter], - out_variable_filters: Sequence[KindFilter], - rng_filters: Sequence[KindFilter]) -> Callable[..., Any]: + in_variable_filters: Sequence[CollectionFilter], + out_variable_filters: Sequence[CollectionFilter], + rng_filters: Sequence[PRNGSequenceFilter]) -> Callable[..., Any]: """Pack variables and rngs for functional transformations.""" @functools.wraps(fn) def wrapper(scope: Scope, *args): @@ -83,20 +83,20 @@ def wrapper(scope: Scope, *args): for scope in scopes: scope._validate_trace_level() - scope._populate_kinds() - variable_groups_xs.append(group_kinds(scope._variables, in_variable_filters)) - # Make sure in only variable kinds are frozen + scope._populate_collections() + variable_groups_xs.append(group_collections(scope._variables, in_variable_filters)) + # Make sure in only variable collections are frozen for variable_groups in variable_groups_xs: for variable_group in variable_groups: - for kind, kind_variables in variable_group.items(): - kind_in_out = any( - in_kind_filter(kind_filter, kind) - for kind_filter in out_variable_filters) - if not kind_in_out: - variable_group[kind] = freeze(kind_variables) + for col_name, collection in variable_group.items(): + col_in_out = any( + in_filter(col_filter, col_name) + for col_filter in out_variable_filters) + if not col_in_out: + variable_group[col_name] = freeze(collection) rng_groups_xs = [] for scope in scopes: - rng_groups = group_kinds(scope.rngs, rng_filters) + rng_groups = group_collections(scope.rngs, rng_filters) for rng_group in rng_groups: for kind in rng_group: rng_group[kind] = scope.make_rng(kind) @@ -134,7 +134,7 @@ def repack(inner_scope_tree): mutable_variables = {key: val for key, val in inner_scope._variables.items() if not isinstance(val, FrozenDict)} - out_variable_groups = group_kinds( + out_variable_groups = group_collections( mutable_variables, tuple(out_variable_filters) + (True,)) remainder = tuple(out_variable_groups[-1].keys()) if remainder: @@ -149,21 +149,21 @@ def repack(inner_scope_tree): inner_scope.invalidate() for scope, out_variable_groups in zip(scopes, out_variable_groups_xs): for out_variable_group in out_variable_groups: - for kind, kind_variables in out_variable_group.items(): - for name, value in kind_variables.items(): - scope.put_variable(kind, name, value) + for col_name, collection in out_variable_group.items(): + for name, value in collection.items(): + scope.put_variable(col_name, name, value) return y return wrapper id_fn = lambda x: x def transform_module(fn: Callable[..., Any], - target: KindFilter = 'param', + target: CollectionFilter = 'params', trans_in_fn: Callable[..., Any] = id_fn, trans_out_fn: Callable[..., Any] = id_fn, init: bool = True, mutable: bool = False, - rngs: KindFilter = True, - variables: KindFilter = True): + rngs: PRNGSequenceFilter = True, + variables: CollectionFilter = True): def wrapper(scope, *args, **kwargs): if init: vs = scope.variables() @@ -182,11 +182,11 @@ def wrapper(scope, *args, **kwargs): def transform( - target: KindFilter, + target: CollectionFilter, trans_in_fn: Callable[..., Any] = id_fn, trans_out_fn: Callable[..., Any] = id_fn, init: bool = False, mutable: bool = False, - rngs: KindFilter = True, variables: KindFilter = True): + rngs: PRNGSequenceFilter = True, variables: CollectionFilter = True): def wrapper(scope_fn, repack, variable_groups_xs, rng_groups_xs, fn, *args): assert len(variable_groups_xs) == 1, 'transform does not support multi-scope lifting.' target, variables = variable_groups_xs[0] @@ -231,7 +231,7 @@ class Out(Generic[T]): axis: T -def _split_in_out_axes(xs: Mapping[KindFilter, Any]): +def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]): unpack = lambda v: v.axis if isinstance(v, (In, Out)) else v in_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, Out)} out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)} @@ -243,8 +243,8 @@ def _split_in_out_axes(xs: Mapping[KindFilter, Any]): def vmap(fn: Callable[..., Any], - variable_axes: Mapping[KindFilter, InOutAxis], - split_rngs: Mapping[KindFilter, bool], + variable_axes: Mapping[CollectionFilter, InOutAxis], + split_rngs: Mapping[PRNGSequenceFilter, bool], in_axes=0, out_axes=0, axis_size=None) -> Callable[..., Any]: """Wraps jax.vmap.""" variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) @@ -301,9 +301,9 @@ def mapped(variable_groups_xs, rng_groups_xs, args): def scan(fn: Callable[..., Any], - variable_axes: Mapping[KindFilter, InOutScanAxis] = {}, - variable_carry: KindFilter = False, - split_rngs: Mapping[KindFilter, bool] = {}, + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, + variable_carry: CollectionFilter = False, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, in_axes=0, out_axes=0, length: Optional[int] = None, reverse: bool = False) -> Callable[..., Any]: @@ -368,7 +368,7 @@ def scanned(carry, variable_groups_xs, rng_groups_xs, args): def custom_vjp(module_fn: Callable[..., Any], backward_fn: Callable[..., Any], - grad_kind: KindFilter='param', + grad_kind: CollectionFilter='params', nondiff_argnums=()): def inner(scope_fn, repack_fn, variable_groups_xs, rng_groups_xs, *args): assert len(variable_groups_xs) == 1, 'transform does not support multi-scope lifting.' @@ -409,8 +409,8 @@ def f_bwd(*args): def remat(fn: Callable[..., Any], - variables: KindFilter = True, - rngs: KindFilter = True) -> Callable[..., Any]: + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True) -> Callable[..., Any]: """Wraps jax.jit.""" def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): @jax.remat @@ -428,9 +428,8 @@ def jit(fn: Callable[..., Any], static_argnums: Union[int, Iterable[int]] = (), device=None, backend: Union[str, None] = None, - in_variables: KindFilter = True, - out_variables: KindFilter = True, - rngs: KindFilter = True) -> Callable[..., Any]: + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True) -> Callable[..., Any]: """Wraps jax.jit.""" if not isinstance(static_argnums, Iterable): static_argnums = (static_argnums,) @@ -447,14 +446,14 @@ def jitted(variable_groups_xs, rng_groups_xs, *args): return jitted(variable_groups_xs, rng_groups_xs, *args) - return pack(inner, (in_variables,), (out_variables,), (rngs,)) + return pack(inner, (variables,), (variables,), (rngs,)) def remat_scan(body_fn: Callable[..., Any], scope: Scope, carry: Any, lengths: Sequence[int], - variable_carry: KindFilter = False, - variable_axes: Mapping[KindFilter, InOutScanAxis] = {}, - split_rngs: Mapping[KindFilter, bool] = {}): + variable_carry: CollectionFilter = False, + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}): # TODO(jheek) should remat scan have scan inputs/outputs? if len(lengths) == 1: def wrapper(scope, carry): diff --git a/flax/core/scope.py b/flax/core/scope.py index 3373dc8c9..665ae424b 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -34,12 +34,13 @@ PRNGKey = Any Array = Any +Filter = Union[bool, str, Sequence[str]] +CollectionFilter = Filter +PRNGSequenceFilter = Filter -KindFilter = Union[bool, str, Sequence[str]] +MaybeFrozenCollection = Union[Dict[str, Any], FrozenDict[str, Any]] -MaybeFrozenKind = Union[Dict[str, Any], FrozenDict[str, Any]] - -Variables = Dict[str, MaybeFrozenKind] +Variables = Dict[str, MaybeFrozenCollection] def _fold_in_str(rng: PRNGKey, data: str) -> PRNGKey: @@ -51,48 +52,48 @@ def _fold_in_str(rng: PRNGKey, data: str) -> PRNGKey: return random.fold_in(rng, hash_int) -def in_kind_filter(kind_filter: KindFilter, kind: str) -> bool: - if isinstance(kind_filter, str): - return kind == kind_filter - if isinstance(kind_filter, Sequence) and not isinstance(kind_filter, str): - return kind in kind_filter - if isinstance(kind_filter, bool): - return kind_filter - raise TypeError('Invalid KindFilter') +def in_filter(filter: Filter, kind: str) -> bool: + if isinstance(filter, str): + return kind == filter + if isinstance(filter, Sequence) and not isinstance(filter, str): + return kind in filter + if isinstance(filter, bool): + return filter + raise TypeError('Invalid Filter') -def group_kinds(xs: Variables, - kind_filters: Sequence[KindFilter]) -> Sequence[Variables]: +def group_collections(xs: Variables, + col_filters: Sequence[CollectionFilter]) -> Sequence[Variables]: """Group variables by kind filters.""" - kinds = xs.keys() + cols = xs.keys() groups = [] - for kind_filter in kind_filters: - remaining_kinds = [] + for col_filter in col_filters: + remaining_cols = [] group = {} - for kind in kinds: - if in_kind_filter(kind_filter, kind): - group[kind] = jax.tree_map(lambda x: x, xs[kind]) + for col in cols: + if in_filter(col_filter, col): + group[col] = jax.tree_map(lambda x: x, xs[col]) else: - remaining_kinds.append(kind) - kinds = remaining_kinds + remaining_cols.append(col) + cols = remaining_cols groups.append(group) return tuple(groups) class Variable(Generic[T]): - def __init__(self, scope: 'Scope', kind: str, name: str): + def __init__(self, scope: 'Scope', collection: str, name: str): self.scope = scope - self.kind = kind + self.collection = collection self.name = name @property def value(self) -> T: - return self.scope.get_variable(self.kind, self.name) + return self.scope.get_variable(self.collection, self.name) @value.setter def value(self, value: T): - self.scope.put_variable(self.kind, self.name, value) + self.scope.put_variable(self.collection, self.name, value) import contextlib @@ -138,7 +139,7 @@ def invalidate(self): self._invalid = True def variables(self): - self._populate_kinds() + self._populate_collections() return freeze(self._variables) def _validate_trace_level(self): @@ -205,21 +206,21 @@ def wrapper(*args, **kwargs): return fn(scope.rewound(), *args, **kwargs) return wrapper - def get_kind(self, kind: str, mutable: bool = False) -> MaybeFrozenKind: - """Returns all variable of a given kind.""" - if kind not in self._variables: + def collection(self, col: str, mutable: bool = False) -> MaybeFrozenCollection: + """Returns a collection of variables.""" + if col not in self._variables: if self.parent: - parent_kind = self.parent.get_kind(kind, mutable) - if self.name not in parent_kind: - if isinstance(parent_kind, FrozenDict) or not mutable: + parent_col = self.parent.collection(col, mutable) + if self.name not in parent_col: + if isinstance(parent_col, FrozenDict) or not mutable: return FrozenDict() - parent_kind[self.name] = {} - self._variables[kind] = parent_kind[self.name] + parent_col[self.name] = {} + self._variables[col] = parent_col[self.name] elif mutable: - self._variables[kind] = {} + self._variables[col] = {} else: return FrozenDict() - return self._variables[kind] + return self._variables[col] def has_rng(self, kind: str) -> bool: return kind in self.rngs @@ -231,45 +232,45 @@ def make_rng(self, kind: str) -> PRNGKey: self.rng_counters[kind] += 1 return random.fold_in(self.rngs[kind], self.rng_counters[kind]) - def get_variable(self, kind: str, name: str, default: T = None) -> T: - variables = self.get_kind(kind) + def get_variable(self, col: str, name: str, default: T = None) -> T: + variables = self.collection(col) if name in variables: return variables[name] else: return default - def has_variable(self, kind: str, name: str) -> bool: - variables = self.get_kind(kind) + def has_variable(self, col: str, name: str) -> bool: + variables = self.collection(col) return name in variables - def put_variable(self, kind: str, name: str, value: Any): + def put_variable(self, col: str, name: str, value: Any): self._check_valid() self._validate_trace_level() - variables = self.get_kind(kind, mutable=True) + variables = self.collection(col, mutable=True) variables[name] = value - def variable(self, kind: str, name: str, init_fn: Callable[..., T], + def variable(self, col: str, name: str, init_fn: Callable[..., T], *init_args) -> Variable[T]: self.reserve(name) - if not self.has_variable(kind, name): + if not self.has_variable(col, name): init_value = init_fn(*init_args) - self.put_variable(kind, name, init_value) - return Variable(self, kind, name) + self.put_variable(col, name, init_value) + return Variable(self, col, name) def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: - s_init_fn = lambda *args: init_fn(self.make_rng('param'), *init_args) - v = self.variable('param', name, s_init_fn, *init_args) + s_init_fn = lambda *args: init_fn(self.make_rng('params'), *init_args) + v = self.variable('params', name, s_init_fn, *init_args) return v.value - def _populate_kinds(self): - kinds = self.root._variables.keys() - for kind in kinds: - self.get_kind(kind) + def _populate_collections(self): + collections = self.root._variables.keys() + for col in collections: + self.collection(col) def _unfreeze_variables(variables, mutable): new_variables = {} for key, value in variables.items(): - if in_kind_filter(mutable, key): + if in_filter(mutable, key): new_variables[key] = unfreeze(value) else: new_variables[key] = value @@ -277,7 +278,7 @@ def _unfreeze_variables(variables, mutable): def apply(fn: Callable[..., Any], - mutable: KindFilter = False) -> Callable[..., Any]: + mutable: CollectionFilter = False) -> Callable[..., Any]: """Functionalize a module.""" @functools.wraps(fn) def wrapper(variables, *args, rngs=None, **kwargs): @@ -291,11 +292,11 @@ def wrapper(variables, *args, rngs=None, **kwargs): return wrapper -def init(fn: Callable[..., Any], mutable: KindFilter = True) -> Callable[..., Any]: +def init(fn: Callable[..., Any], mutable: CollectionFilter = True) -> Callable[..., Any]: @functools.wraps(fn) def wrapper(rngs, *args, **kwargs): if not isinstance(rngs, dict): assert rngs.shape == (2,) - rngs = {'param': rngs} + rngs = {'params': rngs} return apply(fn, mutable=mutable)({}, *args, rngs=rngs, **kwargs) return wrapper diff --git a/flax/linen/module.py b/flax/linen/module.py index 8b63c95b3..d17de887e 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -420,7 +420,7 @@ def variable(self, kind: str, name: str, init_fn, *init_args): return v def param(self, name: str, init_fn: Callable[..., T], *init_args, - kind='param') -> T: + kind='params') -> T: """Declare a parameter in this Module. Args: @@ -467,7 +467,7 @@ def init_with_output(self, rngs, *args, method=None, **kwargs): """Create initialized data for module and return it with output.""" if not isinstance(rngs, dict): assert rngs.shape == (2,) - rngs = {'param': rngs} + rngs = {'params': rngs} return self.apply( {}, *args, rngs=rngs, method=method, mutable=True, **kwargs) diff --git a/linen_examples/core_design_test/attention_simple.py b/linen_examples/core_design_test/attention_simple.py index ad245b66a..b786044eb 100644 --- a/linen_examples/core_design_test/attention_simple.py +++ b/linen_examples/core_design_test/attention_simple.py @@ -119,14 +119,14 @@ def multi_head_dot_product_attention( attn_fn, in_axes=(None, None, None), out_axes=-2, axis_size=num_heads, - variable_axes={'param': 0}, - split_rngs={'param': True, 'dropout': not broadcast_dropout}) + variable_axes={'params': 0}, + split_rngs={'params': True, 'dropout': not broadcast_dropout}) for axis in reversed(sorted(batch_axes)): attn_fn = lift.vmap( attn_fn, in_axes=(axis, axis, axis), out_axes=axis, - variable_axes={'param': None}, - split_rngs={'param': False, 'dropout': not broadcast_dropout}) + variable_axes={'params': None}, + split_rngs={'params': False, 'dropout': not broadcast_dropout}) y = attn_fn(scope, inputs_q, inputs_kv, bias) return y.mean(axis=-2) @@ -135,7 +135,7 @@ def multi_head_dot_product_attention( inputs = jnp.ones((2, 7, 16)) y, variables = init(multi_head_dot_product_attention)( - {'param': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}, + {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}, inputs, inputs, num_heads=2, batch_axes=(0,), diff --git a/linen_examples/core_design_test/big_resnets.py b/linen_examples/core_design_test/big_resnets.py index c7e32c025..d16304c4c 100644 --- a/linen_examples/core_design_test/big_resnets.py +++ b/linen_examples/core_design_test/big_resnets.py @@ -59,8 +59,8 @@ def body_fn(scope, x): return lift.remat_scan( body_fn, scope, x, lengths=blocks, - variable_axes={'param': 0, 'batch_stats': 0}, - split_rngs={'param': True}) + variable_axes={'params': 0, 'batch_stats': 0}, + split_rngs={'params': True}) if __name__ == "__main__": x = random.normal(random.PRNGKey(0), (1, 8, 8, 8)) diff --git a/linen_examples/core_design_test/resnet.py b/linen_examples/core_design_test/resnet.py index abf33d729..399bec98f 100644 --- a/linen_examples/core_design_test/resnet.py +++ b/linen_examples/core_design_test/resnet.py @@ -70,7 +70,7 @@ def resnet(scope: Scope, x, x = residual_block(block_scope, x, conv, norm, act, block_features, strides) # we can access parameters of the sub module by operating on the scope # Example: - # block_scope.get_kind('param')['conv_1']['kernel'] + # block_scope.get_kind('params')['conv_1']['kernel'] x = jnp.mean(x, (1, 2)) x = scope.child(nn.dense, 'out')(x, num_classes) return x diff --git a/linen_examples/core_design_test/scan.py b/linen_examples/core_design_test/scan.py index a4e868b87..0c5208195 100644 --- a/linen_examples/core_design_test/scan.py +++ b/linen_examples/core_design_test/scan.py @@ -40,14 +40,14 @@ def body_fn(scope, c, x): carry, ys = lift.scan( body_fn, variable_carry='counter', - variable_axes={'param': lift.broadcast}, - split_rngs={'param': False})(scope, (), xs) + variable_axes={'params': lift.broadcast}, + split_rngs={'params': False})(scope, (), xs) else: carry, ys = lift.scan( body_fn, variable_carry='counter', - variable_axes={'param': 0}, - split_rngs={'param': True})(scope, (), xs) + variable_axes={'params': 0}, + split_rngs={'params': True})(scope, (), xs) # output layer return carry, ys diff --git a/linen_examples/core_design_test/tied_autoencoder.py b/linen_examples/core_design_test/tied_autoencoder.py index 704edbbbd..81965b288 100644 --- a/linen_examples/core_design_test/tied_autoencoder.py +++ b/linen_examples/core_design_test/tied_autoencoder.py @@ -55,9 +55,9 @@ def _tied(self, fn, transpose=False): return fn def trans(variables): - if 'param' not in variables: + if 'params' not in variables: return variables - params = variables['param'] + params = variables['params'] params['kernel'] = params['kernel'].T return variables diff --git a/linen_examples/core_design_test/vmap.py b/linen_examples/core_design_test/vmap.py index bdde083a8..c9d8c5043 100644 --- a/linen_examples/core_design_test/vmap.py +++ b/linen_examples/core_design_test/vmap.py @@ -31,13 +31,13 @@ def mlp_vmap(scope: Scope, x: Array, if share_params: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), - variable_axes={'param': None}, - split_rngs={'param': False}) + variable_axes={'params': None}, + split_rngs={'params': False}) else: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), - variable_axes={'param': 0}, - split_rngs={'param': True}) + variable_axes={'params': 0}, + split_rngs={'params': True}) # hidden layers for size in sizes[:-1]: diff --git a/linen_examples/core_design_test/weight_std.py b/linen_examples/core_design_test/weight_std.py index b4c9bea6b..4727f6960 100644 --- a/linen_examples/core_design_test/weight_std.py +++ b/linen_examples/core_design_test/weight_std.py @@ -26,7 +26,7 @@ def weight_std(fn, kernel_name='kernel', eps=1e-8): def std(variables): - params = variables['param'] + params = variables['params'] assert kernel_name in params kernel = params[kernel_name] redux = tuple(range(kernel.ndim - 1)) diff --git a/linen_examples/imagenet/train.py b/linen_examples/imagenet/train.py index d7c3f4112..be6f7aeeb 100644 --- a/linen_examples/imagenet/train.py +++ b/linen_examples/imagenet/train.py @@ -94,7 +94,7 @@ def model(**kwargs): def initialized(key, image_size): input_shape = (1, image_size, image_size, 3) model_ = model() - return model_.init({"param": key}, jnp.ones(input_shape, model_.dtype)) + return model_.init({'params': key}, jnp.ones(input_shape, model_.dtype)) def cross_entropy_loss(logits, labels): @@ -135,10 +135,10 @@ def train_step(state, batch, learning_rate_fn): """Perform a single training step.""" def loss_fn(params): """loss function used for training.""" - variables = {'param': params, 'batch_stats': state.batch_stats} + variables = {'params': params, 'batch_stats': state.batch_stats} logits, new_variables = model().apply(variables, batch['image'], mutable=['batch_stats']) loss = cross_entropy_loss(logits, batch['label']) - weight_penalty_params = jax.tree_leaves(variables['param']) + weight_penalty_params = jax.tree_leaves(variables['params']) weight_decay = 0.0001 weight_l2 = sum([jnp.sum(x ** 2) for x in weight_penalty_params @@ -183,7 +183,7 @@ def loss_fn(params): def eval_step(state, batch): params = state.optimizer.target - variables = {'param': params, 'batch_stats': state.batch_stats} + variables = {'params': params, 'batch_stats': state.batch_stats} logits = model(train=False).apply(variables, batch['image'], mutable=False) return compute_metrics(logits, batch['label']) @@ -285,7 +285,7 @@ def main(argv): base_learning_rate = FLAGS.learning_rate * batch_size / 256. variables = initialized(rng, image_size) - optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(variables['param']) + optimizer = optim.Momentum(beta=FLAGS.momentum, nesterov=True).create(variables['params']) state = TrainState(step=0, optimizer=optimizer, batch_stats=variables['batch_stats'], dynamic_scale=dynamic_scale) diff --git a/linen_examples/linen_design_test/attention_simple.py b/linen_examples/linen_design_test/attention_simple.py index 25c8b9dce..8408ee226 100644 --- a/linen_examples/linen_design_test/attention_simple.py +++ b/linen_examples/linen_design_test/attention_simple.py @@ -192,7 +192,7 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): if __name__ == '__main__': inputs = jnp.ones((8, 97, 256)) - rngs = {'param': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} + rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} model = MultiHeadDotProductAttention( broadcast_dropout=False, qkv_features=256, diff --git a/linen_examples/linen_design_test/autoencoder.py b/linen_examples/linen_design_test/autoencoder.py index 54461c895..47f7dbe56 100644 --- a/linen_examples/linen_design_test/autoencoder.py +++ b/linen_examples/linen_design_test/autoencoder.py @@ -72,7 +72,7 @@ def decode(self, z): # `ae.initialized` returnes a materialized copy of `ae` by # running through an input to create submodules defined lazily. params = ae.init( - {'param': random.PRNGKey(42)}, + {'params': random.PRNGKey(42)}, jnp.ones((1, 28, 28, 1))) @@ -82,7 +82,7 @@ def decode(self, z): # `ae.variables` is a frozen dict that looks like -# {"param": {"decoder": {"Dense_0": {"bias": ..., "kernel": ...}, ...}} +# {'params': {"decoder": {"Dense_0": {"bias": ..., "kernel": ...}, ...}} print("var shapes", jax.tree_map(jnp.shape, params)) diff --git a/linen_examples/linen_design_test/linear_regression.py b/linen_examples/linen_design_test/linear_regression.py index bd8a66b2d..c9a24c915 100644 --- a/linen_examples/linen_design_test/linear_regression.py +++ b/linen_examples/linen_design_test/linear_regression.py @@ -27,7 +27,7 @@ @jit def predict(params): - return model.apply({'param': params}, X) + return model.apply({'params': params}, X) @jit def loss_fn(params): @@ -35,8 +35,8 @@ def loss_fn(params): @jit def init_params(rng): - mlp_variables = model.init({'param': rng}, X) - return mlp_variables['param'] + mlp_variables = model.init({'params': rng}, X) + return mlp_variables['params'] # Get initial parameters params = init_params(jax.random.PRNGKey(42)) diff --git a/linen_examples/linen_design_test/mlp_explicit.py b/linen_examples/linen_design_test/mlp_explicit.py index 73aeb2257..2b1f242b9 100644 --- a/linen_examples/linen_design_test/mlp_explicit.py +++ b/linen_examples/linen_design_test/mlp_explicit.py @@ -42,7 +42,7 @@ def setup(self): # explicit instances are materialized immediately at init pprint(self.dense2.variables) - # {'param': {'bias': DeviceArray([0.], dtype=float32), + # {'params': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[ 0.6704609 ], # [-0.90477365]], dtype=float32)}} @@ -52,10 +52,10 @@ def __call__(self, x): # Return an initialized instance of MLP by only calling `setup`. rngkey = jax.random.PRNGKey(10) -init_variables = MLP().init({'param': rngkey}, jnp.ones((1, 3))) +init_variables = MLP().init({'params': rngkey}, jnp.ones((1, 3))) pprint(init_variables) -# {'param': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), +# {'params': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), # 'kernel': DeviceArray([[ 0.18307537, -0.38739476], # [-0.902451 , -0.5190721 ], # [ 0.51552075, 1.1169153 ]], dtype=float32)}, diff --git a/linen_examples/linen_design_test/mlp_inline.py b/linen_examples/linen_design_test/mlp_inline.py index 57b7950b3..6dd9cc1e5 100644 --- a/linen_examples/linen_design_test/mlp_inline.py +++ b/linen_examples/linen_design_test/mlp_inline.py @@ -45,7 +45,7 @@ def __call__(self, x): x = jnp.ones((1, 3)) mlp_variables = model.init(rngkey, x) print(mlp_variables) -# {'param': {'Dense_0': {'bias': DeviceArray([0.], dtype=float32), +# {'params': {'Dense_0': {'bias': DeviceArray([0.], dtype=float32), # 'kernel': DeviceArray([[-0.04267037], # [-0.51097125]], dtype=float32)}, # 'Dense_1': {'bias': DeviceArray([0., 0.], dtype=float32), diff --git a/linen_examples/linen_design_test/mlp_lazy.py b/linen_examples/linen_design_test/mlp_lazy.py index b76e3b9f6..8b3bd53eb 100644 --- a/linen_examples/linen_design_test/mlp_lazy.py +++ b/linen_examples/linen_design_test/mlp_lazy.py @@ -45,7 +45,7 @@ def __call__(self, x): mlp_variables = MLP().init(rngkey, jnp.zeros((1, 3))) pprint(mlp_variables) -# {'param': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), +# {'params': {'dense1': {'bias': DeviceArray([0., 0.], dtype=float32), # 'kernel': DeviceArray([[ 0.18307537, -0.38739476], # [-0.902451 , -0.5190721 ], # [ 0.51552075, 1.1169153 ]], dtype=float32)}, diff --git a/linen_examples/linen_design_test/tied_autoencoder.py b/linen_examples/linen_design_test/tied_autoencoder.py index 0752c3908..ec645eea1 100644 --- a/linen_examples/linen_design_test/tied_autoencoder.py +++ b/linen_examples/linen_design_test/tied_autoencoder.py @@ -32,7 +32,7 @@ # @property # def decoder(self): # return self.encoder.detached().attached(variables={ -# "param": {"kernel": self.encoder.variables['param']['kernel'].T}}) +# 'params': {"kernel": self.encoder.variables['params']['kernel'].T}}) # def __call__(self, x): # z = self.encoder(x) @@ -41,7 +41,7 @@ # tae = TiedAutoEncoder(parent=None) # tae = tae.initialized( -# {'param': random.PRNGKey(42)}, +# {'params': random.PRNGKey(42)}, # jnp.ones((1, 16))) # print("reconstruct", jnp.shape(tae(jnp.ones((1, 16))))) # print("var shapes", jax.tree_map(jnp.shape, tae.variables)) diff --git a/linen_examples/linen_design_test/weight_std.py b/linen_examples/linen_design_test/weight_std.py index 0598f109d..f16d0856e 100644 --- a/linen_examples/linen_design_test/weight_std.py +++ b/linen_examples/linen_design_test/weight_std.py @@ -43,15 +43,15 @@ def standardize(x, axis, eps=1e-8): # def __call__(self, x): # # TODO: Think about how this modifies other state -# if not 'param' in self.module.variables: +# if not 'params' in self.module.variables: # # initialize parameters # self.module(x) -# param = self.module.variables['param'] +# param = self.module.variables['params'] # # Make a copy because `param` is (and should be) frozen. We're only transforming # # the parameters, not mutating them. # std_param = param.copy(kernel=standardize(param['kernel'], axis=[0, 1])) -# return self.module.clone(parent=None).apply({"param": std_param}, x) +# return self.module.clone(parent=None).apply({'params': std_param}, x) # class MyModule(Module): # def __call__(self, x): @@ -59,5 +59,5 @@ def standardize(x, axis, eps=1e-8): # std_module = StdWeight(module) # return std_module(x) -# m_variables = MyModule().init({'param': jax.random.PRNGKey(10)}, jnp.ones((1, 4))) +# m_variables = MyModule().init({'params': jax.random.PRNGKey(10)}, jnp.ones((1, 4))) # print(m_variables) diff --git a/linen_examples/mnist/mnist_lib.py b/linen_examples/mnist/mnist_lib.py index f77948580..c354aa348 100644 --- a/linen_examples/mnist/mnist_lib.py +++ b/linen_examples/mnist/mnist_lib.py @@ -55,7 +55,7 @@ def __call__(self, x): def get_initial_params(key): init_shape = jnp.ones((1, 28, 28, 1), jnp.float32) - initial_params = CNN().init(key, init_shape)["param"] + initial_params = CNN().init(key, init_shape)['params'] return initial_params @@ -88,7 +88,7 @@ def compute_metrics(logits, labels): def train_step(optimizer, batch): """Train for a single step.""" def loss_fn(params): - logits = CNN().apply({'param': params}, batch['image']) + logits = CNN().apply({'params': params}, batch['image']) loss = cross_entropy_loss(logits, batch['label']) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -100,7 +100,7 @@ def loss_fn(params): @jax.jit def eval_step(params, batch): - logits = CNN().apply({'param': params}, batch['image']) + logits = CNN().apply({'params': params}, batch['image']) return compute_metrics(logits, batch['label']) diff --git a/linen_examples/mnist/mnist_lib_test.py b/linen_examples/mnist/mnist_lib_test.py index f58b24a61..ae0704728 100644 --- a/linen_examples/mnist/mnist_lib_test.py +++ b/linen_examples/mnist/mnist_lib_test.py @@ -40,7 +40,7 @@ def test_cnn(self): # TODO(mohitreddy): Consider creating a testing module which # gives a parameters overview including number of parameters. - self.assertLen(variables['param'], 4) + self.assertLen(variables['params'], 4) def test_train_and_evaluate(self): """Tests training and evaluation code by running a single step with diff --git a/linen_examples/pixelcnn/model_test.py b/linen_examples/pixelcnn/model_test.py index a13060e57..dd8fab73e 100644 --- a/linen_examples/pixelcnn/model_test.py +++ b/linen_examples/pixelcnn/model_test.py @@ -50,7 +50,7 @@ def assert_mean_and_variance(self, out): def test_conv(self): model = pixelcnn.ConvWeightNorm(features=4, kernel_size=(3, 2)) out, variables = model.init_with_output(self.rng, self.x) - params = variables['param']['weightnorm_params'] + params = variables['params']['weightnorm_params'] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (3, 2, 2, 4)) @@ -63,7 +63,7 @@ def test_conv(self): def test_conv_down(self): model = pixelcnn.ConvDown(features=4) out, variables = model.init_with_output(self.rng, self.x) - params = variables['param']['ConvWeightNorm_0']['weightnorm_params'] + params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (2, 3, 2, 4)) @@ -76,7 +76,7 @@ def test_conv_down(self): def test_conv_down_right(self): model = pixelcnn.ConvDownRight(features=4) out, variables = model.init_with_output(self.rng, self.x) - params = variables['param']['ConvWeightNorm_0']['weightnorm_params'] + params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (2, 2, 2, 4)) @@ -89,7 +89,7 @@ def test_conv_down_right(self): def test_conv_transpose(self): model = pixelcnn.ConvTranspose(features=4, kernel_size = (3, 2)) out, variables = model.init_with_output(self.rng, self.x) - params = variables['param']['weightnorm_params'] + params = variables['params']['weightnorm_params'] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (3, 2, 2, 4)) @@ -102,7 +102,7 @@ def test_conv_transpose(self): def test_conv_transpose_down(self): model = pixelcnn.ConvTransposeDown(features=4) out, variables = model.init_with_output(self.rng, self.x) - params = variables["param"]["ConvWeightNorm_0"]["weightnorm_params"] + params = variables['params']["ConvWeightNorm_0"]["weightnorm_params"] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (2, 3, 2, 4)) @@ -114,7 +114,7 @@ def test_conv_transpose_down(self): def test_conv_transpose_down_right(self): model = pixelcnn.ConvTransposeDownRight(features=4) out, variables = model.init_with_output(self.rng, self.x) - params = variables['param']['ConvWeightNorm_0']['weightnorm_params'] + params = variables['params']['ConvWeightNorm_0']['weightnorm_params'] direction, scale, bias = self.get_weightnorm(params) self.assertEqual(direction.shape, (2, 2, 2, 4)) diff --git a/linen_examples/pixelcnn/sample.py b/linen_examples/pixelcnn/sample.py index ca9507c24..198a2177d 100644 --- a/linen_examples/pixelcnn/sample.py +++ b/linen_examples/pixelcnn/sample.py @@ -53,9 +53,9 @@ def generate_sample(): init_batch = jnp.zeros((1, 32, 32, 3)) params = train.model().init({ - 'param': model_rng, + 'params': model_rng, 'dropout': dropout_rng - }, init_batch)['param'] + }, init_batch)['params'] optimizer_def = optim.Adam( learning_rate=FLAGS.learning_rate, beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(params) @@ -102,7 +102,7 @@ def sample_iteration(rng, params, sample): """PixelCNN++ sampling expressed as a fixed-point iteration. """ rng, dropout_rng = random.split(rng) - out = train.model().apply({'param': params}, sample, + out = train.model().apply({'params': params}, sample, rngs={'dropout': dropout_rng}) c_params = pixelcnn.conditional_params_from_outputs(out, sample) return conditional_params_to_sample(rng, c_params) diff --git a/linen_examples/pixelcnn/train.py b/linen_examples/pixelcnn/train.py index 3850e9e8a..6a15497d7 100644 --- a/linen_examples/pixelcnn/train.py +++ b/linen_examples/pixelcnn/train.py @@ -135,7 +135,7 @@ def train_step(optimizer, ema, batch, learning_rate_fn, dropout_rng): def loss_fn(params): """loss function used for training.""" pcnn_out = model(dropout_p=FLAGS.dropout_rate).apply( - {'param': params}, batch['image'], rngs={'dropout': dropout_rng}) + {'params': params}, batch['image'], rngs={'dropout': dropout_rng}) return neg_log_likelihood_loss(pcnn_out, batch['image']) lr = learning_rate_fn(optimizer.state.step) @@ -155,7 +155,7 @@ def loss_fn(params): def eval_step(params, batch): images = batch['image'] - pcnn_out = model(dropout_p=0.).apply({'param': params}, images) + pcnn_out = model(dropout_p=0.).apply({'params': params}, images) return {'loss': lax.pmean(neg_log_likelihood_loss(pcnn_out, images), 'batch')} @@ -217,9 +217,9 @@ def train(): rng, dropout_rng = random.split(rng) initial_variables = model().init({ - 'param': init_rng, + 'params': init_rng, 'dropout': dropout_rng - }, init_batch)['param'] + }, init_batch)['params'] optimizer_def = optim.Adam( learning_rate=FLAGS.learning_rate, beta1=0.95, beta2=0.9995) optimizer = optimizer_def.create(initial_variables) diff --git a/linen_examples/seq2seq/train.py b/linen_examples/seq2seq/train.py index c732583b4..f1adb7a75 100644 --- a/linen_examples/seq2seq/train.py +++ b/linen_examples/seq2seq/train.py @@ -168,8 +168,8 @@ class EncoderLSTM(nn.Module): @functools.partial( nn.transforms.scan, - variable_axes={'param': nn.broadcast}, - split_rngs={'param': False}) + variable_axes={'params': nn.broadcast}, + split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): lstm_state, is_eos = carry @@ -215,8 +215,8 @@ class DecoderLSTM(nn.Module): @functools.partial( nn.transforms.scan, - variable_axes={'param': nn.broadcast}, - split_rngs={'param': False}) + variable_axes={'params': nn.broadcast}, + split_rngs={'params': False}) @nn.compact def __call__(self, carry, x): rng, lstm_state, last_prediction = carry @@ -303,8 +303,8 @@ def get_initial_params(key): vocab_size = CTABLE.vocab_size encoder_shape = jnp.ones((1, get_max_input_len(), vocab_size), jnp.float32) decoder_shape = jnp.ones((1, get_max_output_len(), vocab_size), jnp.float32) - return model().init({'param': key, 'lstm': key}, - encoder_shape, decoder_shape)['param'] + return model().init({'params': key, 'lstm': key}, + encoder_shape, decoder_shape)['params'] def get_examples(num_examples): @@ -359,7 +359,7 @@ def train_step(optimizer, batch, lstm_key): def loss_fn(params): """Compute cross-entropy loss.""" - logits, _ = model().apply({'param': params}, + logits, _ = model().apply({'params': params}, batch['query'], batch['answer'], rngs={'lstm': lstm_key}) @@ -386,7 +386,7 @@ def decode(params, inputs, key): init_decoder_input = onehot(CTABLE.encode('=')[0:1], CTABLE.vocab_size) init_decoder_inputs = jnp.tile(init_decoder_input, (inputs.shape[0], get_max_output_len(), 1)) - _, predictions = model(teacher_force=False).apply({'param': params}, + _, predictions = model(teacher_force=False).apply({'params': params}, inputs, init_decoder_inputs, rngs={'lstm': key}) diff --git a/linen_examples/vae/train.py b/linen_examples/vae/train.py index a5438c204..d9f7495ee 100644 --- a/linen_examples/vae/train.py +++ b/linen_examples/vae/train.py @@ -128,7 +128,7 @@ def model(): @jax.jit def train_step(optimizer, batch, z_rng): def loss_fn(params): - recon_x, mean, logvar = model().apply({'param': params}, batch, z_rng) + recon_x, mean, logvar = model().apply({'params': params}, batch, z_rng) bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean() kld_loss = kl_divergence(mean, logvar).mean() @@ -142,12 +142,12 @@ def loss_fn(params): @jax.jit def eval(params, images, z, z_rng): - recon_images, mean, logvar = model().apply({'param': params}, images, z_rng) + recon_images, mean, logvar = model().apply({'params': params}, images, z_rng) comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1), recon_images[:8].reshape(-1, 28, 28, 1)]) - generate_images = model().apply({'param': params}, z, method=VAE.generate) + generate_images = model().apply({'params': params}, z, method=VAE.generate) generate_images = generate_images.reshape(-1, 28, 28, 1) metrics = compute_metrics(recon_images, images, mean, logvar) @@ -181,7 +181,7 @@ def main(argv): test_ds = jax.device_put(test_ds) init_data = jnp.ones((FLAGS.batch_size, 784), jnp.float32) - params = model().init(key, init_data, rng)['param'] + params = model().init(key, init_data, rng)['params'] optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(params) optimizer = jax.device_put(optimizer) diff --git a/linen_examples/wmt/train.py b/linen_examples/wmt/train.py index 58f9d6c58..fc86a0d4c 100644 --- a/linen_examples/wmt/train.py +++ b/linen_examples/wmt/train.py @@ -354,7 +354,7 @@ def train_step(optimizer, def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( - {'param': params}, + {'params': params}, inputs, targets, inputs_positions=inputs_positions, @@ -385,7 +385,7 @@ def eval_step(params, batch, config, label_smoothing=0.0): inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(config).apply( - {'param': params}, inputs, targets) + {'params': params}, inputs, targets) return compute_metrics(logits, targets, weights, label_smoothing) @@ -411,7 +411,7 @@ def predict_step(inputs, params, cache, eos_id, max_decode_len, config, # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( models.Transformer(config).apply( - {'param': params}, inputs, method=models.Transformer.encode), + {'params': params}, inputs, method=models.Transformer.encode), beam_size) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) @@ -419,7 +419,7 @@ def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( - {'param': params, 'cache': flat_cache}, + {'params': params, 'cache': flat_cache}, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, @@ -578,7 +578,7 @@ def initialize_variables(rng): beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) - optimizer = optimizer_def.create(initial_variables['param']) + optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables diff --git a/tests/core/scope_test.py b/tests/core/scope_test.py index ff5f89cac..ae453025e 100644 --- a/tests/core/scope_test.py +++ b/tests/core/scope_test.py @@ -25,9 +25,9 @@ class ScopeTest(absltest.TestCase): def test_rng(self): def f(scope): - self.assertTrue(scope.has_rng('param')) + self.assertTrue(scope.has_rng('params')) self.assertFalse(scope.has_rng('dropout')) - rng = scope.make_rng('param') + rng = scope.make_rng('params') self.assertTrue(np.all(rng == random.fold_in(random.PRNGKey(0), 1))) init(f)(random.PRNGKey(0)) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index d3b20be62..c0c181230 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -72,7 +72,7 @@ def test_multihead_self_attention_w_dropout(self): dropout_rate=0.1, ) rng1, rng2 = random.split(rng) - rngs = {'param': rng1, 'dropout': rng2} + rngs = {'params': rng1, 'dropout': rng2} y, _ = sa_module.init_with_output(rngs, x, x) self.assertEqual(y.shape, x.shape) diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index ba872665b..3d1f5c5f0 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -37,7 +37,7 @@ class LinearTest(parameterized.TestCase): def test_dense(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, @@ -49,7 +49,7 @@ def test_dense(self): np.testing.assert_allclose(y, np.full((1, 4), 4.)) def test_dense_extra_batch_dims(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 2, 3)) dense_module = nn.Dense( features=4, @@ -60,7 +60,7 @@ def test_dense_extra_batch_dims(self): np.testing.assert_allclose(y, np.full((1, 2, 4), 4.)) def test_dense_no_bias(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, @@ -77,18 +77,18 @@ def test_dense_is_dense_general(self): use_bias=True, bias_init=initializers.normal(), ) - y1, _ = dense_module.init_with_output(dict(param=random.PRNGKey(1)), x) + y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x) dg_module = nn.DenseGeneral( features=4, use_bias=True, bias_init=initializers.normal(), ) - y2, _ = dg_module.init_with_output(dict(param=random.PRNGKey(1)), x) + y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x) np.testing.assert_allclose(y1, y2) def test_dense_general_batch_dim_raises(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3, 2, 5)) with self.assertRaises(ValueError): dg_module = nn.DenseGeneral( @@ -100,7 +100,7 @@ def test_dense_general_batch_dim_raises(self): dg_module.init_with_output(rng, x) def test_dense_general_two_out(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3)) dg_module = nn.DenseGeneral( features=(2, 2), @@ -111,7 +111,7 @@ def test_dense_general_two_out(self): np.testing.assert_allclose(y, np.full((1, 2, 2), 4.)) def test_dense_general_two_in(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 2, 2)) dg_module = nn.DenseGeneral( features=3, @@ -123,7 +123,7 @@ def test_dense_general_two_in(self): np.testing.assert_allclose(y, np.full((1, 3), 5.)) def test_dense_general_batch_dim(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((2, 1, 3, 5)) state = {'counter': 0.} @@ -149,7 +149,7 @@ def _counter_init(rng, shape, dtype, state): ((3, -2), (), 'bijk,kjlm->bilm'), ((-2, 3), (0,), 'bijk,bjklm->bilm')]) def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((16, 8, 9, 10)) dg_module = nn.DenseGeneral( @@ -160,11 +160,11 @@ def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): kernel_init=initializers.normal(), ) y, initial_params = dg_module.init_with_output(rng, x) - target = np.einsum(einsum_expr, x, initial_params['param']['kernel']) + 1. + target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1. np.testing.assert_allclose(y, target, atol=1e-6) def test_conv(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 8, 3)) conv_module = nn.Conv( features=4, @@ -174,11 +174,11 @@ def test_conv(self): bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) - self.assertEqual(initial_params['param']['kernel'].shape, (3, 3, 4)) + self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) np.testing.assert_allclose(y, np.full((1, 6, 4), 10.)) def test_group_conv(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 8, 4)) conv_module = nn.Conv( features=4, @@ -189,11 +189,11 @@ def test_group_conv(self): bias_init=initializers.ones, ) y, initial_params = conv_module.init_with_output(rng, x) - self.assertEqual(initial_params['param']['kernel'].shape, (3, 2, 4)) + self.assertEqual(initial_params['params']['kernel'].shape, (3, 2, 4)) np.testing.assert_allclose(y, np.full((1, 6, 4), 7.)) def test_conv_transpose(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, @@ -203,7 +203,7 @@ def test_conv_transpose(self): bias_init=initializers.ones, ) y, initial_params = conv_transpose_module.init_with_output(rng, x) - self.assertEqual(initial_params['param']['kernel'].shape, (3, 3, 4)) + self.assertEqual(initial_params['params']['kernel'].shape, (3, 3, 4)) correct_ans = np.array([[[ 4., 4., 4., 4.], [ 7., 7., 7., 7.], [10., 10., 10., 10.], @@ -217,7 +217,7 @@ def test_conv_transpose(self): np.testing.assert_allclose(y, correct_ans) def test_embed(self): - rng = dict(param=random.PRNGKey(0)) + rng = dict(params=random.PRNGKey(0)) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to( jnp.arange(4)[..., None], (4, 3)).astype(jnp.float32) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 0d5b6876e..2ae78e507 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -188,7 +188,7 @@ def test_lstm(self): self.assertEqual(carry[0].shape, (2, 4)) self.assertEqual(carry[1].shape, (2, 4)) np.testing.assert_allclose(y, carry[1]) - param_shapes = jax.tree_map(np.shape, initial_params['param']) + param_shapes = jax.tree_map(np.shape, initial_params['params']) self.assertEqual(param_shapes, { 'ii': {'kernel': (3, 4)}, 'if': {'kernel': (3, 4)}, @@ -211,7 +211,7 @@ def test_gru(self): #gru = nn.Model(nn.GRUCell, initial_params) self.assertEqual(carry.shape, (2, 4)) np.testing.assert_allclose(y, carry) - param_shapes = jax.tree_map(np.shape, initial_params['param']) + param_shapes = jax.tree_map(np.shape, initial_params['params']) self.assertEqual(param_shapes, { 'ir': {'kernel': (3, 4), 'bias': (4,)}, 'iz': {'kernel': (3, 4), 'bias': (4,)}, diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 039c2bc18..dc01c3b33 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -124,8 +124,8 @@ def test_vmap(self): def vmap(cls): return nn.vmap(cls, in_axes=(0,), - variable_axes={'param': None}, - split_rngs={'param': False}) + variable_axes={'params': None}, + split_rngs={'params': False}) normal_model = TransformedMLP(features=[3, 4, 5]) vmap_model = TransformedMLP(features=[3, 4, 5], transform=vmap) init_variables = normal_model.init(key2, x) @@ -143,8 +143,8 @@ def test_vmap_decorated(self): def vmap(fn): return nn.vmap(fn, in_axes=(0,), - variable_axes={'param': None}, - split_rngs={'param': False}) + variable_axes={'params': None}, + split_rngs={'params': False}) normal_model = decorated_MLP()(features=[3, 4, 5]) vmap_model = decorated_MLP(vmap)(features=[3, 4, 5]) init_variables = normal_model.init(key2, x) @@ -159,8 +159,8 @@ class SimpleScan(nn.Module): @nn.compact def __call__(self, c, xs): LSTM = nn.scan(nn.LSTMCell, - variable_axes={'param': nn.broadcast}, - split_rngs={'param': False}) + variable_axes={'params': nn.broadcast}, + split_rngs={'params': False}) return LSTM(name="lstm_cell")(c, xs) key1, key2 = random.split(random.PRNGKey(0), 2) @@ -174,7 +174,7 @@ def __call__(self, c, xs): # simulate scan in python for comparison: c = init_carry ys = [] - lstmcell_variables = freeze({'param': init_variables['param']['lstm_cell']}) + lstmcell_variables = freeze({'params': init_variables['params']['lstm_cell']}) for i in range(xs.shape[0]): c, y = nn.LSTMCell().apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) @@ -188,8 +188,8 @@ def __call__(self, c, xs): def test_scan_decorated(self): class SimpleScan(nn.Module): @partial(nn.scan, - variable_axes={'param': nn.broadcast}, - split_rngs={'param': False}) + variable_axes={'params': nn.broadcast}, + split_rngs={'params': False}) @nn.compact def __call__(self, c, xs): return nn.LSTMCell(name="lstm_cell")(c, xs) @@ -205,7 +205,7 @@ def __call__(self, c, xs): # simulate scan in python for comparison: c = init_carry ys = [] - lstmcell_variables = freeze({'param': init_variables['param']['lstm_cell']}) + lstmcell_variables = freeze({'params': init_variables['params']['lstm_cell']}) for i in range(xs.shape[0]): c, y = nn.LSTMCell().apply(lstmcell_variables, c, xs[i]) ys.append(y[None, ...]) @@ -430,8 +430,8 @@ class Inner(nn.Module): outer_module: nn.Module @partial(nn.vmap, in_axes=(0,), - variable_axes={'param': 0}, - split_rngs={'param': True}) + variable_axes={'params': 0}, + split_rngs={'params': True}) @nn.jit @nn.compact def __call__(self, x): @@ -450,10 +450,10 @@ def __call__(self, x): init_vars = Test(None).init(rngs, x) y = Test(None).apply(init_vars, x) self.assertEqual( - init_vars['param']['outer']['Dense_0']['kernel'].shape, + init_vars['params']['outer']['Dense_0']['kernel'].shape, (3, 2, 5)) self.assertEqual( - init_vars['param']['outer']['Dense_0']['bias'].shape, + init_vars['params']['outer']['Dense_0']['bias'].shape, (3, 5)) self.assertEqual(y.shape, (3, 1, 5)) diff --git a/tests/linen/module_test.py b/tests/linen/module_test.py index d2ecd9301..32a116b75 100644 --- a/tests/linen/module_test.py +++ b/tests/linen/module_test.py @@ -57,9 +57,9 @@ class ModuleTest(absltest.TestCase): def test_init_module(self): rngkey = jax.random.PRNGKey(0) x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = DummyModule(parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = DummyModule(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) onp.testing.assert_allclose(y, jnp.array([2.])) @@ -68,9 +68,9 @@ def test_init_module(self): def test_arg_module(self): rngkey = jax.random.PRNGKey(0) x = jnp.ones((10,)) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = Dense(3, parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = Dense(3, parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) self.assertEqual(params['kernel'].shape, (10, 3)) @@ -86,9 +86,9 @@ def __call__(self, x): def _mydense(self, x): return Dense(3)(x) x = jnp.ones((10,)) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = MLP(parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) @@ -114,9 +114,9 @@ def __call__(self, x): z = mlp(x) return y + z x = jnp.ones((10,)) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = Top(parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = Top(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) @@ -137,9 +137,9 @@ def __call__(self, x): #w = self.lyrs2[0](x) return z x = jnp.ones((10,)) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = MLP(parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = MLP(parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) param_shape = jax.tree_map(jnp.shape, params) @@ -175,14 +175,14 @@ def setup(self): def __call__(self): return self.outer() - scope = Scope({'param': {}}, rngs={'param': rngkey}) + scope = Scope({'params': {}}, rngs={'params': rngkey}) # Make sure this doesn't raise "Can't attach to remote parent" wrapper = Wrapper(parent=scope) wrapper() # Make sure that variables are registered at the level of the # Wrapper submodule, not the Outer submodule. - self.assertEqual(40, scope.variables()['param']['inner']['x']) + self.assertEqual(40, scope.variables()['params']['inner']['x']) def test_param_in_setup(self): rngkey = jax.random.PRNGKey(0) @@ -193,9 +193,9 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) y = DummyModule(x.shape, parent=scope)(x) - params = scope.variables()['param'] + params = scope.variables()['params'] y2 = DummyModule(x.shape, parent=scope.rewound())(x) onp.testing.assert_allclose(y, y2) onp.testing.assert_allclose(y, jnp.array([2.])) @@ -208,7 +208,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): y = DummyModule(parent=scope)(x) @@ -223,7 +223,7 @@ def foo(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'must be initialized.*setup'): y = Dummy(parent=scope).foo(x) @@ -238,7 +238,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'bias already in use'): y = Dummy(x.shape, parent=scope)(x) @@ -252,7 +252,7 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'bias already in use'): y = Dummy(x.shape, parent=scope)(x) @@ -266,7 +266,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) return x + bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'bias already in use'): y = Dummy(x.shape, parent=scope)(x) @@ -279,7 +279,7 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'notbias.*must equal.*bias'): y = Dummy(x.shape, parent=scope)(x) @@ -293,7 +293,7 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'name bias exists already'): y = Dummy(x.shape, parent=scope)(x) class Dummy(nn.Module): @@ -305,7 +305,7 @@ def __call__(self, x): bias = DummyModule(name='bias') return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'name bias exists already'): y = Dummy(x.shape, parent=scope)(x) class Dummy(nn.Module): @@ -317,7 +317,7 @@ def __call__(self, x): bias = self.param('bias', initializers.ones, self.xshape) return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'bias already'): y = Dummy(x.shape, parent=scope)(x) @@ -330,7 +330,7 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'In setup, assign names of Modules ' 'via self. and not using keyword argument name=""'): y = Dummy(x.shape, parent=scope)(x) @@ -344,7 +344,7 @@ def setup(self): def __call__(self, x): return x + self.bias x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'Name bias already in use'): y = Dummy(x.shape, parent=scope)(x) @@ -357,7 +357,7 @@ def setup(self): def __call__(self, x): return self.bias(x) x = jnp.array([1.]) - scope = Scope({}, {'param': rngkey}) + scope = Scope({}, {'params': rngkey}) with self.assertRaisesRegex(ValueError, 'bias exists already'): y = Dummy(x.shape, parent=scope)(x) @@ -428,7 +428,7 @@ def __call__(self, x): x = nn.relu(nn.Dense(width)(x)) return nn.Dense(self.widths[-1])(x) test = MLP(onp.array([3, 3], onp.int32)) - params = test.init({'param': random.PRNGKey(42)}, jnp.ones((3, 3))) + params = test.init({'params': random.PRNGKey(42)}, jnp.ones((3, 3))) _ = test.apply(params, jnp.ones((3, 3))) def test_get_local_methods(self): diff --git a/tests/linen/toplevel_test.py b/tests/linen/toplevel_test.py index 7455c00a7..e597ba6c0 100644 --- a/tests/linen/toplevel_test.py +++ b/tests/linen/toplevel_test.py @@ -47,11 +47,11 @@ class ModuleTopLevelTest(absltest.TestCase): # d = Dummy(parent=None).initialized() # def test_toplevel_initialized_with_rng(self): - # d = Dummy(parent=None).initialized(rngs={'param': random.PRNGKey(0)}) + # d = Dummy(parent=None).initialized(rngs={'params': random.PRNGKey(0)}) # self.assertEqual(d.variables.param.foo, 1) # def test_toplevel_initialized_frozen(self): - # d = Dummy(parent=None).initialized(rngs={'param': random.PRNGKey(0)}) + # d = Dummy(parent=None).initialized(rngs={'params': random.PRNGKey(0)}) # with self.assertRaisesRegex(BaseException, "Can't set value"): # d.variables.param.foo = 2 @@ -59,12 +59,12 @@ class ModuleTopLevelTest(absltest.TestCase): # d = Dummy(parent=None) # # initializing should make a copy and not have any effect # # on `d` itself. - # d_initialized = d.initialized(rngs={'param': random.PRNGKey(0)}) + # d_initialized = d.initialized(rngs={'params': random.PRNGKey(0)}) # # ... make sure that indeed `d` has no scope. # self.assertIsNone(d.scope) # def test_can_only_call_initialized_once(self): # d = Dummy(parent=None) - # d = d.initialized(rngs={'param': random.PRNGKey(0)}) + # d = d.initialized(rngs={'params': random.PRNGKey(0)}) # with self.assertRaises(BaseException): - # d.initialized(rngs={'param': random.PRNGKey(0)}) + # d.initialized(rngs={'params': random.PRNGKey(0)}) diff --git a/tests/nn_test.py b/tests/nn_test.py index bb7a33756..f34ccc667 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -168,7 +168,7 @@ def test_shared_module_called_in_other_frame(self): class SubModule(nn.Module): def apply(self): - self.param('param', (), initializers.zeros) + self.param('params', (), initializers.zeros) class UseSharedModule(nn.Module): @@ -184,7 +184,7 @@ def apply(self): _, params = TopLevel.init(random.PRNGKey(0)) self.assertEqual({ - 'shared': {'param': jnp.zeros(())}, + 'shared': {'params': jnp.zeros(())}, 'use_shared': {}, }, params)