# Flax introduction

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import jax
from jax import numpy as jnp, random, jit, lax

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
tf.config.set_visible_devices([], 'GPU')

import flax
from flax import nn, optim

## Intro to Jax

Jax is a numerical computation library which aims to replicate the numpy api.

A few important things to know about Jax:
1. It is functional. This means no in-place ops and sliced assignments
2. Jax can execute computaitons on CPUs, GPUs, and TPUs.
3. functions using the jax.numpy api can be traced for automatic transformations

    a. **jit**: compile a function using XLA enabling fast execution

    b. **grad**: take the gradient of a function

    c. **vmap**: adds a batch dimension to a function

    d. **pmap**: split a computation across devices based on the first dimension of each input argument.


## Neural Networks in Jax

Before we dive into Flax what a typical neural networks component looks like when written in "native" Jax.

We decompose a learnable linear layer into two parts: a initializer function which uses a jax PRNGKey to generate a random kernel and bias and the apply function which computes the linear transformation using a set of parameters and some inputs.

In [2]:
def dense_init(rng, in_features, out_features,
               kernel_init=jax.nn.initializers.lecun_normal(),
               bias_init=jax.nn.initializers.zeros):
  k1, k2 = random.split(rng)
  # init functions take a PRNGKey and a shape tuple and return ndarrays.
  kernel = kernel_init(k1, (in_features, out_features))
  bias = bias_init(k2, (out_features,))
  return kernel, bias

def dense_apply(params, inputs):
  kernel, bias = params
  return jnp.dot(inputs, kernel) + bias

Functional Programming without abstractions naturally results into somewhat verbose but very explicit code.

Note how the random number generators and parameters are passed on explicitly to functions.
Jax has no concept of variables so we cannot hide the parameters in variables somewhere.
Similairly, there is no global random number generator which updates an internal seed.

In [3]:
params = dense_init(random.PRNGKey(0), in_features=2, out_features=4)
print(params)

(DeviceArray([[ 0.56129605,  0.98355806, -1.2807567 , -1.0105268 ],
             [ 0.05945718, -0.10787111, -0.6869378 , -0.09312173]],            dtype=float32), DeviceArray([0., 0., 0., 0.], dtype=float32))


Once we generated a set of parameters it is easy enough to apply them to some inputs.

In [4]:
x = jnp.ones((1, 2))
dense_apply(params, x)

DeviceArray([[ 0.6207532 ,  0.87568694, -1.9676945 , -1.1036485 ]], dtype=float32)

Because everything is functional we can use the functional transformations that jax provides to do useful things like taking gradients to optimize the model.

In [5]:
def loss_fn(params, x):
  y = dense_apply(params, x)
  return jnp.mean(y ** 2)
grad_fn = jax.grad(loss_fn) # by default jax.grad takes the gradient w.r.t. the first argument
grad_fn(params, x)

(DeviceArray([[ 0.3103766 ,  0.43784347, -0.98384726, -0.5518243 ],
              [ 0.3103766 ,  0.43784347, -0.98384726, -0.5518243 ]],            dtype=float32),
 DeviceArray([ 0.3103766 ,  0.43784347, -0.98384726, -0.5518243 ], dtype=float32))

## Flax Modules

The core of Flax is the Module abstraction.
Modules allow you to write parameterized functions just as if you were writing a normal numpy function with Jax.
The module api allows you to declare parameters and use them directly with the Jax api's.

A few things to know about Modules:
1. A Module is creates by defining a subclass of `flax.nn.Module` and implementing the apply method.
2. parameters are declared using `self.param(name, shape, init_func)` and return an initialized parameter value.
3. `Dense.init(rng, ...)` and `Dense.call(params, ...)` behave identically to the `dense_init` and `dense_apply` implemented earlier.

Now let's try to do redefine the dense layer using Flax Modules.

In [6]:
class Dense(nn.Module):
  """A learned linear transformation."""
  def apply(self, x, features,
            kernel_init=jax.nn.initializers.lecun_normal(),  # init functions are of the form (PrngKey, shape) => init_value
            bias_init=jax.nn.initializers.zeros):
    in_features = x.shape[-1]
    kernel_shape = (in_features, features)
    kernel = self.param('kernel', kernel_shape, kernel_init)
    bias = self.param('bias', (features,), bias_init)
    return jnp.dot(x, kernel) + bias

In [7]:
y, params = Dense.init(random.PRNGKey(0), x, features=4)
print(params)

{'kernel': DeviceArray([[ 0.04612674, -0.41086802,  0.29959238,  0.65631014],
             [-0.8772928 ,  0.31151664, -0.06712601,  0.7201758 ]],            dtype=float32), 'bias': DeviceArray([0., 0., 0., 0.], dtype=float32)}


In [8]:
Dense.call(params, x, features=4)

DeviceArray([[-0.8311661 , -0.09935138,  0.23246637,  1.376486  ]], dtype=float32)

Note how we must specify the number of features in both init and call.
Modules will often have certain parameters that are fixed for each call to init and call.

We can use `Module.partial` to apply these arguments. partial takes keyword arguments and returns a new module for which the given arguments are already applied.
It can be thought of as the equivalant of `functools.partial` for Modules. 

In [9]:
model_def = Dense.partial(features=4) # Module + hyper parameters = model defintion
_, params = model_def.init(random.PRNGKey(0), x)
model_def.call(params, x)

DeviceArray([[-0.8311661 , -0.09935138,  0.23246637,  1.376486  ]], dtype=float32)

### Composition

Modules can be composed to form more complex Modules.

Within a Module's apply function other modules behave just like functions.


In [10]:
# same as flax.nn.relu
def relu(x):
  return jnp.maximum(0., x)

class MLP(nn.Module):
  """Multi Layer Perceptron."""
  
  def apply(self, x,
            hidden_features,
            output_features,
            activation_fn):

    z = Dense(x, hidden_features)
    h = activation_fn(z)
    y = Dense(h, output_features)
    return y

model_def = MLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
y, params = model_def.init(random.PRNGKey(0), x)
print(y)

[[0.5629978  0.51556504 0.05943576 0.68959564]]


`jax.tree_map` allows you to apply a function to each leave of a pytree.
A pytree can consist of (nested) lists, tuples, dicts and other types that contain arrays.

We use the tree_map util to reveal the parameter structure of the MLP model

Here we can see that composing modules results in a structure of nested dictionaries.

In [11]:
jax.tree_map(np.shape, params)

{'0': {'bias': (8,), 'kernel': (2, 8)}, '1': {'bias': (4,), 'kernel': (8, 4)}}

#### Module name
By default Flax will use integers as keys for the parameters of sub modules. By passing the name argument we can control the parameter structure and make it more meaningful.

In [12]:
class NamedMLP(nn.Module):
  def apply(self, x,
            hidden_features,
            output_features,
            activation_fn):

    z = Dense(x, hidden_features, name='hidden')
    h = activation_fn(z)
    y = Dense(h, output_features, name='out')
    return y

model_def = NamedMLP.partial(hidden_features=8, output_features=4, activation_fn=relu)
_, params = model_def.init(random.PRNGKey(0), x)
jax.tree_map(np.shape, params)

{'hidden': {'bias': (8,), 'kernel': (2, 8)},
 'out': {'bias': (4,), 'kernel': (8, 4)}}

### Parameter sharing

Sometimes a module should be applied to multiple inputs with one set of parameters.
We can make a Module for which parameters are shared between calls using `Module.shared`.
Just like with `Module.partial` we can pass keyword arguments that are fixed for each call to the module.

In [13]:
class SimpleRNN(nn.Module):
  def apply(self, x, iterations=3):
    dense = Dense.shared(
        features=x.shape[-1],
        kernel_init=jax.nn.initializers.orthogonal(),
        name='cell')
    ys = []
    for i in range(iterations):
      x = dense(x)
      ys.append(x)
    return ys

we call the Dense layer named 'cell' 3 times but only one set of parameters shows up in the parameter structure due to weight sharing.

In [14]:
ys, params = SimpleRNN.init(random.PRNGKey(0), x)
print(ys)
jax.tree_map(np.shape, params)

[DeviceArray([[-0.8121684,  1.1577489]], dtype=float32), DeviceArray([[-1.2806695, -0.5999045]], dtype=float32), DeviceArray([[ 0.36959398, -1.3650641 ]], dtype=float32)]


{'cell': {'bias': (2,), 'kernel': (2, 2)}}

### Shape inference

Previously we initialized the model by passing in some inputs.
This is useful because it allows for Modules which automatically infer the shape of parameters based on inputs. It can also help catch errors in the model early, in the initialization phase of a program.

Nontheless, `Module.init` includes some unnecesary overhead because typically we are not interested in the actual output of the model during initialization. Therefore, we can use Jax build in lazy evaluation to get the benefits of shape inference without doing any unnecesary compute.

`Module.init_by_shape` returns only the shape and dtype of outputs but still creates fully initialized parameters. If you want to use initializers that (indirectly) depend on the values (not shape) of the inputs you should keep using `Module.init`.

In [15]:
input_spec = [((1, 2), jnp.float32)]
out_spec, params = SimpleRNN.init_by_shape(random.PRNGKey(0), input_spec)
print(ys)
jax.tree_map(np.shape, params)

[DeviceArray([[-0.8121684,  1.1577489]], dtype=float32), DeviceArray([[-1.2806695, -0.5999045]], dtype=float32), DeviceArray([[ 0.36959398, -1.3650641 ]], dtype=float32)]


{'cell': {'bias': (2,), 'kernel': (2, 2)}}

In [16]:
print(out_spec) # the outputs only define the shape and dtype
print([spec.shape for spec in out_spec])

[ShapeDtypeStruct(shape=(1, 2), dtype=float32), ShapeDtypeStruct(shape=(1, 2), dtype=float32), ShapeDtypeStruct(shape=(1, 2), dtype=float32)]
[(1, 2), (1, 2), (1, 2)]


### Model

Module makes it easy to keep track of parameters inside a Model but so far it still required explicilty keeping track of parameter structure and the init & apply functions.

Model is a thin abstraction around a Module and a set of parameter.
A Model instance is callable and functional (eg. changing parameters requires a new model instance).

Using `Module.create` or `Module.create_by_shape` will create a newly initialized set of parameters and wrap them in a Model instance. 

In [0]:
x = jnp.ones((1, 2))
ys, model = Dense.partial(features=4).create(random.PRNGKey(0), x)
jax.tree_map(np.shape, model.params)

In [0]:
model(x)

In [0]:
model.params

Parameters can be updated using the `Model.replace` method

In [0]:
biased_model = model.replace(params={'kernel': model.params['kernel'], 'bias': model.params['bias'] + 1.})
biased_model.params

Model is registerd as a container object which means that it can be passed to Jax transformations and `jax.tree_map`.

For example we can take gradients w.r.t. a model object. The returned Model object will contain the gradients corresponding to each parameter.

In [0]:
def loss_fn(model):
  y = model(x)
  return jnp.mean(y ** 2)

model_grad = jax.grad(loss_fn)(model)
model_grad.params

### State

Flax allows stateful operations to happen within a limited scope.

stateful modules are defined using the `Module.state` api. It returns a state object that has a property value that can be assigned to.

A typical use stateful module is BatchNorm which maintains a moving average of batch statistics (mean, variance).
During training the moving averages are updated such that they can be used during test time.



In [0]:
# simplified version of nn.BatchNorm
class BatchNorm(nn.Module):
  def apply(self, x, red_axis=0, eps=1e-5,
            momentum=0.99, training=False,
            gamma_init=nn.initializers.ones,
            beta_init=nn.initializers.zeros):

    # compute the moments of the input
    mean = x.mean(red_axis, keepdims=True)
    var = jnp.square(x - mean).mean(red_axis, keepdims=True)

    # define the state variables
    ra_mean = self.state('mean', mean.shape, nn.initializers.zeros)
    ra_var = self.state('var', var.shape, nn.initializers.ones)

    if not self.is_initializing():  # during init we ignore the moving averages completly
      if training:
        # during training the moving averages are updated
        alpha = 1. - momentum
        ra_mean.value += alpha * (mean - ra_mean.value)
        ra_var.value += alpha * (var - ra_var.value)
      else:
        # if we are not training we use the moving averages
        mean = ra_mean.value
        var = ra_var.value

    # standardize the input
    y = (x - mean) / jnp.sqrt(var + eps)

    # learn the scale and bias of the output
    gamma = self.param('gamma', mean.shape, gamma_init)
    beta = self.param('beta', mean.shape, beta_init)
    return gamma * y + beta

Stateful modules require special care when used. The `nn.stateful` context manager defines a scope in which stateful operations are allowed. Outside of this scope the state becomes immutable.

The state is stored in a `nn.Collection` object which internally stores the state as a dictionary.

`nn.stateful` takes a Collection containing the current state and returns a new Collection that contains the updated state. By default a new Collection will be created.

When using `nn.stateful(state, mutable=False)` the state can be read but any updates will raise an error. This is often useful during test time to garantuee test data does not affect the model.

In [0]:
class MyModel(nn.Module):

  def apply(self, x, training=False):
    x = Dense(x, features=4)
    x = BatchNorm(x, training=training, momentum=0., name='batch_norm')
    return x

dist_a = lambda rng, shape: random.normal(rng, shape) * jnp.array([[1., 3.]])

x_a = dist_a(random.PRNGKey(1), (1024, 2))
print('std. deviation of input:', x_a.std(0))

with nn.stateful() as init_state:
  y, params = MyModel.init(random.PRNGKey(2), x_a)
print('std. deviation of output (init):', y.std(0))

with nn.stateful(init_state) as new_state:
  y = MyModel.call(params, x_a, training=True)
print('std. deviation of output (training):', y.std(0))

with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_a, training=False)
print('std. deviation of output (testing):', y.std(0))

The state can be inspected using `Collection.as_dict()`.

Each Module has a path like key into the Collection (eg. '/some_module/nested_module/dense').

In [0]:
init_state.as_dict()

In [0]:
new_state.as_dict()

The stateful mechanism forces the user to be explicit about stateful operations.

One motivating example for this approach is to enforce that state is not updated at test time.

Another benefit is that it is easier to replace the state when necessary.
For example let say we want to apply this model on a second input distribution (b) with different statistics.


In [0]:
dist_b = lambda rng, shape: random.normal(rng, shape) * jnp.array([[2., 5.]])

x_b = dist_b(random.PRNGKey(1), (1024, 2))

with nn.stateful(new_state, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print(y.std(0)) # this will not be properly normalized!

We can solve the skew in statistics by creating a seperate state for this alternative input distribution.

In [0]:
with nn.stateful(init_state) as state_b:
  y = MyModel.call(params, x_b, training=True)
print('std. deviation of output (training):', y.std(0))

with nn.stateful(state_b, mutable=False):
  y = MyModel.call(params, x_b, training=False)
print('std. deviation of output (testing):', y.std(0))

## Optimizer

In [0]:
rng = random.PRNGKey(0)
rng, key1, key2 = random.split(rng, 3)
n = 30
x = jnp.linspace(-5., 5.)
X = random.uniform(key1, (n,), minval=-5., maxval=5.)
f = lambda x: 2. * x
Y = f(X) + random.normal(key2, (n,))
plt.plot(x, f(x))
plt.scatter(X, Y)
plt.show()

In [0]:
class LinearRegression(nn.Module):
  def apply(self, x):
    return nn.Dense(x[..., None], features=1)[..., 0]

rng, key = random.split(rng)
_, model = LinearRegression.create(key, X)

plt.plot(x, f(x))
plt.plot(x, model(x))
plt.scatter(X, Y)
plt.show()

In [0]:
optimizer_def = optim.Momentum(learning_rate=0.1)
optimizer = optimizer_def.create(model)

train_steps = 100

def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()

for i in range(train_steps):
  optimizer, loss = optimizer.optimize(loss_fn)
print('mean square error:', loss)

trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()

In [0]:
loss, grads = optimizer.compute_gradients(loss_fn)
new_optimizer = optimizer.apply_gradient(grads)
new_optimizer.target.params

## Advanced features

### Selective Optimization

In [0]:
slope_opt_def = optim.Momentum(learning_rate=0.1)
bias_opt_def = optim.Momentum(learning_rate=0.1, weight_decay=10.)
slope_traversal = optim.ModelParamTraversal(lambda name, param: 'kernel' in name)
bias_traversal = optim.ModelParamTraversal(lambda name, param: 'bias' in name)
optimizer_def = optim.MultiOptimizer((slope_traversal, slope_opt_def), (bias_traversal, bias_opt_def))

optimizer = optimizer_def.create(model)

train_steps = 100

def loss_fn(model):
  Y_hat = model(X)
  return jnp.square(Y - Y_hat).mean()

for i in range(train_steps):
  optimizer, loss = optimizer.optimize(loss_fn)
print('mean square error:', loss)

trained_model = optimizer.target
print(trained_model.params)
plt.plot(x, f(x))
plt.plot(x, trained_model(x))
plt.scatter(X, Y)
plt.show()

### Multi method modules

In [0]:
class MultiMethodModule(nn.Module):

  def apply(self, x):
    kernel = self.param('kernel', (), lambda _, shape: jnp.full(shape, 2.))
    return x * kernel

  @nn.module_method
  def decode(self, x):
    kernel = self.get_param('kernel')
    return x * kernel

x = 2. ** jnp.arange(5)
y, model = MultiMethodModule.create(random.PRNGKey(0), x)
print(x[1:], y[:-1])
model.decode(1.)

In [0]:
def body_fn(x, _):
  y = model.decode(x)
  return y, y

lax.scan(body_fn, 1., (), length=4)[1]

### Transforming sub module parameters

In [0]:
def add_scale(module):
  class ScaleWrapper(nn.Module):
    """Add a learnable scale to the kernel of a module."""

    def apply(self, *args, **kwargs):
      def init_fn(rng, _):
        _, params = module.init(rng, *args, **kwargs)
        # here we could change the initial parameters of the wrapped module
        return params
      params = self.param('params', None, init_fn)
      # here change transform parameters every call
      assert 'kernel' in params
      kernel = params['kernel']
      features = kernel.shape[-1]
      scale = self.param('scale', (features,), nn.initializers.ones)
      scaled_kernel = kernel * scale
      scaled_params = params.copy()
      scaled_params['kernel'] = scaled_kernel

      return module.call(scaled_params, *args, **kwargs)
  return ScaleWrapper

x = jnp.ones((1, 2))
model_def = add_scale(Dense).partial(features=4)
y, params = model_def.init(random.PRNGKey(0), x)
params

### Collection

## Gotchas