In [1]:
# Uncomment to run the notebook in Colab
# ! pip install -q "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install -q --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [2]:
# check available devices
import jax

In [3]:
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()



jax backend cpu


[CpuDevice(id=0)]

# 🎛 The 3-steps workflow 🎛

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eserie/wax-ml/blob/main/docs/notebooks/04_The_three_steps_workflow.ipynb)

It is already very useful to be able to execute a JAX function on a dataframe in a single work step
and with a single command line thanks to WAX-ML accessors.

The 1-step WAX-ML's stream API works like that:
```python
<data-container>.stream(...).apply(...)
```

But this is not optimal because, under the hood, there are mainly three costly steps:
- (1) (synchronize | data tracing | encode): make the data "JAX ready"
- (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
- (3) (format): convert data back to pandas/xarray/numpy format.

With the `wax.stream` primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the "calculation step" when
you are doing research.
You can then take full advantage of the JAX primitives, especially the `jit` primitive.

Let's illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.

## Imports

In [4]:
import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr

from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.external.eagerpy import convert_to_tensors
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll

register_wax_accessors()

## Performance on big dataframes

### Generate data

In [5]:
T = 1.0e5
N = 1000

In [6]:
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)

CPU times: user 2.46 s, sys: 211 ms, total: 2.67 s
Wall time: 2.67 s


### Pandas EWMA

In [7]:
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()

CPU times: user 2.77 s, sys: 444 ms, total: 3.21 s
Wall time: 3.21 s


### WAX-ML EWMA

In [8]:
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()

CPU times: user 2.06 s, sys: 390 ms, total: 2.45 s
Wall time: 2.4 s


It's a little faster, but not that much faster...

### WAX-ML EWMA (without format step)

Let's disable the final formatting step (the output is now in raw JAX format):

In [9]:
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()

CPU times: user 444 ms, sys: 245 ms, total: 689 ms
Wall time: 687 ms


In [10]:
type(df_ewma_wax_no_format)

jaxlib.xla_extension.DeviceArray

Let's check the device on which the calculation was performed (if you have GPU available, this should be `GpuDevice` otherwise it will be `CpuDevice`):

In [11]:
df_ewma_wax_no_format.device()

CpuDevice(id=0)

That's better! In fact (see below)
there is a performance problem in the final formatting step.
See WEP3 for a proposal to improve the formatting step.

### Generate data (in dataset format)

WAX-ML `Sream` object works on datasets.
So let's transform the `DataFrame` into a xarray `Dataset`:

In [12]:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")

## Step (1) (synchronize | data tracing | encode)

In this step,  WAX-ML do:
- "data tracing" : prepare the indices for fast access tin the JAX function `access_data`
- synchronize streams if there is multiple ones.
  This functionality have options : `freq`, `ffills`
- encode and convert data from numpy to JAX: use encoders for `datetimes64` and `string_`
  dtypes. Be aware that by default JAX works in float32
  (see [JAX's Common Gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) to work in float64).

We have a function `Stream.prepare` that implement this Step (1).
It prepares a function that wraps the input function with the actual data and indices
in a pair of pure functions (`TransformedWithState` Haiku tuple).

In [13]:
%%time
stream = dataframe.wax.stream()

CPU times: user 18 µs, sys: 2 µs, total: 20 µs
Wall time: 22.9 µs


Define our custom function to be applied on a dict of arrays
having the same structure than the original dataset:

In [14]:
def my_ewma_on_dataset(dataset):
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

In [17]:
transform_dataset, jxs = stream.prepare(dataset, my_ewma_on_dataset)

Let's definite the init parameters and state of the transformation we
will apply.

### Init params and state

In [18]:
from wax.unroll import init_params_state

In [19]:
rng = jax.random.PRNGKey(42)
params, state = init_params_state(transform_dataset, rng, jxs)

In [20]:
params

FlatMapping({'ewma': FlatMapping({'alpha': DeviceArray(0.1, dtype=float32)})})

In [21]:
assert state["ewma"]["count"].shape == (N,)
assert state["ewma"]["mean"].shape == (N,)

## Step (2) (compile | code tracing | execution)

In this step we:
- prepare a pure function (with
  [Haiku's transform mechanism](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-transforms))
  Define a "transformation" function which:
    - access to the data
    - apply another transformation, here: EWMA

- compile it with `jax.jit`
- perform code tracing and execution (the last line):
    - Unroll the transformation on "steps" `xs` (a `np.arange` vector).

In [22]:
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

In [23]:
outputs.device()

CpuDevice(id=0)

Once it has been compiled and "traced" by JAX, the function is much faster to execute:

In [24]:
%%timeit
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

1.55 s ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 1.63 s, sys: 82.7 ms, total: 1.72 s
Wall time: 1.56 s


This is 3x faster than pandas implementation!

(The 3x factor is obtained by measuring the execution with %timeit.
We don't know why, but when executing a code cell once at a time, then the execution time can vary a lot and we can observe some executions with a speed-up of 100x).

### Manually prepare the data and manage the device

In order to manage the device on which the computations take place,
we need to have even more control over the execution flow.
Instead of calling `stream.prepare` to build the `transform_dataset` function,
we can do it ourselves by :
- using the `stream.trace_dataset` function
- converting the numpy data in jax ourself
- puting the data on the device we want.

In [28]:
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")

We explicitly set data on CPUs (the is not needed if you only have CPUs):

In [29]:
from jax.tree_util import tree_leaves, tree_map

cpus = jax.devices("cpu")
jnp_data, jnp_index, jxs = tree_map(
    lambda x: jax.device_put(x, cpus[0]), (jnp_data, jnp_index, jxs)
)
print("data copied to CPU device.")

data copied to CPU device.


We have now "JAX-ready" data for later fast access.

Let's define the transformation that wrap the actual data and indices in a pair of
pure functions:

In [30]:
%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
    dataset = tree_access_data(jnp_data, jnp_index, step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

CPU times: user 149 µs, sys: 1 µs, total: 150 µs
Wall time: 156 µs


And we can call it as before:

In [31]:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 387 ms, sys: 240 ms, total: 627 ms
Wall time: 544 ms


In [32]:
outputs.device()

CpuDevice(id=0)

## Step(3) (format)
Let's come back to pandas/xarray:

In [33]:
%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)

CPU times: user 766 ms, sys: 8.93 ms, total: 775 ms
Wall time: 774 ms


It's quite slow (see WEP3 enhancement proposal).

## GPU execution

## GPU execution

Let's look with execution on GPU

In [34]:
try:
    gpus = jax.devices("gpu")
    jnp_data, jnp_index, jxs = tree_map(
        lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
    )
    print("data copied to GPU device.")
    GPU_AVAILABLE = True
except RuntimeError as err:
    print(err)
    GPU_AVAILABLE = False

Unknown backend gpu. Available: ['interpreter', 'cpu']


Let's check that our data is on the GPUs:

In [35]:
tree_leaves(jnp_data)[0].device()

CpuDevice(id=0)

In [36]:
tree_leaves(jnp_index)[0].device()

CpuDevice(id=0)

In [37]:
jxs.device()

CpuDevice(id=0)

In [38]:
%%time
if GPU_AVAILABLE:
    rng = next(hk.PRNGSequence(42))
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 8.82 µs


Let's redefine our function `transform_dataset` by explicitly specify to `jax.jit` the `device` option.

In [39]:
%%time
if GPU_AVAILABLE:

    @hk.transform_with_state
    def transform_dataset(step):
        dataset = tree_access_data(jnp_data, jnp_index, step)
        return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])

    transform_dataset = type(transform_dataset)(
        transform_dataset.init, jax.jit(transform_dataset.apply, device=gpus[0])
    )

    rng = next(hk.PRNGSequence(42))
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 7.15 µs


In [44]:
outputs.device()

CpuDevice(id=0)

In [45]:
%%timeit
if GPU_AVAILABLE:
    outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

17.5 ns ± 0.151 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
