<a href="https://colab.research.google.com/github/gordicaleksa/get-started-with-JAX/blob/main/Tutorial_4_Flax_Zero2Hero_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Warming up with Flax

This notebook heavily relies on https://flax.readthedocs.io/en/latest/ + some additional code/modifications, comments/notes, etc.

### Flax basics

In [None]:
# Install Flax
!pip install --upgrade -q git+https://github.com/google/flax.git

In [7]:
import jax
from jax import lax, random, numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn  # nn is also used in PyTorch and in older Flax API

import optax  # JAX optimizers

from typing import Any, Callable, Sequence, Optional
import functools
import numpy as np

In [10]:
model = nn.Dense(features=5)

# Dense as well as all other layers inherit from Module class (same as in PyTorch...)
print(nn.Dense.__bases__)

# todo: check the source code: https://github.com/google/flax/blob/main/flax/linen/linear.py

(<class 'flax.linen.module.Module'>,)


In [11]:
seed = 0
key1, key2 = random.split(random.PRNGKey(seed))

x = random.normal(key1, (10,))  # Dummy input
params = model.init(key2, x)  # Initialization call
jax.tree_map(lambda x: x.shape, params)  # Checking output shapes

# Note1: automatic shape inference
# Note2: immutable structure
# Note2: init_with_output



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [None]:
# This is how you run prediction in Flax
model.apply(params, x)

In [None]:
try:
    y = model(x) # Returns an error
except AttributeError as e:
    print(e)

### Small example - linear regression

In [None]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (xdim, ydim))
b = random.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x_samples, W) + b 
y_samples += 0.1*random.normal(knoise,(nsamples, ydim)) # Adding noise
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

In [None]:
def make_mse_func(x_batched, y_batched):
  def mse(params):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred)/2.0
    # We vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse_func(x_samples, y_samples)

In [None]:
alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

for i in range(101):
  # We perform one gradient update
  loss_val, grads = grad_fn(params)
  params = jax.tree_multimap(lambda p, g: p - alpha * g,
                            params, grads)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Doing the same thing with dedicated optimizers!

Enter DeepMind's optax!

In [None]:
tx = optax.sgd(learning_rate=alpha)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(loss)

In [None]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

# After training you can save/load the model using flax's serialization module 

### Create custom models

In [None]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]  # data field (Module is Python's dataclass)

  def setup(self):  # because Python dataclass took the __init__ function...
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

# todo: use @nn.compact instead

Going deeper

In [None]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
      # We could also declare kernel/bias in setup() fn
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function, RNG passed implicitly
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)


from inspect import signature
print(signature(nn.initializers.lecun_normal()))

Introducing state

In [None]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s), x.shape[1:])
    # self.param will by default add this variable to 'params' collection
    # self.variable returns a reference hence .value
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)

In [None]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = variables.pop('params')
  variables = freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part

Adding the optimizer (maybe delete the above cell)

In [None]:
def update_step(tx, apply_fn, x, opt_state, params, state):

  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, state), grads = jax.value_and_grad(loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return opt_state, params, state

x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(3):
  opt_state, params, state = update_step(tx, model.apply, x, opt_state, params, state)
  print('Updated state: ', state)

# todo: flax.training.train_state.TrainState
# todo: BatchNorm ('batch_stats' collection added) see source code

More comprehensive, contrived example

In [None]:
class Block(nn.Module):
  features: int
  training: bool
  @nn.compact
  def __call__(self, inputs):
    x = nn.Dense(self.features)(inputs)
    x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)
    x = nn.BatchNorm(use_running_average=not self.training)(x)
    return x

key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)
x = random.uniform(key1, (3,4,4))

model = Block(features=3, training=True)

init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = init_variables.pop('params')

# When calling `apply` with mutable kinds, returns a pair of output, 
# mutated_variables.
y, mutated_variables = model.apply(
    init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])

# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = freeze(dict(params=init_params, 
                                **mutated_variables))

print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n', 
      jax.tree_map(jnp.shape, init_variables))
print('output:\n', y)

# Let's run these model variables during "evaluation":
eval_model = Block(features=3, training=False)
y = eval_model.apply(updated_variables, x)  # Nothing mutable; single return value.
print('eval output:\n', y)

# check out remat if you have memory expensive computation

In [None]:
# capture_intermediates
# check out https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html
# for more options

# todo: add this to some of the above examples instead
@jax.jit
def init(key, x):
  variables = CNN().init(key, x)
  return variables

@jax.jit
def predict(variables, x):
  y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
  intermediates = state['intermediates']
  fin = jax.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
  return y, fin

variables = init(jax.random.PRNGKey(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_leaves(is_finite))
assert all_finite, "non finite intermediate detected!"

filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"

y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']

### Full MNIST example

In [None]:
# 1) annotated
# 2) interactive
# 3) source code itself
# combine the above 3 here

### Bonus: going through real ImageNet CNN example

In [None]:
# jump to https://github.com/google/flax/tree/main/examples/imagenet

HuggingFace examples, community week are a good resource:

1) https://github.com/huggingface/transformers/tree/master/examples/flax

2) https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects

Source code is also your friend, library is still evolving