# Large - Multi-Channel Timeseries with Dynamic Data Access

TODO create banner image
![]()

---

## Overview

<div class="admonition alert alert-info">
    <p class="admonition-title" style="font-weight:bold"> Visit the Index Page </p>
    This workflow example is part of set of related workflows. If you haven't already, visit the <a href="/index.html">index</a> page for an introduction and guidance on choosing the appropriate workflow.
</div>

The intended use-case for this workflow is to browse and annotate multi-channel timeseries data from an [electrophysiological](https://en.wikipedia.org/wiki/Electrophysiology) recording session.

Compared to other approaches in this set of workflows, this particular workflow is focused on 'large-sized' datasets, which we define as a dataset that does not comfortably fit into the available RAM.

In such cases where the entire dataset cannot be loaded into memory, we have to consider what approaches might work best for scalability. The approach we will demonstrate is one of the most common approaches in the bio-imaging community, and is based on the use of multi-resolution data structures.

We will create a derived dataset that includes a multi-resolution pyramid (incrementally downsampled versions of a large dataset), and then use a dynamic accessor to access the appropriate resolution based on viewport and screen parameters.

## Prerequisites and Resources

| Topic | Type | Notes |
| --- | --- | --- |
| [Intro and Guidance](./index.ipynb) | Prerequisite | Background |
| [Time Range Annotation](./time_range_annotation.ipynb) | Next Step | Display and edit time ranges |
| [Smaller Dataset Workflow](./small_multi-chan-ts.ipynb) | Alternative | Use Numpy |
| [Medium Dataset Workflow](./medium_multi-chan-ts.ipynb) | Alternative | Use Pandas and downsampling |

---

## Preprocessing the data

### Imports and Configuration

We start by importing the libraries necessary to preprocess the data, notably:

- `tsdownsample` for downsampling data
- `ndpyramid` for creating a multi-resolution pyramid
- `datatree` for opening and reading datatrees

In [None]:
import os

import dask.array as da
import datatree as dt
import h5py
import numpy as np
import xarray as xr
from ndpyramid import pyramid_create
from tsdownsample import MinMaxLTTBDownsampler

DATA_DIR = os.path.expanduser("~/repos/czi/allensdk_cache/session_715093703")
PYRAMID_PATH = os.path.join(DATA_DIR, "pyramid_neuropix_10s.zarr")
OVERWRITE = False

### Serialize into XArray

We use `h5py` to open the HDF5 file and because `xarray` provides an interface with many of the modern data wrangling libraries, we serialize pieces of the data into an `xr.DataArray`. We also wrap `dask` on the data so that it's lazily loaded, i.e. data isn't loaded until necessary.


In [None]:
def serialize_to_xarray(f, data_key, dims):
    coords = {f[dim] for dim in dims.values()}
    data = f[data_key]
    ds = xr.DataArray(
        da.from_array(data, name="data", chunks=(data.shape[0], 1)),
        dims=dims,
        coords=coords,
    ).to_dataset()
    return ds


h5py_path = os.path.join(DATA_DIR, "probe_810755797_lfp.nwb")
f = h5py.File(h5py_path, "r")

ts_ds = serialize_to_xarray(
    f,
    "acquisition/probe_810755797_lfp_data/data",
    {
        "time": "acquisition/probe_810755797_lfp_data/timestamps",
        "channel": "acquisition/probe_810755797_lfp_data/electrodes",
    },
).isel(channel=slice(10))

ts_ds

### Create a DataTree

Now that we have an `xr.DataArray`, we can perform computations on it in a vectorized & parallelized manner with `xr.apply_ufunc`.

Combine it with `ndpyramid.pyramid_create` to create a data tree with various levels containing the downsampled by various factors data.

In [None]:
# Define the factors for downsampling, that scale with the number of channels.
FACTORS = list(np.array([1, 2, 4, 8, 16, 32, 64, 128, 256]) ** (len(ts_ds["channel"]) // 4))


def _help_downsample(data, time, n_out):
    """
    Helper function for downsampling and returning as a specific format.
    """
    indices = MinMaxLTTBDownsampler().downsample(time, data, n_out=n_out)
    return data[indices], indices


def apply_downsample(ts_ds, factor, dims):
    """
    Apply downsampling to a time series dataset.
    """
    dim = dims[0]
    n_out = len(ts_ds["data"]) // factor
    print(f"Downsampling by factor {factor} for a size of {n_out}.")
    ts_ds_downsampled, indices = xr.apply_ufunc(
        _help_downsample,
        ts_ds["data"],
        ts_ds[dim],
        kwargs=dict(n_out=n_out),
        input_core_dims=[[dim], [dim]],
        output_core_dims=[[dim], ["indices"]],
        exclude_dims=set((dim,)),
        vectorize=True,
        dask="parallelized",
        dask_gufunc_kwargs=dict(output_sizes={dim: n_out, "indices": n_out}),
    )
    ts_ds_downsampled[dim] = ts_ds[dim].isel(time=indices.values[0])
    return ts_ds_downsampled.rename("data")


if not os.path.exists(PYRAMID_PATH) or OVERWRITE:
    ts_dt = pyramid_create(
        ts_ds,
        factors=FACTORS,
        dims=["time"],
        func=apply_downsample,
        type_label="pick",
        method_label="pyramid_downsample",
    )
    display(ts_dt)

### Persist and Re-open

`dt.DataTree`s mirror `xr.DataArray`s in functionality, and so we can easily export it as zarr.

In [None]:
if not os.path.exists(PYRAMID_PATH) or OVERWRITE:
    ts_dt.to_zarr(PYRAMID_PATH, mode="w")

And read it back in just as easily--just be sure to specify the correct engine.

In [None]:
ts_dt = dt.open_datatree(PYRAMID_PATH, engine="zarr")

ts_dt

## Plotting

### Import and Configuration

We now import the libraries necessary for interactively utilizing the datatree / pyramid we just created, notably:

- `holoviews`, using `bokeh` backend, to build interactive plots
- `panel` to create widgets and dashboard
- `scipy` for calculating a zscore of the data


In [None]:
import holoviews as hv
import panel as pn
import datatree as dt

from bokeh.models.tools import HoverTool, WheelZoomTool
from holoviews.operation.datashader import rasterize
from holoviews.plotting.links import RangeToolLink
from scipy.stats import zscore

pn.extension()
hv.extension("bokeh")

### Prepare the Data

Here, we prepare some metadata about the data.

In [None]:
def _extract_ds(ts_dt, level, channel):
    """
    Helper function to extract a dataset at a specific level and channel.
    """
    ds = ts_dt[str(level)].sel(channel=channel).ds
    return ds


ts_dt = dt.open_datatree(PYRAMID_PATH, engine="zarr")

num_levels = len(ts_dt) - 1
sel_group = f"{num_levels}"
time_da = _extract_ds(ts_dt, sel_group, 0)["time"]

channels = ts_dt[sel_group].ds["channel"].values
num_channels = len(channels)

### Create Dynamic Plot

Here we define a `rescale` function that reruns when the axes' ranges (`RangeXY`) or the size of a plot (`PlotSize`) changes.

Based on the changes and thresholds, a new plot is created using a new subset of the datatree. 

In [None]:
X_PADDING = 0.2  # buffer around x so if user zooms out, data is still visible


def rescale(x_range, y_range, width, scale, height):
    # fix edge cases when streams are initialized
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()
    if y_range is None:
        y_range = 0, num_channels
    x_padding = (x_range[1] - x_range[0]) * X_PADDING
    time_slice = slice(x_range[0] - x_padding, x_range[1] + x_padding)

    # calculate the appropriate zoom level and size
    if width is None or height is None:
        zoom_level = num_levels - 1
        size = time_da.size
    else:
        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))
        size = sizes[zoom_level]

    # re-plot the data
    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        hover = HoverTool(
            tooltips=[
                ("Channel", str(channel)),
                ("Time", "$x s"),
                ("Amplitude", "$y µV"),
            ]
        )
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=time_slice).load()
        curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=1,
            default_tools=["pan", "reset", WheelZoomTool(), hover],
        )
        curves *= curve

    # update the title
    title = (
        f"level {zoom_level} ({x_range[0]:.2f}s - {x_range[1]:.2f}s) "
        f"(WxH: {width}x{height}) (length: {size})"
    )
    curves = curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        title=title,
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        framewise=True,
        axiswise=True,
    )
    return curves


range_stream = hv.streams.RangeXY()
size_stream = hv.streams.PlotSize()
dmap = hv.DynamicMap(rescale, streams=[size_stream, range_stream])

dmap

### Associate a Minimap

Lastly, we can link a minimap to the main plot to allow for easier navigation.

In [None]:
data = ts_dt[sel_group].ds["data"].values
y_positions = range(num_channels)
yticks = [(i, ich) for i, ich in enumerate(channels)]
z_data = zscore(data, axis=1)

minimap = rasterize(
    hv.Image((time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)")
).opts(
    cnorm='eq_hist',
    cmap="RdBu_r",
    xlabel="",
    yticks=[yticks[0], yticks[-1]],
    toolbar="disable",
    height=120,
    responsive=True,
x    alpha=0.8,
)

tool_link = RangeToolLink(
    minimap,
    dmap,
    axes=["x", "y"],
    boundsx=(0, time_da.max().item() // 2),
    boundsy=(0, len(channels) // 2),
)

app = (dmap + minimap).cols(1)
app

### Add a Widget

Currently, the minimap uses only the coarsest level of the datatree. We can create a widget to control the level of granularity the minimap shows!

In [None]:
input_group = pn.widgets.Select(value=f"/{sel_group}", options=list(ts_dt.groups[1:]))


def update_minimap(group):
    data = ts_dt[group].ds["data"].values
    y_positions = range(num_channels)
    z_data = zscore(data, axis=1)
    time_da = _extract_ds(ts_dt, group, 0)["time"]

    minimap = hv.Image(
        (time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)"
    )
    return minimap


yticks = [(i, ich) for i, ich in enumerate(channels)]
minimap = rasterize(
    hv.DynamicMap(pn.bind(update_minimap, input_group.param.value)).opts(
        cnorm="eq_hist",
        cmap="RdBu_r",
        xlabel="",
        yticks=[yticks[0], yticks[-1]],
        toolbar="disable",
        height=120,
        responsive=True,
        alpha=0.8,
    )
)

app = pn.Column(input_group, (dmap + minimap).cols(1))
app

## *Optional:* Standalone App

Using HoloViz Panel, we can also set this application as servable so we can launch it in a browser window, outside of a Jupyter Notebook (templates do not work in notebooks at the time of writing).

In [None]:
pn.template.FastListTemplate(main=[app]).servable();  # semi-colon to prevent it from showing output in a notebook

## Summary

## Full app for easy copy/pasting

In [None]:
import holoviews as hv
import panel as pn
import datatree as dt

from bokeh.models.tools import HoverTool, WheelZoomTool
from holoviews.operation.datashader import rasterize
from holoviews.plotting.links import RangeToolLink
from scipy.stats import zscore

pn.extension()
hv.extension("bokeh")

X_PADDING = 0.2  # buffer around x so if user zooms out, data is still visible


def rescale(x_range, y_range, width, scale, height):
    # fix edge cases when streams are initialized
    if x_range is None:
        x_range = time_da.min().item(), time_da.max().item()
    if y_range is None:
        y_range = 0, num_channels
    x_padding = (x_range[1] - x_range[0]) * X_PADDING
    time_slice = slice(x_range[0] - x_padding, x_range[1] + x_padding)

    # calculate the appropriate zoom level and size
    if width is None or height is None:
        zoom_level = num_levels - 1
        size = time_da.size
    else:
        sizes = [
            _extract_ds(ts_dt, zoom_level, 0)["time"].sel(time=time_slice).size
            for zoom_level in range(num_levels)
        ]
        zoom_level = np.argmin(np.abs(np.array(sizes) - width))
        size = sizes[zoom_level]

    # re-plot the data
    curves = hv.Overlay(kdims="Channel")
    for channel in channels:
        hover = HoverTool(
            tooltips=[
                ("Channel", str(channel)),
                ("Time", "$x s"),
                ("Amplitude", "$y µV"),
            ]
        )
        sub_ds = _extract_ds(ts_dt, zoom_level, channel).sel(time=time_slice).load()
        curve = hv.Curve(sub_ds, ["time"], ["data"], label=f"ch{channel}").opts(
            color="black",
            line_width=1,
            subcoordinate_y=True,
            subcoordinate_scale=1,
            default_tools=["pan", "reset", WheelZoomTool(), hover],
        )
        curves *= curve

    # update the title
    title = (
        f"level {zoom_level} ({x_range[0]:.2f}s - {x_range[1]:.2f}s) "
        f"(WxH: {width}x{height}) (length: {size})"
    )
    curves = curves.opts(
        xlabel="Time (s)",
        ylabel="Channel",
        title=title,
        show_legend=False,
        padding=0,
        aspect=1.5,
        responsive=True,
        framewise=True,
        axiswise=True,
    )
    return curves


def _extract_ds(ts_dt, level, channel):
    """
    Helper function to extract a dataset at a specific level and channel.
    """
    ds = ts_dt[str(level)].sel(channel=channel).ds
    return ds


def update_minimap(group):
    data = ts_dt[group].ds["data"].values
    y_positions = range(num_channels)
    z_data = zscore(data, axis=1)
    time_da = _extract_ds(ts_dt, group, 0)["time"]

    minimap = hv.Image(
        (time_da, y_positions, z_data), ["Time (s)", "Channel"], "Amplitude (uV)"
    )
    return minimap


ts_dt = dt.open_datatree(PYRAMID_PATH, engine="zarr")

num_levels = len(ts_dt) - 1
sel_group = f"{num_levels}"
time_da = _extract_ds(ts_dt, sel_group, 0)["time"]

channels = ts_dt[sel_group].ds["channel"].values
num_channels = len(channels)

range_stream = hv.streams.RangeXY()
size_stream = hv.streams.PlotSize()
dmap = hv.DynamicMap(rescale, streams=[size_stream, range_stream])

input_group = pn.widgets.Select(value=f"/{sel_group}", options=list(ts_dt.groups[1:]))
yticks = [(i, ich) for i, ich in enumerate(channels)]
minimap = rasterize(
    hv.DynamicMap(pn.bind(update_minimap, input_group.param.value)).opts(
        cnorm="eq_hist",
        cmap="RdBu_r",
        xlabel="",
        yticks=[yticks[0], yticks[-1]],
        toolbar="disable",
        height=120,
        responsive=True,
        alpha=0.8,
    )
)

app = pn.Column(input_group, (dmap + minimap).cols(1))
app

```python

```