Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ Module
------------------------

.. autoclass:: Module
:members: setup, variable, param, apply, init, init_with_output, make_rng, variables, Variable, __setattr__
:members: setup, variable, param, bind, apply, init, init_with_output, make_rng, variables, Variable, __setattr__

Init/Apply
------------------------

.. currentmodule:: flax.linen
.. autofunction:: apply
.. autofunction:: init
.. autofunction:: init_with_output

Variables
----------------------
Expand Down
17 changes: 9 additions & 8 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ def loss_fn(params):

@jax.jit
def eval(params, images, z, z_rng):
recon_images, mean, logvar = model().apply({'params': params}, images, z_rng)
def eval_model(vae):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you ported this example to use nn.apply?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not essential but I think it's nice to have at least one usage in the examples. Here we avoid passing custom methods and a double apply

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes I now fully agree that this is a good idea. I always disliked the method argument to apply... (And if we really want to we /could/ get rid of it now... But not sure)

recon_images, mean, logvar = vae(images, z_rng)
comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1)])

comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1)])
generate_images = vae.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)
return metrics, comparison, generate_images

generate_images = model().apply({'params': params}, z, method=VAE.generate)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)

return metrics, comparison, generate_images
return nn.apply(eval_model, model())({'params': params})


def prepare_image(x):
Expand Down
2 changes: 1 addition & 1 deletion flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
from .axes_scan import broadcast
from .frozen_dict import FrozenDict, freeze, unfreeze
from .tracers import current_trace, trace_level, check_trace_level
from .scope import Scope, Array, apply, init
from .scope import Scope, Array, apply, init, bind
from .lift import scan, vmap, jit
45 changes: 33 additions & 12 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,14 @@ def invalidate(self):
"""Invalidates the Scope."""
self._invalid = True

def variables(self) -> Collection:
def mutable_variables(self) -> VariableDict:
"""Returns an immutable copy of the mutable variables belonging to this Scope."""
self._populate_collections()
xs = {k: v for k, v in self._variables.items()
if in_filter(self.mutable, k)}
return freeze(xs)

def variables(self) -> VariableDict:
"""Returns an immutable copy of the variables belonging to this Scope."""
self._populate_collections()
return freeze(self._variables)
Expand Down Expand Up @@ -576,6 +583,29 @@ def _unfreeze_variables(variables, mutable):
return new_variables


def bind(variables: VariableDict,
rngs: Optional[RNGSequences] = None,
mutable: CollectionFilter = False):
"""Bind variables and rngs to a new ``Scope``.

bind provides a ``Scope`` instance without transforming a function
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ''bind'' (i.e. backticks)

with ``apply``. This is particulary useful for debugging and
interactive use cases like notebooks where a function would limit
the ability split up code into different cells.

a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional
and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore,
we recommend using ``apply`` when code should be reusable and compatible
across the JAX software ecosystem.
"""
if not _is_valid_variables(variables):
raise errors.ApplyScopeInvalidVariablesError()
if rngs is not None and not _is_valid_rngs(rngs):
raise errors.ApplyScopeInvalidRngsError()
new_variables = _unfreeze_variables(variables, mutable)
return Scope(new_variables, rngs=rngs, mutable=mutable)


def apply(fn: Callable[..., Any],
mutable: CollectionFilter = False) -> Callable[..., Any]:
"""Functionalize a `Scope` function.
Expand All @@ -593,19 +623,10 @@ def wrapper(variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
**kwargs) -> Union[Any, Tuple[Any, VariableDict]]:

if not _is_valid_variables(variables):
raise errors.ApplyScopeInvalidVariablesError()
if rngs is not None and not _is_valid_rngs(rngs):
raise errors.ApplyScopeInvalidRngsError()
new_variables = _unfreeze_variables(variables, mutable)
with Scope(new_variables, rngs=rngs, mutable=mutable).temporary() as root:
with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
y = fn(root, *args, **kwargs)
if mutable is not False:
mutated_variables = {k: v
for k, v in new_variables.items()
if in_filter(mutable, k)}
return y, freeze(mutated_variables)
return y, root.mutable_variables()
else:
return y

Expand Down
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
make_causal_mask, combine_masks)
from ..core import broadcast
from .linear import Conv, ConvTranspose, Dense, DenseGeneral, Embed
from .module import Module, compact, enable_named_call, disable_named_call, Variable
from .module import Module, compact, enable_named_call, disable_named_call, Variable, init, init_with_output, apply
from .normalization import BatchNorm, GroupNorm, LayerNorm
from .pooling import avg_pool, max_pool
from .recurrent import GRUCell, LSTMCell, ConvLSTM, OptimizedLSTMCell
Expand Down
Loading