# What are Jax and Haiku?

[Jax](https://github.com/google/jax) is a library for "composable transformations of Python+NumPy programs". It includes helpful functions to easily compute the gradient of a function, `grad`, compile functions in a just-in-time manner with XLA for use on GPU/TPU, `jit`, and much more. It it primarily targeted for high-performance machine learning research.

[Haiku](https://github.com/deepmind/dm-haiku/) is a neural network library for JAX that adds helpful model/layers/... classes, while allowing full access to JAX's pure function transformations.

# Why Dask + Jax/Haiku?

Whilst Jax supports much high performance and parallel capababilities, e.g. `pmap`, Dask supports a large number of backends (including [several types of distributed clusters](https://blog.dask.org/2020/07/23/current-state-of-distributed-dask-clusters)) and also integrates very well with much of the PyData stack. Thus in this notebook we explore how we might integrate these libraries.

# Example: Learning the sine function with a neural network

### Setup

Firstly let's setup a dask distributed client and install/import the libaries we need.

In [None]:
from dask.distributed import Client, as_completed
client = Client()
client

In [None]:
!pip install hvplot jax jaxlib dm-haiku

In [None]:
import math
from typing import Tuple

import dask
import dask.dataframe as dd
import haiku as hk
import holoviews as hv
import hvplot.pandas
import jax
import jax.numpy as jnp
import pandas as pd
import numpy as np
from dask_ml.preprocessing import StandardScaler
from jax.experimental import optix

Here we create the example data (representing the sine wave/function), and covert it into a Pandas dataframe.

In [None]:
num_points = 50
x = np.linspace(start=0., stop=2*np.pi, num=num_points)
y = np.sin(x)

df_all = pd.DataFrame({
    "x": x.flatten(),
    "y": y.flatten(),
})

Randomly split the data into train/validation datasets.

In [None]:
train_frac = 0.8
df_train = df_all.sample(frac=train_frac, random_state=42)
df_validation = df_all.drop(labels=df_train.index)

Finally we convert the Pandas dataframe into a Dask dataframe. Here we also define the batch size for use when training the model, which corresponds to the number of data points in each partition of a Dask dataframe.

In [None]:
batch_size =  32  # i.e. the number of data samples passed to the neural network model

ddf_train = dd.from_pandas(df_train, chunksize=batch_size)
ddf_validation = dd.from_pandas(df_validation, chunksize=batch_size)

Visualising the data confirms our example data does indeed correspond to the sine function:

In [None]:
(df_train.hvplot.scatter(x="x", y="y", label='Training data') * 
 df_validation.hvplot.scatter(x="x", y="y", label='Validation data')).opts(title="Dataset")

### Define model and initialize ready for training

To help the training process, let's standardize `x` (the input or "feature") so that it has roughly mean 0 and standard deviation 1. Here we use `dask_ml`.

In [None]:
scaler = StandardScaler()
ddf_train["scaled_x"] = scaler.fit_transform(ddf_train[["x"]]).x
ddf_validation["scaled_x"] = scaler.transform(ddf_validation[["x"]]).x

Define and initialize a Haiku neural network.

In [None]:
def net_function(x: jnp.ndarray) -> jnp.ndarray:
    net = hk.Sequential([
      hk.Linear(50), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(50), jax.nn.relu,
      hk.Linear(1),
    ])
    pred = net(x)
    return pred

def initialize_net_function():
    net_transform = hk.transform(net_function)
    rng = jax.random.PRNGKey(42)
    num_features = 1
    example_x = jnp.array(np.random.random([batch_size, num_features]))
    params = net_transform.init(rng, example_x)
    return net_transform, params

net_transform, params = initialize_net_function()

Define a mean squard error loss function.

In [None]:
@jax.jit
def loss_function(params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:    
    def mean_squared_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:
        loss = jnp.average((y_true - y_pred) ** 2)
        return loss
    
    y_pred: jnp.ndarray = net_transform.apply(params, x)
    loss_value: jnp.ndarray = mean_squared_error(y_true, y_pred)
    return loss_value

Define and initialise an optimizer.

In [None]:
optimizer: optix.InitUpdate = optix.adam(learning_rate=1e-3)
opt_state: optix.OptState = optimizer.init(params)

Define a predict function, and create a wrapper for use with Dask dataframe `.map_partitions()`.

In [None]:
@jax.jit
def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray:
    return net_transform.apply(params, x)


def dask_predict_wrapper(df: pd.DataFrame, params: hk.Params) -> jnp.ndarray:
    scaled_x = jnp.array(df[["scaled_x"]].values)
    return predict(params, scaled_x).flatten()

Let's predict before training our model to see how well it performs (it indeed performs badly as expected):

In [None]:
ddf_validation["y_pred_no_training"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)

(df_train.hvplot.scatter(x="x", y="y", label='Training data') * 
 ddf_validation.compute().hvplot.scatter(x="x", y="y_pred_no_training", label='Predicted validation data')).opts(title="No training")

### Training

There are a number of ways we can use Dask to help with deep learning/neural network training. For this example we implement the following which focus on when you run out of RAM/memory:

1. Train model using Dask as a lazy loader of data.
2. Data-parallel training of deep learning models: Compute gradients in parallel.

Other common use cases for distributed deep learning training include when you are CPU bound, e.g. distributed hyperparameter optimization or distributed training of an ensemble of models. This are omited here and are left as an exercise for the reader.

#### CASE 1 - Train model using Dask as a lazy loader of data

If we have data that is larger than the RAM we have available, we can use Dask to load and train on it batch-by-batch, one at a time.

Let's first define an update function (the main learning function that updates the model parameters).

In [None]:
@jax.jit
def update(params: hk.Params, opt_state: optix.OptState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[hk.Params, optix.OptState]:
    grads = jax.grad(loss_function)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optix.apply_updates(params, updates)
    return params, opt_state

Next we define our training loop.

In [None]:
for epoch_number in range(200):  # num epochs
    for ddf_one_partition in ddf_train.partitions:  # for each batch
        df_one_partition = ddf_one_partition.compute()
        scaled_x = jnp.array(df_one_partition[["scaled_x"]].values)
        y = jnp.array(df_one_partition[["y"]].values)
        params, opt_state = update(params, opt_state, scaled_x, y)

Let's now predict after training our model, and visualise the prediction to see how well it performs.

In [None]:
ddf_validation["y_pred_with_training_CASE1"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)

(df_train.hvplot.scatter(x="x", y="y", label='Training data') * 
 ddf_validation.compute().hvplot.scatter(x="x", y="y_pred_with_training_CASE1", label='Predicted validation data')).opts(title="With training CASE 1")

#### CASE 2 - Data-parallel training of deep learning models: Compute gradients in parallel.

First let's reset the model parameters and optimizer state (so that we don't benefit from the training in Case 1).

In [None]:
_, params = initialize_net_function()
opt_state = optimizer.init(params)

In Case 1, we load data only when it's needed (i.e. when we are ready to train that particular batch of data), preventing potential RAM/memory problems. However we are not benefiting from the parallel capabilities of Dask to train multiple batches at the same time.

For this case, we compute the gradients of the loss function in parallel, and pull these back to the client where we update the model with the optimizer there.

In [None]:
@jax.jit
def compute_grads(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
    grads = jax.grad(loss_function)(params, x, y)
    return grads

def dask_compute_grads_one_partition_wrapper(ddf: dd.DataFrame, params: hk.Params) -> jnp.ndarray:
    scaled_x = jnp.array(ddf[["scaled_x"]].values)
    y = jnp.array(ddf[["y"]].values)
    grads = compute_grads(params, scaled_x, y)
    return grads

The training loop now looks like:

In [None]:
for epoch_number in range(200):  # num epochs
    futures = []
    for ddf_one_partition in ddf_train.partitions:
        # Compute the gradients in parallel
        futures.append(client.submit(dask_compute_grads_one_partition_wrapper, ddf_one_partition, params))
    
    for future, grads in as_completed(futures, with_results=True):
        # Bring the gradients back to the client, and update the model with the optimizer on the client
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optix.apply_updates(params, updates)

Let's check our predictions.

In [None]:
ddf_validation["y_pred_with_training_CASE2"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)

(df_train.hvplot.scatter(x="x", y="y", label='Training data') * 
 ddf_validation.compute().hvplot.scatter(x="x", y="y_pred_with_training_CASE2", label='Predicted validation data')).opts(title="With training CASE 2")