# Optimizing AxisArray manipulations

I frequently have to remind myself of a few things when working with AxisArrays.

* Modifying a single field will affect all other references to this object, which could be in other branches of the pipeline.
* The `axes` field is a dict and modifying it will affect all other references to this `axes` dict.
* `deepcopy` is super slow and usually unnecessary because we often only modify 1 or 2 fields.

So I came up with a general pattern to transform an AxisArray object that does not use `deepcopy` yet breaks the links to other references on the manipulated fields. The general rules are as follows:

1. The outgoing message must (out_msg) be a new instance.
2. The out_msg `axes` must be a new dictionary if ANY value is to be modified.
3. Values in `axes` might be mutable references themselves, so modify-by-replacement.

I use this notebook to remind myself of the problem and to profile the different possible solutions.

In [1]:
import copy
from dataclasses import replace

import numpy as np
from ezmsg.util.messages.axisarray import AxisArray


## Generate inputs

We start by generating input messages. We also create a copy of each message's time axis .offset, so we can later check to see a change to the output message's offset affects the input message.

In [35]:
def create_inputs(freq_units="Hz"):
    n_chans = 20
    n_freqs = 100
    data_dur = 30.0
    fs = 1024.0

    n_samples = int(data_dur * fs)
    data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs)
    n_msgs = int(data_dur / 2)

    offset = 0
    messages = []
    for arr in np.array_split(data, n_samples // n_msgs):
        messages.append(
            AxisArray(
                arr,
                dims=["time", "ch", "freq"],
                axes={
                    "time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset),
                    "freq": AxisArray.Axis(gain=1.0, offset=0.0, unit=freq_units)
                }
            )
        )
        offset += arr.shape[0] / fs
    print(f"Generated {len(messages)} messages")
    return messages


in_msgs = create_inputs()
# Make a copy of the offsets for future 
t0s = np.array([_.axes["time"].offset for _ in in_msgs])
assert all([t0 == ax_arr.axes["time"].offset for t0, ax_arr in zip(t0s, in_msgs)])

Generated 2048 messages


In [3]:
# Let's make sure that modifying an offset does not modify our baseline.
in_msgs[-1].axes["time"].offset = -11.11
assert t0s[-1] != in_msgs[-1].axes["time"].offset 
in_msgs[-1].axes["time"].offset = t0s[-1]  # Set it back

## Creating and manipulation a new AxisArray

### deepcopy

`copy.deepcopy` works as expected. Modifying the `.offset` of the new msg's "time" axis has no effect on the original message.
Unfortunately, deepcopy is very expensive, especially when we are only modifying a small subset of fields.

In [4]:
out_msg = copy.deepcopy(in_msgs[0])
assert out_msg.axes is not in_msgs[0].axes
out_msg.axes["time"].offset = -22.22
assert in_msgs[0].axes["time"].offset == t0s[0]
assert in_msgs[0].axes["time"].offset != out_msg.axes["time"].offset

# But deepcopy can be expensive
%timeit out_msgs_dc = [copy.deepcopy(_) for _ in in_msgs]

77.9 ms ± 5.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### (shallow) copy

As expected, a shallow copy creates a new object but all mutable fields are shared between the old and the new. As a result, any modification to the new fields will affect the old object.

As shallow copy is roughly equivalent to initializing a new object with all the fields of the old, we expect the same behaviour when creating outputs that way.

In [5]:
out_msg = copy.copy(in_msgs[0])
# Shares the objects underlying the fields.
assert out_msg.axes is in_msgs[0].axes
assert out_msg.data is in_msgs[0].data
# Modify out and see that input changes too.
out_msg.axes["time"].offset = -22.22
assert in_msgs[0].axes["time"].offset == out_msg.axes["time"].offset
assert in_msgs[0].axes["time"].offset != t0s[0]
# Reset the value.
in_msgs[0].axes["time"].offset = t0s[0]

In [6]:
# What if we create a new message and fill the fields with references?
out_msg = AxisArray(
    data=in_msgs[0].data,
    dims=in_msgs[0].dims,
    axes=in_msgs[0].axes
)
# Shares the objects underlying the fields.
assert out_msg.axes is in_msgs[0].axes
assert out_msg.axes["time"] is in_msgs[0].axes["time"]
assert out_msg.data is in_msgs[0].data
# Modify out and see that input changes too.
out_msg.axes["time"].offset = -33.33
assert in_msgs[0].axes["time"].offset == out_msg.axes["time"].offset
assert in_msgs[0].axes["time"].offset != t0s[0]
# Reset the value.
in_msgs[0].axes["time"].offset = t0s[0]

### New from dataclasses `replace`

It is not obvious how dataclasses `replace` makes the new object, so we demonstrate it here explicitly.
The output is a shallow copy of the input. Furthermore, the "replaced" fields are shallow copies of their argument values.

In [7]:
out_msg = replace(
    in_msgs[0],
    axes=in_msgs[0].axes
)
assert out_msg.data is in_msgs[0].data
assert out_msg.axes is in_msgs[0].axes
assert out_msg.axes["time"] is in_msgs[0].axes["time"]
out_msg.axes["time"].offset = -44.44
assert in_msgs[0].axes["time"].offset == out_msg.axes["time"].offset
assert in_msgs[0].axes["time"].offset != t0s[0]
# Reset
in_msgs[0].axes["time"].offset = t0s[0]

### New `axes` on init

If we create a new dictionary to pass to out_msg creation, then we can successfully separate the input axes dict and the output axes dict. HOWEVER, this does not break links between the individual fields!

In [8]:
out_msg = replace(
    in_msgs[0],
    axes={**in_msgs[0].axes}
)
assert out_msg.axes is not in_msgs[0].axes
# Yay, the dict is not the same object!
# However, the time axes are still linked.
assert out_msg.axes["time"] is in_msgs[0].axes["time"]
out_msg.axes["time"].offset = -55.55
assert in_msgs[0].axes["time"].offset != t0s[0]
# Reset
in_msgs[0].axes["time"].offset = t0s[0]

And just for completeness, we show that if we keep the same .axes dict and just create a new entry then we haven't solved the problem because the container is linked. We simply replaced the "time" value on both the output and the input.

In [9]:
out_msg = copy.copy(in_msgs[0])
out_msg.axes["time"] = replace(out_msg.axes["time"], offset=-66.66)
# dict remains the same, so we just changed one field in one dict shared across input and output.
assert out_msg.axes["time"] is in_msgs[0].axes["time"]
# Reset
in_msgs[0].axes["time"].offset = t0s[0]

### General solution

1. Return message is a new object, potentially initialised via shallow copy of input or template.
2. `.axes` is a new dict.
3. Field (e.g., "time") should be modified by replacement only.

Here is one example of how to achieve this.


In [63]:
out_msg = copy.copy(in_msgs[0])  # 1. New object
out_msg.axes = {  # New dict 
    **in_msgs[0].axes,  # Shallow copy of existing items
    "time": copy.copy(in_msgs[0].axes["time"])  # Overwrite time with a copy of the TimeAxis.
}
out_msg.axes["time"].offset = -77.77
assert out_msg.axes is not in_msgs[0].axes
assert out_msg.axes["time"] is not in_msgs[0].axes["time"]
assert in_msgs[0].axes["time"].offset == t0s[0]  # Input offset has not been affected

## Profiling modifications to AxisArray axis fields

What is the fastest way to return a message that is similar to the input with a transformed output and a modification to one of the axes entries?


### Profile creation of axes dict

There are several ways to make a new dict initialized with an old dict.

* `new_dict = old_dict.copy()`
* `new_dict = {**old_dict}`
* `new_dict = dict({**old_dict})`

The 3rd method is strictly slower than the 2nd so we will ignore it.

We also need to overwrite one of the fields. The 2nd method allows this in the same command.
Additionally, there's a question as to whether we should copy all the old dict fields or only the ones we aren't replacing.

Let's investigate.

First, we'll recreate our `in_msgs` but this time give the freq axis a very long string for units to make copying more expensive.


In [36]:
import string
import random

res = ''.join(random.choices(string.ascii_uppercase +
                             string.digits, k=1000))
in_msgs = create_inputs(freq_units=res)

Generated 2048 messages


Next we'll try several ways to make a new dict.
Note that in each case we can't assume the class of the axis info object because it might have been a customized subclass.

In [31]:
def new_dict_1(in_msg):
    new_dict = in_msg.axes.copy()
    ax_cls = type(new_dict["time"])
    new_dict["time"] = ax_cls(gain=new_dict["time"].gain, offset=-88.88)
    return new_dict

def new_dict_2(in_msg):
    new_dict = {**in_msg.axes}
    ax_cls = type(new_dict["time"])
    new_dict["time"] = ax_cls(gain=new_dict["time"].gain, offset=-88.88)
    return new_dict

def new_dict_3(in_msg):
    ax_cls = type(in_msg.axes["time"])
    return {
        **in_msg.axes,
        "time": ax_cls(gain=in_msg.axes["time"].gain, offset=-88.88)
    }

def new_dict_4(in_msg):
    ax_cls = type(in_msg.axes["time"])
    return {
        **{
            k: (v if k != "time" else ax_cls(gain=in_msg.axes["time"].gain, offset=-88.88))
            for k, v in in_msg.axes.items() if k != "time"
        }
    }


%timeit [new_dict_1(_) for _ in in_msgs]
%timeit [new_dict_2(_) for _ in in_msgs]
%timeit [new_dict_3(_) for _ in in_msgs]
%timeit [new_dict_4(_) for _ in in_msgs]


883 µs ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
882 µs ± 18.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
965 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
774 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


`new_dict_4` is faster maybe because it has fewer "copies", even though the copies are shallow. I might have thought that the check against `k != "time"` is expensive.
However, the 4th method is hard to read and is easy to get wrong.

The other methods are quite close to each other. Method 1 or 2 might be preferred over 4 for readibility.

Next we consider whether creation of the new TimeAxis field is better done with initialization or `replace`. The nice thing about `replace` is that we don't need to know the input class or the other fields.

In [32]:
def new_dict_2_replace(in_msg):
    new_dict = {**in_msg.axes}
    new_dict["time"] = replace(new_dict["time"], offset=-88.88)


%timeit [new_dict_2_replace(_) for _ in in_msgs]

2.03 ms ± 8.29 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Wow! `replace` was more than twice as slow as creating a new object.

### Profile creation of object

In [38]:
def new_axarr_1(in_msg):
    msg_cls = type(in_msg)
    return msg_cls(**in_msg.__dict__)


def new_axarr_2(in_msg):
    return replace(in_msg, axes=in_msg.axes)


%timeit [new_axarr_1(_) for _ in in_msgs]
%timeit [new_axarr_2(_) for _ in in_msgs]

1.16 ms ± 1.61 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.32 ms ± 10.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
def full_modify_1(in_msg):
    msg_cls = type(in_msg)
    out_msg = msg_cls(**in_msg.__dict__)
    ax_cls = type(in_msg.axes["time"])
    out_msg.axes = {
        k: (v if k != "time" else ax_cls(gain=in_msg.axes["time"].gain, offset=-88.88))
        for k, v in in_msg.axes.items() if k != "time"
    }
    return out_msg


def full_modify_2(in_msg):
    return replace(
            in_msg,
            axes={
                **in_msg.axes,
                "time": replace(
                    in_msg.axes["time"],
                    offset=-99.99
                )
            }
        )
    

%timeit [full_modify_1(_) for _ in in_msgs]
%timeit [full_modify_2(_) for _ in in_msgs]

1.91 ms ± 6.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.51 ms ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


0.9 vs 2.2 µs per message (2048 messages per loop).