# Lesson 3: Vertical and horizontal scaling

(all credit to J. Pivarski for jupyter notebooks and images, Histogramming tools by L. Gray)

This lesson is about making Python faster, in two dimensions:

<center>
<img src="../img/horizontal-and-vertical-scaling.svg" width="60%">
</center>

## Vertical scaling

Reminder: Python is slow.

<center>
<img src="../img/benchmark-games-2023.svg" width="90%">
</center>

We have already seen that NumPy (and Awkward Array) can circumvent Python's slowness by doing computationally intensive work in compiled code.

In [None]:
import numpy as np
import awkward as ak

events = ak.from_parquet("../data/SMHiggsToZZTo4L.parquet")[:100000]

<br>

In [None]:
a = np.random.uniform(5, 10, 1000000)
b = np.random.uniform(10, 20, 1000000)
c = np.random.uniform(-0.1, 0.1, 1000000)

In [None]:
%%timeit -r1 -n1

pz = [[muon.pt * np.sinh(muon.eta) for muon in event.muon] for event in events]

<br>

In [None]:
%%timeit -r1 -n1

pz = events.muon.pt * np.sinh(events.muon.eta)

Array-oriented programming connects Python with compiled code, but it's not the only way to do that.

<img src="../img/history-of-bindings-2.svg" width="100%">

In [None]:
import numba as nb

<br>

In [None]:
def quadratic_formula(a, b, c):
    return (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)

In [None]:
%%timeit

output = quadratic_formula(a, b, c)

In [None]:
@nb.njit
def quadratic_formula_numba(a, b, c):
    output = np.empty(len(a), dtype=np.float64)
    for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c)):
        output[i] = (-b_i + np.sqrt(b_i**2 - 4*a_i*c_i)) / (2*a_i)
    return output

quadratic_formula_numba(a, b, c)

<br>

In [None]:
%%timeit

quadratic_formula_numba(a, b, c)

The exercises are on vertical scaling, so do one now, before I talk about horizontal scaling.

## Horizontal scaling

<center>
<img src="../img/horizontal-and-vertical-scaling.svg" width="60%">
</center>

There are many ways to distribute a computation.

<br>

The traditional way is to use a batch queue, such as the LHC GRID. You can run Python scripts in a batch queue as easily as anything else.

<br>

This section will be about Dask, a popular way to distribute computations _within_ Python.

<center>
<img src="../img/logo-dask.svg" width="30%">
</center>

Dask is a library for describing a computation as a task graph.

In [None]:
import dask

<br>

Eager Python code:

In [None]:
def increment(i):
    return i + 1

def add(a, b):
    return a + b

a, b = 1, 12
c = increment(a)
d = increment(b)
output = add(c, d)

output

<br>

Lazy Python code:

In [None]:
@dask.delayed
def increment(i):
    return i + 1

@dask.delayed
def add(a, b):
    return a + b

a, b = 1, 12
c = increment(a)
d = increment(b)
output = add(c, d)

output

In [None]:
output.compute()

<br>

In [None]:
output.visualize()

In [None]:
import numpy as np
import dask.array as da

a = da.random.uniform(5, 10, 1000000)
b = da.random.uniform(10, 20, 1000000)
c = da.random.uniform(-0.1, 0.1, 1000000)

output = (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)
output

<br>

In [None]:
output.visualize()

A task graph is a delayed computation—all the instructions that are needed to perform the computation at a later time.

<br>

This separates the problem of "what to compute?" from "when/where/on what resources to compute it?"

In [None]:
import time

@dask.delayed
def start():
    print("start")
    time.sleep(0.5)
    return 1

@dask.delayed
def concurrent(initial, i):
    time.sleep(np.random.uniform(0, 1))
    print(f"concurrent {i}", end="")
    time.sleep(np.random.uniform(0, 0.05))
    print()
    return initial + i**2

@dask.delayed
def combine(partial_results):
    time.sleep(0.5)
    print("combine")
    return sum(partial_results)

initial = start()
output = combine([concurrent(initial, i) for i in range(10)])

In [None]:
output.visualize()

Dask has three built-in schedulers (rarely used in production):

* `"synchronous"`: not parallel, intended for debugging
* `"threads"`: multiple threads in the same process, limited by the [Python GIL](https://realpython.com/python-gil/)
* `"processes"`: multiple Python processes; not affected by the GIL, but it has to start a bunch of processes

<br>

In [None]:
with dask.config.set(scheduler="synchronous"):
    output.compute()

In a real system, we're more likely to use the Distributed library (or a third-party, such as Ray).

<img src="../img/distributed-overview.svg" width="100%">

Run in separate terminals:

```bash
dask-scheduler
```

and several of the following:

```bash
dask-worker --nthreads 1 127.0.0.1:8786
```

In [None]:
import dask.distributed

client = dask.distributed.Client("127.0.0.1:8786")
client

In [None]:
output.compute()

<br><br><br><br><br>

## Dask collections

You can build general computations with `@dask.delayed`, but there are some common patterns that we'd want to build all the time.

For instance, splitting a calculation on NumPy arrays into embarrassingly parallel parts:

In [None]:
import os

os.getcwd()

In [None]:
import os
import h5py

dataset_hdf5 = h5py.File(os.getcwd() + "/../data/SMHiggsToZZTo4L.h5")

pt1 = da.from_array(dataset_hdf5["ee_mumu"]["e1"]["pt"], chunks=10000)
phi1 = da.from_array(dataset_hdf5["ee_mumu"]["e1"]["phi"], chunks=10000)
eta1 = da.from_array(dataset_hdf5["ee_mumu"]["e1"]["eta"], chunks=10000)
pt2 = da.from_array(dataset_hdf5["ee_mumu"]["e2"]["pt"], chunks=10000)
phi2 = da.from_array(dataset_hdf5["ee_mumu"]["e2"]["phi"], chunks=10000)
eta2 = da.from_array(dataset_hdf5["ee_mumu"]["e2"]["eta"], chunks=10000)

pt1

In [None]:
mass = np.sqrt(2*pt1*pt2*(np.cosh(eta1 - eta2) - np.cos(phi1 - phi2)))
mass

In [None]:
mass.visualize()

In [None]:
from hist import Hist

Hist.new.Reg(120, 0, 120, name="dimuon mass").Double().fill(
    mass.compute()
).plot();

<br><br><br><br><br>

<img src="../img/dask-overview.svg" width="100%">

<br><br><br><br><br>

## The dask-awkward collection

In [None]:
import uproot
import awkward as ak

events = uproot.dask(os.getcwd() + "/../data/SMHiggsToZZTo4L.root")

# events = uproot.dask(
#     "https://pivarski-princeton.s3.amazonaws.com/cms-open-dimuons/Run2012B_DoubleMuParked.root",
#     step_size=1000,
# )

events

In [None]:
selected = events[ak.num(events.Electron_pt) == 2]
selected

In [None]:
pt1 = selected.Electron_pt[:, 0]
phi1 = selected.Electron_phi[:, 0]
eta1 = selected.Electron_eta[:, 0]
pt2 = selected.Electron_pt[:, 1]
phi2 = selected.Electron_phi[:, 1]
eta2 = selected.Electron_eta[:, 1]

pt1

In [None]:
import numpy as np
mass = np.sqrt(2*pt1*pt2*(np.cosh(eta1 - eta2) - np.cos(phi1 - phi2)))
mass

In [None]:
mass.visualize()

In [None]:
Hist.new.Reg(120, 0, 120, name="dimuon mass").Double().fill(
    mass.compute()
).plot();

In [None]:
import hist.dask as hda

mass_dask_hist = hda.Hist.new.Reg(120, 0, 120, name="dimuon mass").Double().fill(
    mass
)
mass_dask_hist

In [None]:
mass_dask_hist.compute().plot();

## Building Distributed Objects in Dask

An extremely common pattern in Dask is embarassingly parallel processing.
i.e. Repeating the computation on different data that's being loaded in from a file in chunks.
This could be on a single machine, or it could be multiple machines in a computing cluster.

This becomes tricky when the operation itself becomes resource intensive in one way or another.
Histograms actually serve as a great demonstration of this problem, especially given the modern
tools we have (Hist, and similar) that let us produce many-dimensional histograms easily.

Generating a billion bin histogram is easily realizable and when considering weighted filling this is 
a 16 GB histogram. Histograms this size have already been made for analyses, CMS W-mass ~8 GB, and 
50GB histograms are possible when dealing with large amounts of systematics.

Given that these histograms do not fit on a single typical batch slot, what can we do to make full use of
our modern software *without* requiring dedicated services or very custom code?

Dask provides all the necessary primitives to build software that scales, and we'll take a look at that!


The following is code that I'm developing towards a contribution to a histogramming library.
I thought it might be useful to see how these kind of objects get made.

This is going to use a number of the concepts we introduced on Tuesday, in addition to the things we discussed today!

In [None]:
# first stop the scheduler and worker in the terminal
# then open a new browser tab at localhost:8787 to see
# the dashboard
from distributed import Client

client = Client()

A core idea we're going to need here is the ability to partially reduce the data, instead of having one object being output that's *all* of the data represented in one array. We want multiple arrays that each contain partial sums of the data.

This class below creates a dask task graph that does exactly this, in addition to applying the functions needed to create the sums and add them together correctly!

In [None]:
from __future__ import annotations

from dask.layers import Layer

from typing import Any, Callable

import math
import toolz
import numpy as np

class TruncatedTreeReduction(Layer):
    """Truncated Tree Reduction

    This reduction stops when the the number of
    partitions in a tree reduction is smaller than
    split_every.

    Parameters
    ----------
    name : str
        Name to use for the constructed layer.
    name_input : str
        Name of the input layer that is being reduced.
    npartitions_input : int
        Number of partitions in the input layer.
    concat_func : callable
        Function used by each tree node to reduce a list of inputs
        into a single output value. This function must accept only
        a list as its first positional argument.
    tree_node_func : callable
        Function used on the output of ``concat_func`` in each tree
        node. This function must accept the output of ``concat_func``
        as its first positional argument.
    finalize_func : callable, optional
        Function used in place of ``tree_node_func`` on the final tree
        node(s) to produce the final output for each split. By default,
        ``tree_node_func`` will be used.
    split_every : int, optional
        This argument specifies the maximum number of input nodes
        to be handled by any one task in the tree. Defaults to 32.
    tree_node_name : str, optional
        Name to use for intermediate tree-node tasks.
    """

    name: str
    name_input: str
    npartitions_input: int
    npartitions_stop: int
    concat_func: Callable
    tree_node_func: Callable
    finalize_func: Callable | None
    split_every: int
    nnewax_partitions: int | None
    output_partitions: list[int] | list[tuple[int]]
    tree_node_name: str
    widths: list[int]
    height: int

    def __init__(
        self,
        name: str,
        name_input: str,
        npartitions_input: int, 
        npartitions_stop: int,
        concat_func: Callable,
        tree_node_func: Callable,
        finalize_func: Callable | None = None,
        split_every: int = 32,
        nnewax_partitions: int | None = None,
        newax_partition_size: int | None = None,
        tree_node_name: str | None = None,
        annotations: dict[str, Any] | None = None,
    ):
        super().__init__(annotations=annotations)

        if nnewax_partitions is None != newax_partition_size is None:
            raise ValueError("both new axis partitions and partition size must be defined!")
        
        self.name = name
        self.name_input = name_input
        self.npartitions_input = npartitions_input
        self.npartitions_stop = npartitions_stop
        self.concat_func = concat_func
        self.tree_node_func = tree_node_func
        self.finalize_func = finalize_func
        self.split_every = split_every
        self.nnewax_partitions = nnewax_partitions
        self.newax_partition_size = newax_partition_size
        self.tree_node_name = tree_node_name or "tree_node-" + self.name

        # Calculate tree widths and height
        # (Used to get output keys without materializing)
        parts = self.npartitions_input
        self.widths = [parts]
        while parts > max(self.split_every, self.npartitions_stop):
            parts = math.ceil(parts / self.split_every)
            self.widths.append(int(parts))
        self.widths[-1] = self.npartitions_stop
        self.height = len(self.widths)

        npartitions = (
            self.npartitions_stop if not self.nnewax_partitions
            else self.npartitions_stop * self.nnewax_partitions
        )
        partitions_shape = (
            (self.npartitions_stop, ) if not self.nnewax_partitions
            else (self.npartitions_stop, self.nnewax_partitions)
        )
        
        self.output_partitions = (
            list(range(npartitions)) if not self.nnewax_partitions
            else list( (int(i), int(j)) for i, j in zip(*np.unravel_index(np.arange(npartitions), partitions_shape)) )
        )

    
    def _make_key(self, *name_parts, split=None):
        # Helper function construct a key
        # with a "split" element when
        # bool(split_out) is True
        return name_parts + (split,) if split else name_parts

    def _define_task(self, input_keys, split_and_size=None, final_task=False):
        # Define nested concatenation and func task
        if final_task and self.finalize_func: 
            outer_func = self.finalize_func
            if split_and_size is not None:
                split, size = split_and_size
                outer_func = lambda x: self.finalize_func(x, split=split, size=size)
        else:
            outer_func = self.tree_node_func
        return (toolz.pipe, input_keys, self.concat_func, outer_func)

    def _construct_graph(self):
        """Construct graph for a tree reduction."""

        dsk = {}
        if not self.output_partitions:
            return dsk

        if self.height >= 2:
            # Loop over reduction levels
            for depth in range(1, self.height):
                # Loop over reduction groups
                if depth == self.height - 1:
                    split_every_base, overlap = divmod(self.widths[depth-1], self.widths[depth])
                else:
                    split_every_base, overlap = self.split_every, 0
                for group in range(self.widths[depth]):
                    # Calculate inputs for the current group
                    split_every = split_every_base + 1 if group < overlap else split_every_base
                    p_max = self.widths[depth - 1]
                    lstart = split_every * group
                    if group >= overlap:
                        lstart += overlap
                    lstop = min(lstart + split_every, p_max)
                    if depth == 1:
                        # Input nodes are from input layer
                        input_keys = [
                            self._make_key(self.name_input, p)
                            for p in range(lstart, lstop)
                        ]
                    else:
                        # Input nodes are tree-reduction nodes
                        input_keys = [
                            self._make_key(
                                self.tree_node_name, p, depth - 1
                            )
                            for p in range(lstart, lstop)
                        ]

                    # Define task
                    if depth == self.height - 1:
                        # Final Node (Use fused `self.tree_finalize` task)
                        if self.nnewax_partitions:
                            for new_part in range(self.nnewax_partitions):
                                dsk[(self.name, group, new_part)] = self._define_task(
                                    input_keys, 
                                    split_and_size=(new_part, self.newax_partition_size), 
                                    final_task=True
                                )
                        else:
                            dsk[(self.name, group)] = self._define_task(
                                input_keys, final_task=True
                            )
                    else:
                        # Intermediate Node
                        dsk[
                            self._make_key(
                                self.tree_node_name, group, depth
                            )
                        ] = self._define_task(input_keys, final_task=False)
        else:
            # Deal with single-partition case
            for s in self.output_partitions:
                if isinstance(s, tuple):
                    input_keys = [self._make_key(self.name_input, s[0])]
                    dsk[(self.name, *s)] = self._define_task(
                        input_keys, 
                        split_and_size=(s[1], self.newax_partition_size), 
                        final_task=True
                    )
                else:
                    input_keys = [self._make_key(self.name_input, s)]
                    dsk[(self.name, s)] = self._define_task(input_keys, final_task=True)

        return dsk

    def __repr__(self):
        return "DataFrameTreeReduction<name='{}', input_name={}, nnewax_partitions={}>".format(
            self.name, self.name_input, self.nnewax_partitions
        )

    def _output_keys(self):
        return {(self.name, *s) if isinstance(s, tuple) else (self.name, s) for s in self.output_partitions}

    def get_output_keys(self):
        if hasattr(self, "_cached_output_keys"):
            return self._cached_output_keys
        else:
            output_keys = self._output_keys()
            self._cached_output_keys = output_keys
        return self._cached_output_keys

    def is_materialized(self):
        return hasattr(self, "_cached_dict")

    @property
    def _dict(self):
        """Materialize full dict representation"""
        if hasattr(self, "_cached_dict"):
            return self._cached_dict
        else:
            dsk = self._construct_graph()
            self._cached_dict = dsk
        return self._cached_dict

    def __getitem__(self, key):
        return self._dict[key]

    def __iter__(self):
        return iter(self._dict)

    def __len__(self):
        # Start with "base" tree-reduction size
        tree_size = (sum(self.widths[1:]) or 1) * (self.nnewax_partitions or 1)
        if self.nnewax_partitions:
            # Add on "split-*" tasks used for `getitem` ops
            return tree_size + len(self.output_partitions)
        return tree_size

    def _keys_to_output_partitions(self, keys):
        """Simple utility to convert keys to output partition indices."""
        splits = set()
        for key in keys:
            try:
                _name, _split = key
            except ValueError:
                continue
            if _name != self.name:
                continue
            splits.add(_split)
        return splits

    def _cull(self, output_partitions):
        return TruncatedTreeReduction(
            self.name,
            self.name_input,
            self.npartitions_input,
            self.npartitions_stop,
            self.concat_func,
            self.tree_node_func,
            finalize_func=self.finalize_func,
            split_every=self.split_every,
            nnewax_partitions=self.nnewax_partitions,
            newax_partition_size=self.newax_partition_size,
            tree_node_name=self.tree_node_name,
            annotations=self.annotations,
        )

    def cull(self, keys, all_keys):
        """Cull a DataFrameTreeReduction HighLevelGraph layer"""
        deps = {
            (self.name, 0): {
                (self.name_input, i) for i in range(self.npartitions_input)
            }
        }
        output_partitions = self._keys_to_output_partitions(keys)
        if output_partitions != set(self.output_partitions):
            culled_layer = self._cull(output_partitions)
            return culled_layer, deps
        else:
            return self, deps

    def mock(self) -> TruncatedTreeReduction:
        return TruncatedTreeReduction(
            name=self.name,
            name_input=self.name_input,
            npartitions_input=1,
            concat_func=self.concat_func,
            tree_node_func=self.tree_node_func,
            finalize_func=self.finalize_func,
            split_every=self.split_every,
            nnewax_partitions=1 if self.nnewax_partitions else self.nnewax_partitions,
            newax_partition_size=self.newax_partition_size,
            tree_node_name=self.tree_node_name,
        )

In [None]:
import dask.array as da

# make some data
ndata = 125_000_000
chunk_size = 125_000
fill_weighted = True
ndims = 3
data = da.random.normal(size=(ndata, ndims), chunks=(chunk_size, ndims))
weights = da.random.normal(loc=1, scale=0.05, size=(ndata, ), chunks=(chunk_size, ))

In [None]:
# make some bins
nbins = 50
bins = da.linspace(-5, 5, nbins+1)

nd_bins = ndims*[(-5, 5, nbins)]
nd_nbins = math.prod(i_bin[2] for i_bin in nd_bins)
nd_minlength = math.prod(i_bin[2]+2 for i_bin in nd_bins)

## Building a histogramming strategy

What we want is a histogram that never fully appears on any particular worker.
This means that we can't create the "dense" histogram and fill that, so we must 
think towards how we can represent a filled histogram logically!

- Given an N-dimensional histogram we can always figure out what bin a piece of data
*should* be in.
- We can figure out how many (weighted) entries are in a bin using np.unique!
- Once we know how many entries are in bins in slices of our data we can add those slices, keeping track of the indices

Since all of this is bookkeeping the indices, rather than passing around and filling a histogram object, we can avoid
rendering the complete histogram on any particular worker!

In [None]:
# import everything we'll need to build our distributed histogram

import dask
import dask.array as da
import math
import numpy as np

from dask.array.core import Array as dArray
from dask.blockwise import blockwise, BlockIndex
from dask.highlevelgraph import HighLevelGraph

from functools import partial
import numba as nb

nb.config.THREADING_LAYER = 'threadsafe'

In [None]:
counts_dtype = np.dtype(
    [("indices", np.int64), ("counts", np.int64)]
)

weights_dtype = np.float64

weighted_dtype = np.dtype(
    [("indices", np.int64), ("sumw", weights_dtype), ("sumw2", weights_dtype)]
)

def _noop(x):
    return x

# for this simple example bins are in (lo, hi, count) format
# we can extend this later to include all the typical varieties of dense bins
def bin_data(data, bins, **_):
    to_bin = data if len(data.shape) > 1 else data[None]
    shape = []
    binned_data = []
    for ax, (lo, hi, nbins_ax) in enumerate(bins):
        bins_ax = np.linspace(lo, hi, nbins_ax + 1)
        shape.append(nbins_ax + 2)
        binned_data.append(np.searchsorted(bins_ax, to_bin[:, ax], side="right"))
    nd_bins = np.stack(binned_data, )
    return np.ravel_multi_index(binned_data, tuple(shape))

def cheap_sparse_sum(arrays):
    out = arrays.pop()
    locs = out["indices"]
    is_weighted = False
    if "counts" in out.dtype.names:
        counts = out["counts"]
    else:
        is_weighted = True
        sumw_w2 = np.stack((out["sumw"], out["sumw2"]), axis=1)
        
    while len(arrays):
        to_add = arrays.pop()
        ilocs = to_add["indices"]
        if is_weighted:
            isumw_w2 = np.stack((to_add["sumw"], to_add["sumw2"]), axis=1)
        else:
            icounts = to_add["counts"]            
        catted_coords = np.r_[locs, ilocs]
        nlocs, inverse = np.unique(catted_coords, return_inverse=True)
        if is_weighted:
            updated_sumw_w2 = np.zeros((nlocs.size, 2), dtype=sumw_w2.dtype)
            updated_sumw_w2[inverse[:locs.size]] += sumw_w2
            updated_sumw_w2[inverse[locs.size:]] += isumw_w2
            sumw_w2 = updated_sumw_w2
        else:
            updated_counts = np.zeros_like(nlocs, dtype=counts.dtype)
            updated_counts[inverse[:locs.size]] += counts
            updated_counts[inverse[locs.size:]] += icounts
            counts = updated_counts
        locs = nlocs
    out = np.empty(
        len(locs), 
        dtype=(
            weighted_dtype if is_weighted
            else counts_dtype
        )
    )
    out["indices"] = locs
    if is_weighted:
        out["sumw"], out["sumw2"] = sumw_w2[:,0], sumw_w2[:, 1]
    else:
        out["counts"] = counts
    return out


@nb.jit(
    [
        nb.void(nb.float32[:], nb.float32[:], nb.int64[:], nb.float32[:]), #float histogram, signed bins
        nb.void(nb.float64[:], nb.float64[:], nb.int64[:], nb.float64[:]), #double histogram, signed bins
        nb.void(nb.float32[:], nb.float32[:], nb.uint64[:], nb.float32[:]), #float histogram, unsigned bins
        nb.void(nb.float64[:], nb.float64[:], nb.uint64[:], nb.float64[:]), #double histogram, unsigned bins
    ],
    nopython=True, 
    parallel=False, 
    nogil=True
)
def fill_sumw(sumw, sumw2, inverse, weights):
    for pos in range(inverse.size):
        idx = inverse[pos]
        weight = weights[pos]
        sumw[idx] += weight
        sumw2[idx] += weight**2


def sparse_bincount(bins, weights=None, **_):
    if weights is None:
        locs, counts = np.unique(bins, return_counts=True)

        out = np.empty(locs.shape, dtype=[("indices", locs.dtype), ("counts", counts.dtype)])
        out["indices"] = locs
        out["counts"] = counts
        
        return out

    locs, inverse = np.unique(bins, return_inverse=True)
    sumw, sumw2 = np.zeros(locs.shape, dtype=weights.dtype), np.zeros(locs.shape, dtype=weights.dtype)
    fill_sumw(sumw, sumw2, inverse, weights)

    out = np.empty(locs.size, dtype=[("indices", locs.dtype), ("sumw", sumw.dtype), ("sumw2", sumw2.dtype)])
    out["indices"] = locs
    out["sumw"] = sumw
    out["sumw2"] = sumw2
    
    return out


@nb.jit(
    [
        nb.void(nb.int64[:], nb.int64[:], nb.int64[:], nb.int64, nb.int64), #integer histogram, signed bins
        nb.void(nb.uint64[:], nb.int64[:], nb.uint64[:], nb.int64, nb.int64), #unsigned histogram, signed bins
        nb.void(nb.float32[:], nb.int64[:], nb.float32[:], nb.int64, nb.int64), #float histogram, signed bins
        nb.void(nb.float64[:], nb.int64[:], nb.float64[:], nb.int64, nb.int64), #double histogram, signed bins
        nb.void(nb.int64[:], nb.uint64[:], nb.int64[:], nb.uint64, nb.uint64), #integer histogram, unsigned bins
        nb.void(nb.uint64[:], nb.uint64[:], nb.uint64[:], nb.uint64, nb.uint64), #unsigned histogram, unsigned bins
        nb.void(nb.float32[:], nb.uint64[:], nb.float32[:], nb.uint64, nb.uint64), #float histogram, unsigned bins
        nb.void(nb.float64[:], nb.uint64[:], nb.float64[:], nb.uint64, nb.uint64), #double histogram, unsigned bins
    ],
    nopython=True, 
    parallel=False, 
    nogil=True,
)
def fill_slice(out, idxs, data, step, size):
    for pos in range(idxs.size):
        idx = idxs[pos]
        if (idx < step + size) and (idx >= step):
            out[idx - step] += data[pos]

@nb.jit(
    [
        nb.void(nb.float32[:], nb.float32[:], nb.int64[:], nb.float32[:], nb.float32[:], nb.int64, nb.int64), #float histogram, signed bins
        nb.void(nb.float64[:], nb.float64[:], nb.int64[:], nb.float64[:], nb.float64[:], nb.int64, nb.int64), #double histogram, signed bins
        nb.void(nb.float32[:], nb.float32[:], nb.uint64[:], nb.float32[:], nb.float32[:], nb.uint64, nb.uint64), #float histogram, unsigned bins
        nb.void(nb.float64[:], nb.float64[:], nb.uint64[:], nb.float64[:], nb.float64[:], nb.uint64, nb.uint64), #double histogram, unsigned bins
    ],
    nopython=True, 
    parallel=False, 
    nogil=True,
)
def fill_slice_weighted(sumw_out, sumw2_out, idxs, sumw, sumw2, step, size):
    for pos in range(idxs.size):
        idx = idxs[pos]
        if (idx < step + size) and (idx >= step):
            sumw_out[idx - step] += sumw[pos]
            sumw2_out[idx - step] += sumw2[pos]


def sparse_to_dense_chunked_bincount(bincounts, index, size, maxbin):
    nb.config.THREADING_LAYER = 'threadsafe'
    step = index[1]*size
    step_up = min(step + size, maxbin)
    is_weighted = ("sumw" in bincounts.dtype.names)
    dense_out = (
        np.zeros((step_up - step, 2), dtype=bincounts["sumw"].dtype) if is_weighted
        else np.zeros(step_up - step, dtype=np.int64)
    )
    if isinstance(bincounts, list):
        out = bincounts.pop()
        fill_slice(dense_out, out.coords.squeeze(), out.data, step, size)
        while len(bincounts):
            out = bincounts.pop()
            fill_slice(dense_out, out.coords.squeeze(), out.data, step, size)
        return dense_out
    #idx = bincounts.coords.squeeze()
    #mask = (idx < step + size) & (idx >= step)
    #dense_out[idx[mask] - step] += bincounts.data[mask]
    if is_weighted:
        fill_slice_weighted(
            dense_out[:,0], dense_out[:,1], 
            bincounts["indices"], bincounts["sumw"], bincounts["sumw2"], 
            step, size
        )
    else:
        fill_slice(dense_out, bincounts["indices"], bincounts["counts"], step, size)
    return dense_out[None]

In [None]:
# the easiest part: bin the data
binned_sorted = data.map_blocks(bin_data, nd_bins, drop_axis=[1], dtype=np.int64)

In [None]:
# turn the binned data into sparse bin counts
bincounts = binned_sorted.map_blocks(
    sparse_bincount,
    weights if fill_weighted else None,
    chunks=(np.nan,),
    dtype=weighted_dtype if fill_weighted else counts_dtype,
)

In [None]:
# here's how we create the on-cluster partitioned histogram
name = "dense_chunked_bincount"
hist_partitions = 10
n_hist_slices = 25 # number of chunks to split input partitions into
split_every = 8
hist_partition_size = (nd_minlength)//hist_partitions + 1

red_name = "sparse_reduced_bincount"
sparse_hist_reduction_layer = TruncatedTreeReduction(
    name=red_name,
    name_input=bincounts.name,
    npartitions_input=bincounts.npartitions,
    npartitions_stop=n_hist_slices,
    concat_func=cheap_sparse_sum,
    tree_node_func=_noop,
    #finalize_func=sparse_to_dense_chunked_bincount,
    split_every=split_every,
    #nnewax_partitions=hist_partitions,
    #newax_partition_size=hist_partition_size,
)

red_output_parts = len(sparse_hist_reduction_layer.get_output_keys())
red_hlg = HighLevelGraph.from_collections(red_name, sparse_hist_reduction_layer, dependencies=(bincounts, ))
red_hist = dArray(
    red_hlg, 
    red_name,
    chunks=[red_output_parts*(np.nan, )],
    dtype=weighted_dtype if fill_weighted else counts_dtype,
)

The final custom step is to take each partition of summed bin counts and turn that into *slices* of the dense histogram and then sum those slices together.

Here we use a piece of dask called "blockwise" which allows us to define exactly how to apply a function across the partitions of an input array.
It even allows us to make new axes and provides facilites for knowing which partition of data you're working on. 

In [None]:
shuffle_layer = blockwise(
    sparse_to_dense_chunked_bincount, 
    name, "ikj" if fill_weighted else "ik", 
    red_hist.name, "i",
    BlockIndex((red_output_parts, hist_partitions)), "ik",
    hist_partition_size, None,
    nd_minlength, None,
    numblocks = {red_hist.name: (red_output_parts,)},
    new_axes = (
        {"k": tuple(hist_partitions*[hist_partition_size]), "j": tuple(hist_partitions*[2])} if fill_weighted
        else {"k": tuple(hist_partitions*[hist_partition_size])}
    ),
)

dist_hist_hlg = HighLevelGraph.from_collections(name, shuffle_layer, dependencies=(red_hist, ))
dist_hist_sliced = dArray(
    dist_hist_hlg, 
    name, 
    chunks=(
        (1, hist_partition_size, 2) if fill_weighted
        else (1, hist_partition_size)
    ),
    shape=(
        (red_hist.npartitions, nd_minlength, 2) if fill_weighted
        else (red_hist.npartitions, nd_minlength)
    ),        
    dtype=(
        weights.dtype if fill_weighted
        else np.int64
    ),
)

dist_hist = dist_hist_sliced.sum(axis=0)

# check it against da.histogramdd
check_hist, _ = da.histogramdd(data, weights=weights if fill_weighted else None, bins=ndims*(nbins,), range=ndims*[(-5,5)])

In [None]:
dist_hist.visualize(optimize_graph=True)

In [None]:
%%time
check = check_hist.compute()

In [None]:
%%time
new_hist = dist_hist.compute()

In [None]:
if fill_weighted:
    outcome = np.all(np.isclose(new_hist.reshape(52,52,52,2)[1:51, 1:51, 1:51, 0], check))
else:
    outcome = np.all(new_hist.reshape(52,52,52)[1:51, 1:51, 1:51] == check)
outcome