In [12]:
# Uncomment to run the notebook in Colab
# ! pip install "wax-ml[complete]@git+https://github.com/eserie/wax-ml.git"
# ! pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
# check available devices
import jax

In [77]:
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))
jax.devices()

jax backend cpu


[CpuDevice(id=0)]

# 🎛 The 3-steps workflow 🎛

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eserie/wax-ml/blob/main/docs/notebooks/04_The_three_steps_workflow.ipynb)

It is already very useful to be able to execute a JAX function on a dataframe in a single work step
and with a single command line thanks to WAX-ML accessors.

The 1-step WAX-ML's stream API works like that:
```python
<data-container>.stream(...).apply(...)
```

But this is not optimal because, under the hood, there are mainly three costly steps:
- (1) (synchronize | data tracing | encode): make the data "JAX ready"
- (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
- (3) (format): convert data back to pandas/xarray/numpy format.

With the `wax.stream` primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the "calculation step" when
you are doing research.
You can then take full advantage of the JAX primitives, especially the `jit` primitive.

Let's illustrate how to reimplement WAX-ML EWMA yourself with the WAX-ML 3-step workflow.

## Imports

In [1]:
from functools import partial

import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr
from eagerpy import convert_to_tensors

from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll

register_wax_accessors()

## Performance on big dataframes

### Generate data

In [2]:
T = 1.0e6
N = 1000

In [3]:
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)

CPU times: user 24.7 s, sys: 2.41 s, total: 27.1 s
Wall time: 27.1 s


### Pandas EWMA

In [4]:
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()

CPU times: user 47.1 s, sys: 5.99 s, total: 53.1 s
Wall time: 53.3 s


### WAX-ML EWMA

In [5]:
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()



CPU times: user 22.2 s, sys: 9.27 s, total: 31.5 s
Wall time: 36.9 s


It's a little faster, but not that much faster...

### WAX-ML EWMA (without format step)

Let's disable the final formatting step (the output is now in raw JAX format):

In [6]:
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()

CPU times: user 4.79 s, sys: 5.48 s, total: 10.3 s
Wall time: 15.2 s


That's better! In fact (see below)
there is a performance problem in the final formatting step.
See WEP3 for a proposal to improve the formatting step.

### Generate data (in dataset format)

WAX-ML `Sream` object works on datasets to we'll move form dataframe to datasets.

In [7]:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")

## Step (1) (synchronize | data tracing | encode)

In this step,  WAX-ML do:
- "data tracing" : prepare the indices for fast access tin the JAX function `access_data`
- synchronize streams if there is multiple ones.
  This functionality have options : `freq`, `ffills`
- encode and convert data from numpy to JAX: use encoders for `datetimes64` and `string_`
  dtypes. Be aware that by default JAX works in float32
  (see [JAX's Common Gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) to work in float64).

In [8]:
%%time
stream = dataframe.wax.stream()
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")

  lax._check_user_dtype_supported(dtype, "asarray")


CPU times: user 4.19 s, sys: 2.91 s, total: 7.1 s
Wall time: 2.55 s


  lax._check_user_dtype_supported(dtype, "asarray")


We have now "JAX-ready" data for later fast access.

## Step (2) (compile | code tracing | execution)

In this step we:
- prepare a pure function (with
  [Haiku's transform mechanism](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-transforms))
  Define a "transformation" function which:
    - access to the data
    - apply another transformation, here: EWMA

- compile it with `jax.jit`
- perform code tracing and execution (the last line):
    - Unroll the transformation on "steps" `xs` (a `np.arange` vector).

In [26]:
%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
    dataset = partial(tree_access_data, jnp_data, jnp_index)(step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])


rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 2.97 s, sys: 4.39 s, total: 7.36 s
Wall time: 9.94 s


Once it has been compiled and "traced" by JAX, the function is much faster to execute:

In [27]:
%%timeit
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

15.9 s ± 119 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 387 ms, sys: 741 ms, total: 1.13 s
Wall time: 367 ms


This is 3x faster than pandas implementation!

(The 3x factor is obtained by measuring the execution with %timeit.
We don't know why, but when executing a code cell once at a time, then the execution time can vary a lot and we can observe some executions with a speed-up of 100x).

In [31]:
53.3 * 1000 / 367
53.3 / 15.9

3.352201257861635

## Step(3) (format)
Let's come back to pandas/xarray:

In [11]:
%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)

CPU times: user 7.42 s, sys: 5.46 s, total: 12.9 s
Wall time: 6.51 s


It's quite slow (see WEP3 enhancement proposal).

## GPU execution

## GPU execution

Let's look with execution on GPU

In [39]:
import jax

In [40]:
cpus = jax.devices("cpu")

In [79]:
try:
    gpus = jax.devices("gpu")
    jnp_data, jnp_index, jxs = tree_map(
        lambda x: jax.device_put(x, gpus[0]), (jnp_data, jnp_index, jxs)
    )
    print("data copied to GPU device.")
except RuntimeError as err:
    print(err)

Unknown backend gpu. Available: ['interpreter', 'cpu']


Let's check that our data is on the GPUs:

In [80]:
tree_leaves(jnp_data)[0].device()

CpuDevice(id=0)

In [81]:
tree_leaves(jnp_index)[0].device()

CpuDevice(id=0)

In [82]:
jxs.device()

CpuDevice(id=0)

In [83]:
%%time
rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 700 ms, sys: 1.47 s, total: 2.17 s
Wall time: 592 ms


Let's redefine our function `transform_dataset` by explicitly specify to `jax.jit` the `device` option.

In [84]:
%%time


@hk.transform_with_state
def transform_dataset(step):
    dataset = partial(tree_access_data, jnp_data, jnp_index)(step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])


transform_dataset = type(transform_dataset)(
    transform_dataset.init, jax.jit(transform_dataset.apply, device=gpus[0])
)

rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)

CPU times: user 647 ms, sys: 1.37 s, total: 2.02 s
Wall time: 562 ms


In [None]:
%%timeit
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, jxs)