### Slide: Problem Statement

In many general-purpose implementations of RL algorithms
it is necessary to work with **structured runtime data**.

Specifically, common use cases are

> stacking or slicing **homogeneous data** seated in a deeply **nested object**

> recording and retrieving pieces of **structured data** from some storage

<br>

### Slide: Nested Objects

A nested object is defined _recursively_ as a __container__ of
_nested objects_ or a __non-container__ leaf (_scalar_).

```text
NestedObject = Container[NestedObject] | Any
Container = dict[key=Any] | list | tuple | namedtuple
```

* `namedtuple` is a _tuple_ that behaves essentially like a _mapping_ with predefnied _string_ keys
* numpy **arrays**, torch or tensorflow **tensors** are treated as **leaves** (scalars).

<img src="assets/nested-object.png"
     alt="an example dict with a list, a tuple and another dict inside"
     title="an example of a nested object">

<img src="assets/nested-rules.png"
     alt="a dict, a tuple and a list are different nested objects. dicts with different sets of keys and lists or tuples of different size are unequal nested objects"
     title="Constraints on the structure of nested objects">

Other names are `structure` (dm-tree, tensorflow) or `pytree` (jax):

> `pytree` containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves

Or in [dm-tree](https://github.com/deepmind/tree/blob/master/docs/api.rst)
```text
Structure = Union[
    Any,
    Sequence['Structure'],
    Mapping[Any, 'Structure'],
    'AnyNamedTuple',
]
```

<br>

### Slide: Solutions

<img src="assets/how-it-works.png"
     alt="The recursive application of a callable f to the structures s-1, ..., s-n"
     title="Applying a callable to nested data structures">

Desirable properties
* **reuse and generality**: remain oblivious to the structure

* **speed**: do it many-Many-MANY times with little overhead

* **simplicity**: intuitive interface and support of built-in containers

<br>

### Slide: Solution 1 -- [dm-tree](https://github.com/deepmind/tree.git)

In [None]:
import tree as dm_tree

+ `abc.*` friendly
+ fast c/c++ implementation (pybind11)
+ standalone library

```python
dm_tree.map_structure(
    #  A callable that accepts as many arguments as there are structures.
    func,
    
    # Arbitrarily nested structures of the same layout.
    *structures,

    # check if the types of iterables within the structures match
    check_types=True,
)
```

Maps `func` through given structures.

In [None]:
# help(dm_tree.map_structure)

<br>

### Slide: Solution 2 -- [nest](https://www.tensorflow.org/api_docs/python/tf/nest) from tensorflow

In [None]:
import tensorflow.nest as tf_nest

- not `abc.*` friendly
+ fast c/c++ implementation
- ships with tensorflow

```python
tf_nest.map_structure(
    # A callable that accepts as many arguments as there are structures
    func,

    # scalar, or tuple or dict or list of constructed scalars
    #  and/or other tuples/lists, or scalars
    *structure,

    # check if the types of iterables within the structures match
    check_types=True,

    # expand composite tensors such as `tf.sparse.SparseTensor`
    #  and `tf.RaggedTensor` into their component tensors
    expand_composites=False,
)
```
Applies `func(x[0], x[1], ...)` where `x[i]` is an entry in `structure[i]`.

All structures in structure must have the same arity, and the return value
will contain results with the same structure layout.

In [None]:
# help(tf_nest.map_structure)

<br>

### Slide: Solution 3 -- [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) from JAX

In [None]:
import jax.tree_util as pytree

+ not `abc.*` friendly, supports arbitrary [extensions](https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees)
+ fast xla/c/c++ implementation (jaxlib)
- part of jax

```python
pytree.tree_map(
    # callble to be applied at the corresponding leaves of the pytrees
    f,

    # pytrees to be mapped over, with each leaf providing a positional argument to ``f``.
    *pytrees,
)
```

Maps a multi-input function over pytree args to produce a new pytree.

In [None]:
# help(pytree.tree_map)

<br>

### Slide: Solution 4 -- [plyr](https://github.com/ivannz/plyr.git)

In [None]:
import plyr

- not `abc.*` friendly
+ fast c/c++ implementation (c-python)
+ standalone library

```python
plyr.apply(
    # A callable that accepts as many arguments as there are objects
    func,
    
    # nested objects that supply arguments for `func`
    *objects,

    # (optional) perform structureal safety checks
    _safe=True,

    # (optional) whether to make call n-ary calls `func(  v_1, ..., v_n  )`,
    #  or unary calls with args packed in a tuple `func( (v_1, ..., v_n) )`
    _star=True,
    
    # (optional) a callable to finalize a rebuild nested container
    _finalizer=None,

    # keyword arguments passed to each call of `func`
    **kwargs,
)
```

Compute the function using the data from the nested objects.

In [None]:
# help(plyr.apply)

A version of `apply` with leaf broadcasting semantics in _ragged_ nested objects.
> any leaf object is**broadcasted deeper into the hierarchy** of the remaining nested structures

```python
plyr.ragged(
    # A callable that accepts as many arguments as there are objects
    func,
    
    # nested objects that supply arguments for `func`
    *objects,

    # (optional) whether to make call n-ary calls `func(  v_1, ..., v_n  )`,
    #  or unary calls with args packed in a tuple `func( (v_1, ..., v_n) )`
    _star=True,
    
    # (optional) a callable to finalize a rebuilt nested container
    _finalizer=None,

    # keyword arguments passed to each call of `func`
    **kwargs,
)
```

<img src="assets/ragged-broadcast.png"
     alt="Leaves are broadcasted deeper into the hierarchy"
     title="Broadcasting leaves through the hierarchy">

In [None]:
# help(plyr.ragged)

<br>

### Slide: Speed Benchmark

Stacking or summing vectors $\mathbb{R}^d$ with $d=100$:
* `16` structures, each being a `16`-element containers of `np.array`

Nested stacking speed measurements made on `4.2 GHz Intel Core i7 32 GB 2400 MHz DDR4`
($101$ replications, $250\times$ each).

<div>
<style scoped>
    .dataframe tbody tr th:only-of-type {
        vertical-align: middle;
    }

    .dataframe tbody tr th {
        vertical-align: top;
    }

    .dataframe thead th {
        text-align: right;
    }
</style>
<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>dm-tree</th>
      <th>plyr.apply</th>
      <th>plyr.ragged</th>
      <th>pytree</th>
      <th>tf-nest</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>dict-stack</th>
      <td>3.347 ± 0.18</td>
      <td>2.456 ± 0.06</td>
      <td>2.482 ± 0.09</td>
      <td>2.704 ± 0.04</td>
      <td>3.741 ± 0.17</td>
    </tr>
    <tr>
      <th>dict-sum</th>
      <td>1.857 ± 0.07</td>
      <td>1.040 ± 0.02</td>
      <td>1.050 ± 0.02</td>
      <td>1.282 ± 0.02</td>
      <td>2.272 ± 0.04</td>
    </tr>
    <tr>
      <th>list-stack</th>
      <td>3.083 ± 0.05</td>
      <td>2.389 ± 0.05</td>
      <td>2.392 ± 0.04</td>
      <td>2.616 ± 0.04</td>
      <td>3.572 ± 0.07</td>
    </tr>
    <tr>
      <th>list-sum</th>
      <td>1.654 ± 0.03</td>
      <td>1.018 ± 0.03</td>
      <td>1.000 ± 0.02</td>
      <td>1.220 ± 0.04</td>
      <td>2.144 ± 0.03</td>
    </tr>
  </tbody>
</table>
</div>

* Reported values are `median ± IQR` in $\times {10}^{-4}$ sec. scale.

<br>

## Trunk

### Code for benchmarking

In [None]:
appliers = {
    "plyr.apply": "from plyr import apply as apply",
    "pytree": "from jax.tree_util import tree_map as apply",
    "tf-nest": "from tensorflow.nest import map_structure as apply",
    "dm-tree": "from tree import map_structure as apply",

    # leaf-broadcasting `apply`
    "plyr.ragged": "from plyr import ragged as apply",
}

A code generator for nested objects with the specified simple structure.

In [None]:
def generate(setup, *shape, kind='list'):
    if kind == 'dict':
        line = "_x{:02d} = {{j: _x{:02d} for j in range({})}}\n"

    elif kind == 'tuple':
        line = "_x{:02d} = (_x{:02d},) * {}\n"

    elif kind == 'list':
        line = "_x{:02d} = [_x{:02d}] * {}\n"

    else:
        raise ValueError(kind)

    # generate a deeply nested structured data
    code = "_x00 = data\n"
    *lower, n_args = shape
    for lv, d in enumerate(lower, 1):
        code += line.format(lv, lv-1, d)

    # always generate a tuple at the top
    lv = len(shape)
    code += f"_x{lv:02d} = (_x{lv-1:02d},) * {n_args}\n"

    return setup + ("\n\n" if setup else "") + code, f"_x{len(shape):02d}"

The following performs measurements of the runtime speed using `timeit`.
It produces the wall-time of the given statement repeated `n_per_loop`
times measured in `n_loops` independent replications.

Roughly
```python
timings = []

...  # setup
for i in range(n_loops):
    beg = clock()
    for j in range(n_per_loop):
        ...  # statement

    timings.append((clock() - end) / n_per_loop)
```

In [None]:
import numpy

def timeit(
    setup,
    *shape,
    kind='list',
    n_loops=25,
    n_per_loop=50,
):
    from timeit import repeat
    from time import monotonic

    setup, var = generate(setup, *shape, kind=kind)
    return numpy.array(
        repeat(
            stmt=f"apply(func, *{var})",
            setup=setup,
            number=n_per_loop,
            repeat=n_loops,
            globals={},
            timer=monotonic,
        )
    ) / n_per_loop

The common setup code

In [None]:
setup = """
import numpy
from numpy import stack

{{apply}}

func = {func}
data = numpy.random.randn(100)
"""

spawn tests

In [None]:
shape = 16, 16

tests = {
    "list-stack": (
        setup.format(func="lambda *a: stack(a)"),
        shape,
        'list',
    ),
    "dict-stack": (
        setup.format(func="lambda *a: stack(a)"),
        shape,
        'dict',
    ),
    "list-sum": (
        setup.format(func="lambda *a: sum(a)"),
        shape,
        'list',
    ),
    "dict-sum": (
        setup.format(func="lambda *a: sum(a)"),
        shape,
        'dict',
    ),
}

In [None]:
n_loops, n_per_loop = 101, 250

timings = {}
for test, (code, shape, kind) in tests.items():
    for name, apply in appliers.items():
        timings[test, name] = timeit(
            code.format(apply=apply),
            *shape, kind=kind, n_loops=n_loops, n_per_loop=n_per_loop,
        )

Collect into a dataframe

In [None]:
import pandas as pd

columns, quantiles = zip(
    ('q_1', 0.25,),
    ('med', 0.50,),
    ('q_3', 0.75,),
)

results = pd.DataFrame({
    name: dict(zip(columns, numpy.quantile(data, quantiles, axis=-1)))
    for name, data in timings.items()
}).T

Rescale the timings and transform into masurement-spread format

In [None]:
exp = int(numpy.ceil(-numpy.log10(results['med'])).min())

df_pm_speed = (results * 10**exp).apply(
    lambda r: f"{r.med:0.3f} ± {r.q_3 - r.q_1:0.2f}",
    axis='columns',
).unstack(1)

Display the results

In [None]:
caption = rf"Nested stacking speed measurements ($\times {{10}}^{{-{exp}}}$-sec.)"
latex = df_pm_speed.to_latex(caption=caption)

print(caption)
df_pm_speed.style.highlight_min(axis=1)

<br>

In [None]:
print(df_pm_speed.to_html(notebook=True))

<br>

In [None]:
assert False

<br>

### Examples

In [None]:
import numpy as np

In [None]:
a = np.random.randn(10)
aa = [a] * 512
aaa = [aa] * 512

In [None]:
res_dm_tree = dm_tree.map_structure(
    lambda *a: sum(a), *aaa,
)

In [None]:
res_tf_nest = tf_nest.map_structure(
    lambda *a: sum(a), *aaa,
)

for a, b in zip(res_tf_nest, res_dm_tree):
    assert np.allclose(a, b)

In [None]:
res_pytree = pytree.tree_map(
    lambda *a: sum(a), *aaa,
)

for a, b in zip(res_pytree, res_dm_tree):
    assert np.allclose(a, b)

In [None]:
res_plyr = plyr.apply(
    lambda *a: sum(a), *aaa,
)

for a, b in zip(res_plyr, res_dm_tree):
    assert np.allclose(a, b)

In [None]:
plyr.ragged(
    lambda a, b: a + b,
    {'k1': ('a', 'b'), 'k2': 'c'},
    {'k1': 'x', 'k2': ['y', 'z']},
)

<br>

### Setup

```bash
conda update -n base -c defaults conda

# create the testing env
conda create -n plyr-compare \
"python=3.9" \
conda-forge::llvm-openmp \
conda-forge::compilers \
pip \
setuptools \
cython \
numba \
numpy \
mkl \
scipy \
pytest

conda install -n plyr-compare \
matplotlib \
jupyter \
pandas \
tqdm

conda activate plyr-compare

# instaling the libs
pip install --upgrade "jax[cpu]" dm-tree tensorflow python-plyr
```

<br>