# Before you begin

Please <font color='red'>**MAKE A COPY**</font> of this colab to make sure your progress is saved.



# Practical ML with JAX

In this tutorial, we'll see how to use JAX to train a neural network.

In [None]:
!pip install jax numpy pandas matplotlib optax flax -U -qq

**Note:** Is you get a "std::bad_cast" error, click on "Runtime > Restart session and run all".

In [None]:
import jax
import numpy as np
import optax
import flax
from flax import linen as nn
from pprint import pprint
from jax import tree_util

## Basic JAX

In [None]:
def f(x):
 return x**2 + 5

f(2.0) # Note: 2² + 5 = 9

In [None]:
f(jnp.array([6.0, 9.0]))

In [None]:
df = jax.grad(f)
df(2.0)
# Note: d(x² + 5)/dx = 2x = 4

## Gradient descent

In [None]:
# We want to find x that minimizes "loss_fn".
def loss_fn(x):
  return x**2


# Initial guess for x
x = jnp.array(2.0)
print("Initial x: ", x)

# Gradient of loss_fn according to x
gradient_fn = jax.grad(loss_fn)

# Gradient descent algorithm
learning_rate = 0.2
num_iterations = 10
for _ in range(num_iterations):
  loss = loss_fn(x)
  gradient = gradient_fn(x)
  x = x - learning_rate * gradient
  print(f"x:{x:.5f} loss:{loss:.5f} gradient:{gradient:.5f}")

print("Final x: ", x)

## Gradient descent with other algorithms

The classical gradient descent algorithm is slow to converge when there are many dimensions (not really the case use). When using gradient descent on neural networks, it is better to use Adam, AdaGrad, RMSProp, or Momentum optimizers ([see details](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)).

Optax has a large collection of already implemented optimizers. Here is an example with Adam.



In [None]:
# Initial guess for x
x = jnp.array(2.0)
print("Initial x: ", x)

# Gradient descent algorithm with Adam
optimizer = optax.adamw(learning_rate)
optimizer_state = optimizer.init(x)

for _ in range(num_iterations):
  loss = loss_fn(x)
  gradient = gradient_fn(x)

  updates, optimizer_state = optimizer.update(gradient, optimizer_state, params=x)
  x = optax.apply_updates(x, updates)

  print(f"x:{x:.5f} loss:{loss:.5f} gradient:{gradient:.5f}")

print("Final x: ", x)

## A simple MLP

In [None]:
def mlp(params, x):  # The model prediction
  x = jax.nn.relu(jnp.dot(x, params['w1']))
  x = jnp.dot(x, params['w2'])
  return x


def get_batches(batches=100, batch_size=200, dims=10):  # Gen some synthetic data
  for i in range(batches):
    x = np.random.randn(batch_size, dims)
    y = (np.sum(x, axis=1) > 0).astype(int)
    yield x, y


def loss_fn(params, x, y):  # Objective
  logits = mlp(params, x)
  return optax.sigmoid_binary_cross_entropy(logits, y).mean()


# Function value and its derivative together.
loss_and_grads_fn = jax.value_and_grad(loss_fn)


@jax.jit
def train_step(params, x, y, lr=0.1):
  loss, grads = loss_and_grads_fn(params, x, y)
  new_params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
  return new_params, loss


# Initial model parameter values
key = jax.random.PRNGKey(1)
key, key_w1, key_w2 = jax.random.split(key, 3)
params = {
    'w1': jax.random.normal(key_w1, (10, 10)),
    'w2': jax.random.normal(key_w2, (10)),
}

# Training loop
for epoch in range(10):
  for batch_x, batch_y in get_batches():
    params, loss = train_step(params, batch_x, batch_y)
  print(f'Epoch: {epoch} Loss: {loss:.4f}')

## Dictionaries and trees

In JAX, parameters are stored as dictionaries of JAX arrays. Many functions works on and expect those.

Learn more with the [Working with pytrees](https://jax.readthedocs.io/en/latest/working-with-pytrees.html) tutorial.

In [None]:
t = {
    "a": jnp.array(1.),
    "b": jnp.array([1., 2., 3.]),
    "c": {
        "d": jnp.array(1.),
        "e": jnp.array(2.),
    },
}
pprint(t)

In [None]:
# Add 1 to all the values
pprint(jax.tree.map(lambda x:x+1, t))

In [None]:
# Compute gradient
def f(v):
  return (v["a"] * v["b"] * ( v["c"]["d"] + v["c"]["e"] )).mean()
f(t)

In [None]:
jax.grad(f)(t)

## Jitting

Instead of executing operations one after the other in python, "jitting" a function merges all the instructions and compiles them (the same way you would compile a program in c++).

Compilation takes a little bit of time, and is less flexible than execution in python, but it is much faster (especially when using hardware accelerators).

A jitted python function is only executed once. The other calls will call the compiled function. A `print` in the function will only be called once. To print something each time the function is executed (great for debug), use `jax.debug.print` instead.

Jitting / compilation is typed. if you call your function with arguments having other types, the function will be re-compiled.

Learn more [here](https://jax.readthedocs.io/en/latest/jit-compilation.html).

In [None]:
@jax.jit
def f(x):
  print("Compiling function for", x.dtype)
  jax.debug.print("Running with x={x}", x=x)
  return 2 * x + x


# Calling function with a 32 bits integer input.
f(1)
f(2)
f(3)

# Calling function with a 32 bits float input.
# The function will be re-compiled.
f(1.0)
f(2.0)
f(3.0)

## Data placement

In [None]:
# Devices available on your machine.
# Note: By default, Google Colab only have a CPU. In the settings, you can request a TPU for free (if some are available).
jax.devices()

By default, all the computation is done on GPU / TPU is one is available.
This is why JAX is something slower than numpy on some operations.
Otherwise, computation is done on CPU.



In [None]:
a = jnp.array([1,2,3])
print("a is currently stored on", a.device)

In [None]:
a = jax.device_put(a, jax.devices()[-1])
print("a is now stored on", a.device)

## JAX and Numpy

`jax.numpy` (generally alised as `jnp`) has the same API as numpy.

A `jnp.array` is a JAX array, which is different from a `numpy.array`.

It is possible to convert arrays JAX <-> Numpy.

For options with a lot of computation (e.g., multiplying matrices), JAX arrays are generally faster. For operations with a lot of data transfer (e.g., preparing a dataset), Numpy arrays are generally faster.

In [None]:
# A numpy array
a = np.array([1,2,3])
type(a)

In [None]:
# A JAX array
b = jnp.array([1,2,3])
type(b)

In [None]:
# Converting a numpy array into a jax array
type(jnp.asarray(a))

In [None]:
# Converting a jax array into a numpy array
type(np.array(b))

## FLAX's module can help tracking weights

In JAX, weights are always stored in a dictionary (see above).
FLAX's module can help creating and managing those weights.

In [None]:
class MyModel(nn.Module):
  output_dim: int = 2

  @nn.compact
  def __call__(self, x) -> jax.Array:
    # Create a layer. Also create two weight matrices of size
    # [input_dim, output_dim] and [output_dim]
    layer_1 = nn.Dense(features=self.output_dim, name="my_layer_1")

    # Create another layer
    layer_2 = nn.Dense(features=self.output_dim, name="my_layer_2")

    # Apply the layer on the data
    jax.debug.print("Running with x={x}", x=x)
    x = layer_2(layer_1(x))
    return x

# Initialize the model
model = MyModel()
x = jnp.array([1., 2., 3.])
# Create the dictionary of weights.
model_state = model.init(jax.random.PRNGKey(0), x)

model_state

You can call the model with the weights.

In [None]:
model.apply(model_state, x)

## Putting it all together to train a MLP

In [None]:
# This is our model
class MLP(nn.Module):
  num_layers: int = 2
  layer_size: int = 10
  output_dim: int = 1

  @nn.compact
  def __call__(self, x) -> jax.Array:
    for i in range(self.num_layers):
      x = nn.Dense(features=self.layer_size, name=f'layer_{i}')(x)
      x = nn.relu(x)
    x = nn.Dense(features=self.output_dim, name='final_layer')(x)
    x = jnp.squeeze(x, axis=1)
    return x

# Initialize model and its weights
model = MLP()
key, model_init_key = jax.random.split(key, 2)
x_sample, _ = next(iter(get_batches()))
model_state = model.init(model_init_key, x_sample)["params"]

# Print the weights of the model
print("model_state:\n", jax.tree.map(lambda x:x.shape,model_state))

# Print the internal layers of the model, their side and number of flops.
print("Model structure:\n", model.tabulate(
          model_init_key,
          x_sample,
          compute_flops=True,
          compute_vjp_flops=True,
      ))

# Initialize the optimizer (we use AdamW)
optimizer = optax.adamw(0.1)
optimizer_state = optimizer.init(model_state)


@jax.jit
def train_step(model_state, optimizer_state, x, y):

  def loss_fn(model_state, x, y):
    logits = model.apply({"params": model_state}, x)
    return optax.sigmoid_binary_cross_entropy(logits, y).mean()

  # Gradient descent
  loss, grads = jax.value_and_grad(loss_fn)(model_state, x, y)
  updates, new_optimizer_state = optimizer.update(grads, optimizer_state, params=model_state)
  new_model_state = optax.apply_updates(model_state, updates)

  return new_model_state, new_optimizer_state, loss


# Training loop
for epoch in range(10):
  for batch_x, batch_y in get_batches():
    model_state, optimizer_state, loss = train_step(model_state, optimizer_state, batch_x, batch_y)
  print(f'Epoch: {epoch} Loss: {loss:.4f}')