## `yt`, `unyt` and `dask` ! 

In this notebook, we present an initial attempt at creating a `dask` Custom Collection for `unyt` arrays. As of `yt4.0`, the `yt.units` functionality [was extracted into the `unyt` package](https://yt-project.org/docs/dev/yt4differences.html?highlight=ytarray#yt-units-is-now-a-wrapper-for-unyt). The base [`unyt` arrays](https://unyt.readthedocs.io/en/stable/modules/unyt.array.html) are used throughout `yt` for tracking and converting units and so in order to fully leverage `dask` within `yt` it would be helpful to have `unyt` arrays with `dask` functionality. 

In general, `dask` builds task graphs to operate on `collections`. Some commonly used `collections` are `dask.array`, `dask.dataframe` and `dask.bag`. Each of these `collections` is subclassed off of the general `DaskMethodsMixin` collection and the dask documentation provides an overview of how to add [custom collections](https://docs.dask.org/en/latest/custom-collections.html#example-dask-collection). So what we want is a new collection, `unyt_dask_array`, that can be used for buliding and execulting dask task graphs in parallel while preseving `unyt_array` functionality. 

For the present problem, both `unyt` and `dask.array` wrap numpy methods and so the question is how do we cleverly wrap and sublcass to leverage the existing wrapped methods? The `dask` numpy wrappings account for the necessary distribution and reduction operations between chunks and so the simpler place to start is by subclassing `dask.array` and adding `unyt` functionality alongside. Tracking units does not have to happen on each chunk and so in the following, we track units and a cumulative unit conversion factor within our new `unyt_dask_array` class separately and only apply the final conversion when returning the result of a `dask.compute`. 

So let's just dump some code:

In [1]:
from unyt.array import unyt_array

from dask.array.core import Array, finalize
import numpy as np

def unyt_from_dask(dask_array,
              units = None,
              registry = None,
              dtype = None,
              bypass_validation = False,
              input_units = None,
              name = None):
    (cls, args) = dask_array.__reduce__()
    da = unyt_dask_array(*args)
    da._attach_units(units, registry, dtype, bypass_validation, input_units, name)
    return da

def finalize_unyt(results,unit_name,factor):
    # the function to call for the __dask_postcompute__ hook. 
    return unyt_array(finalize(results)*factor,unit_name)

class unyt_dask_array(Array):
    def __init__(self, dask_graph, name, chunks, dtype=None, meta=None, shape=None):
        self.units = None
        self.unyt_name = None
        self.dask_name = name
        self.factor = 1.

    def _attach_units(self,units = None,
              registry = None,
              dtype = None,
              bypass_validation = False,
              input_units = None,
              name = None):
        x_np = np.array([1.])
        self._unyt_array = unyt_array(x_np, units, registry, dtype, bypass_validation, input_units, name)
        self.units = self._unyt_array.units
        self.unyt_name = self._unyt_array.name

    def to(self, units, equivalence=None, **kwargs):
        # tracks any time units are converted with a running conversion factor
        # that gets applied after calling dask methods
        init_val = self._unyt_array.value[0]
        self._unyt_array = self._unyt_array.to(units, equivalence, **kwargs)
        self.factor = self.factor * self._unyt_array.value[0] / init_val
        self.units = units
        self.unyt_name = self._unyt_array.name

    def min(self, axis=None, keepdims=False, split_every=None, out=None):
        result = np.array(super().min(axis, keepdims, split_every, out))
        return unyt_array(result*self.factor, self.units)

    def max(self, axis=None, keepdims=False, split_every=None, out=None):
        result = np.array(super().max(axis, keepdims, split_every, out))
        return unyt_array(result*self.factor, self.units)

    def __dask_postcompute__(self):
        # a dask hook to catch after .compute(), see
        # https://docs.dask.org/en/latest/custom-collections.html#example-dask-collection
        # but it does not catch all computes?
        return finalize_unyt, ((self.units, self.factor))

    

Above, we define three objects: the `unyt_from_dask` function, the `finalize_unyt` function  and the `unyt_dask_array` class. Let's focus on the new class first. 

In our new subclass, `unyt_dask_array(Array)`, `Array` is the core array class of `dask`. This class only has a `__new__` constructor, and so in the `__init__` here we only provide the arguments that get sent to `__new__`: 


```python
def __init__(self, dask_graph, name, chunks, dtype=None, meta=None, shape=None):
        self.units = None
        self.unyt_name = None
        self.dask_name = name
        self.factor = 1.
```

all those arguments are those needed for the base `Array.__new__` constructor, and when we instantiate `unyt_dask_array` the super-class's `__new__` will be called with those arguments automatically before proceeding with `unyt_dask_array.__init__()`. 

These arguments are all related to the details of how dask constructs its graphs and chunks, but we want to be able to instantiate our `unyt_dask_array` more simply. Thus, the convenience function `unyt_from_dask` constructs our new `Array` subclass from an existing `dask` array without having to know the details of how `dask` works.

So, for example, we can do: 

In [2]:
import numpy as np 
import unyt; import dask.array as da

x = da.random.random((10000, 10000), chunks=(1000, 1000))
x_da = unyt_from_dask(x, unyt.m)
x_da

Unnamed: 0,Array,Chunk
Bytes,800.00 MB,8.00 MB
Shape,"(10000, 10000)","(1000, 1000)"
Count,100 Tasks,100 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 800.00 MB 8.00 MB Shape (10000, 10000) (1000, 1000) Count 100 Tasks 100 Chunks Type float64 numpy.ndarray",10000  10000,

Unnamed: 0,Array,Chunk
Bytes,800.00 MB,8.00 MB
Shape,"(10000, 10000)","(1000, 1000)"
Count,100 Tasks,100 Chunks
Type,float64,numpy.ndarray


which behaves as a `dask` array. e.g.,: 

For example, if we load everything into memory:

In [3]:
x_da.compute()

unyt_array([[0.84873849, 0.49570262, 0.84942177, ..., 0.76667334,
             0.88406054, 0.06651547],
            [0.23614276, 0.2691459 , 0.7563452 , ..., 0.82625261,
             0.79109846, 0.09371664],
            [0.85376455, 0.64940279, 0.13865977, ..., 0.06566836,
             0.46294204, 0.79582689],
            ...,
            [0.67687777, 0.02797173, 0.58873372, ..., 0.58259256,
             0.69760201, 0.88294463],
            [0.47230504, 0.9908476 , 0.09871617, ..., 0.58093535,
             0.31079139, 0.81819524],
            [0.30721092, 0.38728382, 0.51681618, ..., 0.55783229,
             0.66753928, 0.17387946]], 'm')

we get a `unyt_array`! This happens because of the `unyt_dask_array.__dask_postcompute__()` hook tells dask to call the `finalize_unyt` function after we execute `compute`. This function simply calls the `finalize` function used by the core dask `Array` and initializes a `unyt_array` with the resulting numpy array:

```
def finalize_unyt(results,unit_name,factor):
    # the function to call for the __dask_postcompute__ hook. 
    return unyt_array(finalize(results)*factor,unit_name)
```

The `factor` is a cumulative unit conversion factor, [described below](#tracking-units).

One caveat to the `__dask_postcompute__` hook is that it does not seem to always catch. For example:

In [4]:
x_da.sum()

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Count,239 Tasks,1 Chunks
Type,float64,numpy.ndarray
Array Chunk Bytes 8 B 8 B Shape () () Count 239 Tasks 1 Chunks Type float64 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,(),()
Count,239 Tasks,1 Chunks
Type,float64,numpy.ndarray


In [5]:
x_da.sum().compute()

50002711.294516616

Does not catch and we just get a scalar factor. We do still have the units:

In [6]:
x_da.units

m

but ideally we'd be returning a `unyt_array` or `unyt_quantity` here. 

In order to catch these operations, we can subclass appropriately... For example, the `unyt_dask_array` class has the following methods that call the corresponding superclass methods for `min` and `max` and then convert to a standard `ndarray` and then return a  `unyt_array`:

```python
    def min(self, axis=None, keepdims=False, split_every=None, out=None):
        result = np.array(super().min(axis, keepdims, split_every, out))
        return unyt_array(result*self.factor, self.units)

    def max(self, axis=None, keepdims=False, split_every=None, out=None):
        result = np.array(super().max(axis, keepdims, split_every, out))
        return unyt_array(result*self.factor, self.units)
```    

So when we do:

In [7]:
x_da.min()

unyt_array(1.56775319e-08, 'm')

we get our standard `unyt_array`. 


### tracking units 

We're also using a bit of trickery to deal with unit conversions in this class within the `unyt_dask_array.to` method, copied here:

```python
def to(self, units, equivalence=None, **kwargs):
        # tracks any time units are converted with a running conversion factor
        # that gets applied after calling dask methods
        init_val = self._unyt_array.value[0]
        self._unyt_array = self._unyt_array.to(units, equivalence, **kwargs)
        self.factor = self.factor * self._unyt_array.value[0] / init_val
        self.units = units
        self.unyt_name = self._unyt_array.name
```        

so within our `unyt_dask_array`, we initialize a hidden `unyt_array` with a value of single value of `1`. Then any time a conversion occours, we apply the conversion to this `self._unyt_array` and track a cumulative conversion factor. Now, whenever we return calculations from `dask` into memory, we simply multiply by this conversion factor and attach the appropriate units. There is likely a more elegant solution here, but the basic idea is to track the units separately from the dask functionaly as the dask operations on each chunk are generally independent of the units scaling. 

So here's an example conversion:

In [8]:
x_da.to(unyt.km)

In [9]:
x_da.units

km

In [10]:
x_da.min()

unyt_array(1.56775319e-11, 'km')

In [11]:
x_da.to(unyt.nanometer)

In [12]:
x_da.min()

unyt_array(15.67753194, 'nm')

from which we see our units changing appropriately. This approach is nice because when we convert multiple times: 


In [13]:
x_da.to(unyt.cm)
x_da.to(unyt.micrometer)
x_da.to(unyt.km)

we are only operating on the hidden `_unyt_array` to track what the final conversion factor should be. In this case, we're converting back to our original units:

In [14]:
x_da.min()

unyt_array(1.56775319e-11, 'km')

The difficulty with this method is that we want to avoid manually subclassing the many methods in the base `Array` collection -- there is likely a clever way to automatically wrap all the methods. Furthermore, we are not using `unyt`'s architecture much at all here, and things may become more complicated, for example, when multiplying two `dask_unyt_array` objects together. But this seems a good start!

### parallel!

Now, because our `unyt_from_dask` class is built off of a `dask` collection, it will work with the parallel scheduling. 

So let's spin up a client:

In [15]:
from dask.distributed import Client
client = Client(threads_per_worker=2, n_workers=2)

In [16]:
client

0,1
Client  Scheduler: tcp://127.0.0.1:33059  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 2  Cores: 4  Memory: 33.51 GB


and re-instantiate our arrays:

In [17]:
x = da.random.random((10000, 10000), chunks=(1000, 1000))
x_da = unyt_from_dask(x, unyt.m)
x_da

Unnamed: 0,Array,Chunk
Bytes,800.00 MB,8.00 MB
Shape,"(10000, 10000)","(1000, 1000)"
Count,100 Tasks,100 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 800.00 MB 8.00 MB Shape (10000, 10000) (1000, 1000) Count 100 Tasks 100 Chunks Type float64 numpy.ndarray",10000  10000,

Unnamed: 0,Array,Chunk
Bytes,800.00 MB,8.00 MB
Shape,"(10000, 10000)","(1000, 1000)"
Count,100 Tasks,100 Chunks
Type,float64,numpy.ndarray


now when we compute our properties, `dask` will compute values from chunks independently. So when we take a min:

In [18]:
x_da.min()

unyt_array(1.47933353e-08, 'm')

on our `dask` dashboard, we can see the distributed tasks complete via the Task Graph:


![TaskStream](resources/unyt_dask_taskgraph.png)


In the above example for `min` and `max`, we are converting the result from the superclass call to a standard `ndarray` as the results here will generally be small enough to be held in memory, even when returning an array using the `axis` argument:

In [19]:
x_da.max(axis=0)

unyt_array([0.99986894, 0.99992408, 0.99952803, ..., 0.99995384,
            0.99989531, 0.99997846], 'm')

As noted above, when we convert units, we don't actually touch the dask array chunks but track a cumulative conversion factor using a hidden `_unyt_array` with a single value. So when we convert:

In [21]:
x_da.to(unyt.cm)
x_da.to(unyt.micrometer)

nothing gets distributed to the dask chunks. The conversion factor only gets applied to the chunks **after** the reduction step:

In [22]:
x_da.max(axis=0)

unyt_array([999868.94269632, 999924.07710443, 999528.02896929, ...,
            999953.84159969, 999895.31102362, 999978.45862321], 'μm')

So in general, this method shows a fairly straightforward approach for adding dask support to `unyt` which can be in turn leveraged by `yt`. The main work to be done is to devise a more clever way to wrap the `dask.array` methods to attach the final units in order to avoid manually subclassing each method. 