# Setting up our environment

In [30]:
# Install the newest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git

  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [31]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

from jax.config import config
config.enable_omnistaging() # Linen requires enabling omnistaging

# Linear regression with Flax

In the previous *JAX for the impatient* notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it's done.

A dense layer is a layer that has a kernel parameter $W\in\mathcal{M}_{m,n}(\mathbb{R})$ where $m$ is the number of features as an output of the model, and $n$ the size of the input, and a bias parameter $b\in\mathbb{R}^m$. The dense layers returns $Wx+b$ from an input $x\in\mathbb{R}^n$.

This dense layer is already provided by Flax in the `flax.linen` module (here imported as `nn`).

In [32]:
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

Layers (and models in general, we'll use that word from now on) are subclasses of the `linen.Module` class.

## Model parameters & initialization

Contrarily to what you might expect coming from some other framework, parameters are not stored with the models themselves. At model creation, no parameters are initialized, and you need to do so by calling the `init` function manually to generate those, using a PRNGKey and a dummy input parameter.

In [33]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
init_params = model.init(key2, x) # Initialization call
jax.tree_map(lambda x: x.shape, init_params) # Checking output shapes

FrozenDict({'params': {'bias': (5,), 'kernel': (10, 5)}})

*Note: JAX and Flax, like numpy are row-based systems, meaning that vectors a represented as row vectors and not column. This can be seen in the shape of all tensors we talk about.*

Here we an see the result is what we expect: bias and kernel parameters of the correct size. What happens under the hood is:

*   The dummy input variable `x` is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.
*   The random PRNG key is used (and needed!) to trigger the initialization functions (those have default values provided by the module here, but could have been replaced in the `model = nn.Dense()` call using `kernel_init` and `bias_init` keyword arguments.
* Initalization functions are called to generate the intial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.
* The init function returns the initalized set of parameter (you can also get the output of the evaluation on the dummy input with the same syntax but using the `init_with_output` method instead of `init`.

As we mentionned, the parameters are never stored inside of the model itself. To evaluate the model with a given set of parameters, we just have to call the `apply` method by providing it the parameters to use as well as the input:

In [34]:
model.apply(init_params, x) # Evaluating model at point x with params init_params

DeviceArray([-0.7358944,  1.3583755, -0.7976872,  0.8168598,  0.6297793],            dtype=float32)

## Gradient descent

If you jumped here directly without going through the JAX part, here is the linear regression formulation we're going to use: from a set of data points $\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}$, we try to find a set of parameters $W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:
$$\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2$$
(Note: this is a rough explanation, theoretically we should be minimizing the expectation of the loss, however for the sake of simplicity here we consider only the sampled loss).

Here, we see that the tuple $(W,b)$ matches the parameters of the Dense layer. We'll perform gradient descent using those. Let's first generate the fake data we'll use.


In [35]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (ydim, xdim))
b = random.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W.T}})

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.normal(ksample, (nsamples, xdim))
y_samples = jax.vmap(lambda x:jnp.dot(W,x)+b)(x_samples) + 0.1*random.normal(knoise,(nsamples, ydim))
print("x shape:", x_samples.shape, "; y shape:", y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


Now let's generate the loss function (mean squarred error) with that data.

In [36]:
def make_mse(x_batched,y_batched):
  def mse(params):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x,y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred,y-pred)/2.0
    # We vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse(x_samples, y_samples)

And finally perform the gradient descent.

In [37]:
params = init_params

alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

for i in range(101):
  # We perform one gradient update
  loss_val, grad = grad_fn(params)
  params = jax.tree_multimap(lambda old,grad: old-alpha*grad, params, grad)
  if (i%10==0):
    print("Loss step {}: ".format(i), loss_val)

Loss for "true" W,b:  0.023639774
Loss step 0:  33.644146
Loss step 10:  0.54844475
Loss step 20:  0.1386285
Loss step 30:  0.05103702
Loss step 40:  0.02445298
Loss step 50:  0.015832197
Loss step 60:  0.01298588
Loss step 70:  0.01204048
Loss step 80:  0.011725718
Loss step 90:  0.011620826
Loss step 100:  0.011585844


## Build-in optimization API

Flax provides an optimization package in `flax.optim` to make your life easier when training models. The process is:

1.   You choose an optimization method (e.g. `optim.GradientDescent`, `optim.Adam`)
2.   From the previous optimization method, you create a wrapper around the parameters you're going to optimize for with the `create` method. Your parameters are accessible through the `target` field.
3. You compute the gradients of your loss with `jax.value_and_grad()`.
4. At every iteration, you compute the gradients at the current point, then use the `apply_gradient()` method on the optimizer to return a new optimizer with updated parameters.



In [38]:
from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=alpha) # Choose the method
optimizer = optimizer_def.create(init_params) # Create the wrapping optimizer with initial parameters
grad_fn = jax.value_and_grad(loss)

In [39]:
for i in range(101):
  loss_val, grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad) # Return the updated optimizer with parameters.
  if (i%10==0):
    print("Loss step {}: ".format(i), loss_val)

Loss step 0:  33.644146
Loss step 10:  0.54844475
Loss step 20:  0.1386285
Loss step 30:  0.05103702
Loss step 40:  0.02445298
Loss step 50:  0.015832197
Loss step 60:  0.01298588
Loss step 70:  0.01204048
Loss step 80:  0.011725718
Loss step 90:  0.011620826
Loss step 100:  0.011585844


## Serializing the result

Now that we're happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that.

In [40]:
from flax import serialization
bytes_output = serialization.to_bytes(optimizer.target)
dict_output = serialization.to_state_dict(optimizer.target)
print('Dict output')
print(dict_output)
print("Bytes output")
print(bytes_output)

Dict output
{'params': {'bias': DeviceArray([-1.4546485, -2.024733 ,  2.0802515,  1.2192372, -0.9952478],            dtype=float32), 'kernel': DeviceArray([[ 1.0113373 , -1.178727  ,  0.17029275, -0.38268968,
               1.0006864 ],
             [ 0.15007997,  0.29265866,  1.3902837 ,  1.7733499 ,
              -0.6942389 ],
             [ 0.01514493,  1.4338304 , -1.395063  ,  1.0604348 ,
               1.0923973 ],
             [-0.91484565,  0.12237039,  0.42970654, -0.59060466,
              -1.8361856 ],
             [ 0.32762164, -1.3771925 , -2.1653433 ,  1.0377196 ,
              -0.44736037],
             [ 1.734053  , -1.1486411 ,  0.5843343 ,  0.9914667 ,
              -0.64042985],
             [ 0.9330395 , -0.19898592,  0.8064922 , -1.2266037 ,
               0.4537833 ],
             [ 1.164198  ,  0.03010925,  0.34369385,  0.3231217 ,
              -1.1663771 ],
             [ 1.1287069 ,  1.3840592 ,  0.55565697,  0.7966899 ,
              -0.7405516 ],
           

**TODO: how do you keep the structure around ? Regenerated from model init ? What about filling another pattern that looks "alike" ?**

To load the model back, you'll need to use as a template de model parameter structure, like the one you would get from the model initialization. Here we use the previously generated `init_params` as template. Note that this will produce a new variable structure, and not mutate in-place.

In [41]:
serialization.from_bytes(init_params, bytes_output)

FrozenDict({'params': FrozenDict({'kernel': array([[ 1.0113373 , -1.178727  ,  0.17029275, -0.38268968,  1.0006864 ],
       [ 0.15007997,  0.29265866,  1.3902837 ,  1.7733499 , -0.6942389 ],
       [ 0.01514493,  1.4338304 , -1.395063  ,  1.0604348 ,  1.0923973 ],
       [-0.91484565,  0.12237039,  0.42970654, -0.59060466, -1.8361856 ],
       [ 0.32762164, -1.3771925 , -2.1653433 ,  1.0377196 , -0.44736037],
       [ 1.734053  , -1.1486411 ,  0.5843343 ,  0.9914667 , -0.64042985],
       [ 0.9330395 , -0.19898592,  0.8064922 , -1.2266037 ,  0.4537833 ],
       [ 1.164198  ,  0.03010925,  0.34369385,  0.3231217 , -1.1663771 ],
       [ 1.1287069 ,  1.3840592 ,  0.55565697,  0.7966899 , -0.7405516 ],
       [-0.10766877,  0.07589053,  0.94861233, -1.16194   ,  0.17127322]],
      dtype=float32), 'bias': array([-1.4546485, -2.024733 ,  2.0802515,  1.2192372, -0.9952478],
      dtype=float32)})})

# Defining your own models

Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we'll show you how to build simple models. To do so, you'll need to create subclasses of the base `nn.Module` class.


## Module basics

The base abstraction for models is the `nn.Module` class, and every type of predefined layers in Flax (like the previous `Dense`) is a subclass of `nn.Module`. Let's take a look and start by defining a simple, but custom multi-layer perceptron i.e. a sequence of Dense layer interleaved with calls to a non-linear activation function.

In [42]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(self, feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024090e-04  2.7864411e-05  2.4478839e-04  8.1344356e-04
  -1.0110775e-03]]


As we can see, a `nn.Module` subclass is made of:

*   A collection of data fields (`nn.Module` are Python dataclasses) - here we only have the `features` field of type `Sequence[int]`.
*   A `setup()` method that is being called at the end of the `__postinit__` where you can register submodules, variables, parameters you will need in your model.
*   A `__call__` function that returns the output of the model from a given input.
*   The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one `layers_n` sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit and fits your mental model of what you would expect.

Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the `__call__` using the `@nn.compact` annotation like so:


In [43]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024090e-04  2.7864411e-05  2.4478839e-04  8.1344356e-04
  -1.0110775e-03]]


There are however a few differences you should be aware of between the two declaration modes:

*   In the `setup` method, you don't have access to the `input` parameter, meaning that you can't rely on shape inference, thus you need to define explicitely (often as a dataclass field) the shapes that you are missing from run time.
*   In `setup`, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).
*   **TODO: Anything else? Like pros for using setup?**



## Module parameters

In the previous MLP example, we relied only on predefined layers and operators (`Dense`, `relu`). Let's imagine that you didn't have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the `@nn.compact` way to declare a new modules:

In [44]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameters:\n', init_variables)
print('output:\n', y)

initialized parameters:
 FrozenDict({'params': {'kernel': DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
             [ 0.05673932,  0.9909285 , -0.63536596],
             [ 0.76134115, -0.3250529 , -0.6522163 ],
             [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})
output:
 [[ 0.5035518   1.8548559  -0.4270196 ]
 [ 0.0279097   0.5589246  -0.43061775]
 [ 0.35471284  1.5741     -0.3286552 ]
 [ 0.5264864   1.2928858   0.10089308]]


Here, we see how both declare and assign a parameter to the model using the `self.param` method. It takes as input `(name, init_fn, *init_args)` : 

*   `name` is simply the name of the parameter that will end up in the parameter structure.
*   `init_fun` is a function with input `(PRNGKey, *init_args)` returning an Array with `init_args` the arguments needed to call the initialisation function
*   `init_args` the arguments to provide to the initialization function.

Such params can also be declared in the `setup` method, but as mentionned above, it won't be able to use shape inference as Flax is using lazy initialization at first call site.

In [45]:
# What happens when you use the mutable parameter? TODO
model.apply(init_variables, x, mutable=['params'])

(DeviceArray([[ 0.5035518 ,  1.8548559 , -0.4270196 ],
              [ 0.0279097 ,  0.5589246 , -0.43061775],
              [ 0.35471284,  1.5741    , -0.3286552 ],
              [ 0.5264864 ,  1.2928858 ,  0.10089308]], dtype=float32),
 FrozenDict({'params': {'kernel': DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
              [ 0.05673932,  0.9909285 , -0.63536596],
              [ 0.76134115, -0.3250529 , -0.6522163 ],
              [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}}))

## Variables and collections of variables

As we've seen so far, working with models means working with:

*   A subclass of `nn.Module`;
*   A pytree of parameters for the model (typically from `model.init()`);

However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normlization layers). There is a way to declare variables beyond the parameters of the model with the `variable` method.

Let's start with a (useless) model that keeps track of the number of samples it has seen.









In [62]:
class Counter(nn.Module):
  @nn.compact
  def __call__(self, x):
    offset = self.param('offset', lambda rng,n: n, 3) # Dummy parameter as example
    # easy pattern to detect if we're initializing
    is_initialized = self.has_variable('counter', 'count')
    counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))
    if is_initialized:
      counter.value += x
    return counter.value + offset


key1, key2 = random.split(random.PRNGKey(0), 2)
x = 1.0
model = Counter()
init_variables = model.init(key1, x)
print('initialized variables:\n', init_variables)

y, updated_state = model.apply(init_variables, x, mutable=['counter'])

print('updated variables:\n', updated_state)
print('output:\n', y)

initialized variables:
 FrozenDict({'params': {'offset': 3}, 'counter': {'count': DeviceArray(0, dtype=int32)}})
updated variables:
 FrozenDict({'counter': {'count': DeviceArray(1., dtype=float32)}})
output:
 4.0


Here, `updated_variables` returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:

In [63]:
y, updated_state = model.apply(updated_variables, 1.0, mutable=['counter'])
updated_variables = freeze({'params': init_variables['params'], **updated_state})
print('updated variables:\n', updated_variables)
print('output:\n', y)

updated variables:
 FrozenDict({'params': {'offset': 3}, 'counter': {'count': DeviceArray(4., dtype=float32)}})
output:
 7.0


Here, we should note that even though this example can look a little dumb, it's not very far away from how works batch normalization in practice: you keep a running average of stats related to your data.