In [None]:
! pip install -q flax

In [1]:
import flax.linen as nn
import jax
import jax.numpy as jnp

# Using stateful modules

## Contents

1. Introduction
  * 
1. Using ``bind`` to create stateful modules:
  * Show that use can use ``bind`` on top-level and dictly call it without having to use ``apply``.
1. Accessing submodules:
  * Show that you can access submodules defined in ``setup``.
  * Show how to use ``bind`` + ``unbind`` to extract a submodule and its variables.
1. Example of using ``bind`` for debbugging / inspecting the computation:
  * Show how to use ``bind`` to manually run a ``Sequential`` module layer by layer.
1. Edge cases of using ``bind``:
  * JIT problems: show that you should not pass a bounded as a capture to a JITed function.
  * Memory leak: show an example of how ``bind`` might cause a memory leak, and give some tips on how to avoid it.

## Introduction

In [3]:
class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = nn.Dense(features=8)
    self.decoder = nn.Dense(features=4)

  def __call__(self, x):
    z = self.encoder(x)
    return self.decoder(z)

x = jnp.ones((1, 5))
module = AutoEncoder()
variable = module.init(jax.random.PRNGKey(0), x)

In [3]:
try:
    module(x)
except Exception as e:
    pass

In [4]:
try:
    encoder = module.encoder
except AttributeError as e:
    print(e)

"AutoEncoder" object has no attribute "encoder". If "encoder" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.


## Using `bind` to create stateful modules

In [5]:
bound_module = module.bind(variable)

In [8]:
bound_module(x)

Array([[-0.47444162, -1.2290186 ,  1.1803787 ,  2.7584562 ]], dtype=float32)

In [10]:
bound_encoder = bound_module.encoder # Works!

## Using `unbind` to create stateless modules

In [None]:
encoder, encoder_vars = bound_encoder.unbind()

## Extracting submodules and their variables

In [None]:
encoder, encoder_vars = module.bind(variable).encoder.unbind()

## Example: using `bind` for interactive computation

In [4]:
module = nn.Sequential([
    nn.Dense(8),
    nn.relu,
    nn.Dense(4),
])
variables = module.init(jax.random.PRNGKey(0), x)

In [11]:
def get_name(layer):
    return type(layer).__name__ if isinstance(layer, nn.Module) else layer.__name__

bound_module = module.bind(variables)
bound_module._try_setup()

inputs = x
for layer in bound_module.layers:
    outputs = layer(inputs)
    print(f'{get_name(layer)}: {inputs.shape} -> {outputs.shape}')
    inputs = outputs

Dense: (1, 5) -> (1, 8)
relu: (1, 8) -> (1, 8)
Dense: (1, 8) -> (1, 4)
