# 🎛 The 3-steps workflow 🎛

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/eserie/wax-ml/blob/main/docs/notebooks/04_The_three_steps_workflow.ipynb)

It is already very useful to be able to execute a JAX function on a dataframe in a single work step
and with a single command line thanks to WAX accessors.

The 1-step WAX's stream API works like that:
```python
.stream(...).apply(...)
```

But this is not optimal because, under the hood, there are mainly three costly steps:
- (1) (synchronize | data tracing | encode): make the data "JAX ready"
- (2) (compile | code tracing | execution): compile and optimize a function for XLA, execute it.
- (3) (format): convert data back to pandas/xarray/numpy format.

With the `wax.stream` primitives, it is quite easy to explicitly split the 1-step workflow
into a 3-step workflow.

This will allow the user to have full control over each step and iterate on each one.

It is actually very useful to iterate on step (2), the "calculation step" when
you are doing research.
You can then take full advantage of the JAX primitives, especially the `jit` primitive.

Let's illustrate how to reimplement WAX EWMA yourself with the WAX 3-step workflow.

## Imports

In [1]:
from functools import partial

import haiku as hk
import numpy as onp
import pandas as pd
import xarray as xr
from eagerpy import convert_to_tensors

from wax.accessors import register_wax_accessors
from wax.compile import jit_init_apply
from wax.format import format_dataframe
from wax.modules import EWMA
from wax.stream import tree_access_data
from wax.unroll import dynamic_unroll

register_wax_accessors()

## Performance on big dataframes

### Generate data

In [None]:
T = 1.0e6
N = 1000

In [2]:
%%time
T, N = map(int, (T, N))
dataframe = pd.DataFrame(
    onp.random.normal(size=(T, N)), index=pd.date_range("1970", periods=T, freq="s")
)

CPU times: user 24.8 s, sys: 2.55 s, total: 27.3 s
Wall time: 27.4 s


### Pandas EWMA

In [3]:
%%time
df_ewma_pandas = dataframe.ewm(alpha=1.0 / 10.0).mean()

CPU times: user 47.1 s, sys: 7.02 s, total: 54.1 s
Wall time: 54.6 s


### WAX EWMA

In [4]:
%%time
df_ewma_wax = dataframe.wax.ewm(alpha=1.0 / 10.0).mean()



CPU times: user 22.7 s, sys: 11.5 s, total: 34.2 s
Wall time: 39.9 s


It's a little faster, but not that much faster...

### WAX EWMA (without format step)

Let's disable the final formatting step (the output is now in raw JAX format):

In [5]:
%%time
df_ewma_wax_no_format = dataframe.wax.ewm(alpha=1.0 / 10.0, format_outputs=False).mean()

CPU times: user 4.63 s, sys: 5.48 s, total: 10.1 s
Wall time: 11.7 s


That's better! In fact (see below)
there is a performance problem in the final formatting step.
See WEP3 for a proposal to improve the formatting step.

### Generate data (in dataset format)

WAX `Sream` object works on datasets to we'll move form dataframe to datasets.

In [6]:
dataset = xr.DataArray(dataframe).to_dataset(name="dataarray")

## Step (1) (synchronize | data tracing | encode)

In this step,  WAX do:
- "data tracing" : prepare the indices for fast access tin the JAX function `access_data`
- synchronize streams if there is multiple ones.
  This functionality have options : `freq`, `ffills`
- encode and convert data from numpy to JAX: use encoders for `datetimes64` and `string_`
  dtypes. Be aware that by default JAX works in float32
  (see [JAX's Common Gotchas](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) to work in float64).

In [7]:
%%time
stream = dataframe.wax.stream()
np_data, np_index, xs = stream.trace_dataset(dataset)
jnp_data, jnp_index, jxs = convert_to_tensors((np_data, np_index, xs), "jax")

  lax._check_user_dtype_supported(dtype, "asarray")


CPU times: user 4.41 s, sys: 3.04 s, total: 7.45 s
Wall time: 2.71 s


  lax._check_user_dtype_supported(dtype, "asarray")


We have now "JAX-ready" data for later fast access.

## Step (2) (compile | code tracing | execution)

In this step we:
- prepare a pure function (with
  [Haiku's transform mechanism](https://dm-haiku.readthedocs.io/en/latest/api.html#haiku-transforms))
  Define a "transformation" function which:
    - access to the data
    - apply another transformation, here: EWMA

- compile it with `jax.jit`
- perform code tracing and execution (the last line):
    - Unroll the transformation on "steps" `xs` (a `np.arange` vector).

In [8]:
%%time
@jit_init_apply
@hk.transform_with_state
def transform_dataset(step):
    dataset = partial(tree_access_data, jnp_data, jnp_index)(step)
    return EWMA(alpha=1.0 / 10.0, adjust=True)(dataset["dataarray"])


rng = next(hk.PRNGSequence(42))
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, xs)

CPU times: user 16 s, sys: 4.97 s, total: 21 s
Wall time: 26.8 s


Once it has been compiled and "traced" by JAX, the function is much faster to execute:

In [9]:
%%time
outputs, state = dynamic_unroll(transform_dataset, None, None, rng, False, xs)

CPU times: user 16.7 s, sys: 3.76 s, total: 20.5 s
Wall time: 17.6 s


This is between 10x and 40x faster (time may vary) and 130 x faster than pandas implementation!

## Step(3) (format)
Let's come back to pandas/xarray:

In [10]:
%%time
y = format_dataframe(
    dataset.coords, onp.array(outputs), format_dims=dataset.dataarray.dims
)

CPU times: user 16.7 s, sys: 2.91 s, total: 19.7 s
Wall time: 18.7 s


It's quite slow (see WEP3 enhancement proposal).