## Leveraging dask to perform chunk-wise operations across stacks of images

The goal of this notebook is produce a MWE of a chunk-wise median across a stack of data. The idea here is to reshape an array of 2D arrays into a single 3D array stack that has been chunked along the dimension we will compute across. 

In [None]:
%load_ext memory_profiler
import time
import warnings
warnings.filterwarnings('ignore') # sshhhhh

import dask.array as da
import numpy as np


### Step 1) Create a toy array to work with
Let's make a data-cube with the following dimensions (9, 3, 3). We will do this starting with a list of 81 integers from 1 to 81 that is reshaped to a (9, 3, 3). 

In [None]:
def compare(dask_ans, numpy_ans):
    try:
        assert np.sum(dask_ans - numpy_ans) == 0
    except AssertionError as e:
        print('FAIL')
    else:
        print('PASS')

def duration(st, et, convert=True):
    deltat = et - st
    units = 'minutes'
    conversion = 1/60
    deltat *= conversion
    print(f"Time to compute median: {deltat:.2f} {units}")
    return deltat, units
        
def da_median(da_array):
    st = time.time()
    da_med = da.map_blocks(np.median, da_array, axis=0, drop_axis=0).compute()
    et = time.time()
    runtime = duration(st, et)
    return da_med, runtime

def np_median(np_array):
    st = time.time()
    np_med = np.median(np_array, axis=0)
    et = time.time()
    runtime = duration(st, et)
    return np_med, runtime

def npstats(np_array):
    nbytes = np_array.nbytes
    units='kB'
    conversion = 1e3
    if nbytes > 1e6 and nbytes<1e9:
        units='MB'
        conversion = 1e6
    
    elif nbytes > 1e9:
        units = 'GB'
        conversion = 1e9
        
    array_mem_size = nbytes/conversion
    print(f"Size of numpy array {array_mem_size:.2f} {units}")
    return array_mem_size
    
    
def make_array(use_dask=True, shape=(9, 3, 3), chunksize=(9, 1, 1)):
    nsamp=1
    for val in shape:
        nsamp *= val 

    a = np.linspace(1,nsamp, nsamp)
    final = a.reshape(shape)
        
    if use_dask:
        final = da.from_array(final).rechunk(chunksize)
        
    return final

Use the helper function <code>make_array</code> to generate <code>numpy</code> and <code>dask</code> arrays with the default size and shape.

In [None]:
np_array = make_array(use_dask=False)
da_array = make_array(use_dask=True)

<code>dask</code> arrays operate lazily, whereas <code>numpy</code> arrays are always in memory. Each time an operation is applied to a <code>dask</code> array, the task is delayed until it is explicity called.  

In [None]:
a = da_array[:, 1, 0]
b = np_array[:, 2, 2]
c = da_array[:, 2, 2].compute()
d = da_array[:, 0, 0]
e = da_array[:, 0, :2]
print(a, '\n', b, '\n', c)

<code>dask</code> also provides two very handy visualization. The first is an HTML representation of the array object you have created, complete with _very_ useful metadata. The second is a static image of the object's task graph. 

In [None]:
da_array

In [None]:
da_array[0].compute()

In [None]:
npstats(np_array)

In [None]:
a.compute()

In [None]:
a.visualize()

In [None]:
d.visualize()

In [None]:
e.visualize()

We can use the <code>map_blocks</code> function to apply any function we desire across the array chunks. Here we use <code>np.median</code>

In [None]:
da_med = da.map_blocks(np.median, da_array, axis=0, drop_axis=0)
print(da_med)

In [None]:
da_med.compute()

In [None]:
da_med.visualize()

#### <code>dask</code> results:

In [None]:
%%timeit
_ = da_med.compute()

<hr>

#### <code>numpy</code> results:

In [None]:
%%timeit
_ = np.median(np_array, axis=0)

<hr>

Note that here if your dataset is very small, <code>dask</code> performs much worse because of the overhead. When we increase the array sizes to mimic something like 5 full-frame ACS images, we see that chunked method with <code>dask</code> scales much better than <code>numpy</code>

In [None]:
np_example = make_array(use_dask=False, shape=(5, 1024, 1024))

In [None]:
da_example1 = make_array(
    use_dask=True, 
    shape=(5, 1024, 1024),
    chunksize=(5, 70, 300)
)

In [None]:
da_example = make_array(
    use_dask=True, 
    shape=(5, 1024, 1024),
    chunksize=(5, 70, 300)
)

In [None]:
a= (da_example1[0] - da_example[0]).sum()

In [None]:
a.visualize(optimize_graph=True)

In [None]:
a.compute(optimize_graph=True)

In [None]:
da_example

In [None]:
npstats(np_example)

In [None]:
from dask.diagnostics import ResourceProfiler


In [None]:
with ResourceProfiler(dt=0.5) as rprof:
    np_med, np_time = np_median(np_example)

In [None]:
rprof.visualize(filename='np_array.html')

In [None]:
with ResourceProfiler(dt=0.5) as da_rprof:
    da_med, da_time = da_median(da_example)

In [None]:
da_rprof.visualize(filename='da_array.html')

In [None]:
def run_numpy_test(narrays):
    datadict = {'narrays':[],'runtime':[], }
    for i in range(nsamples):
        _test_array = 
        med, runtime = np_median(np_array)

In [None]:
a = da_example.rechunk('auto')

### Implementing with the reductions module

In [None]:
from functools import wraps

In [None]:
da.reduction

In [None]:
np.median, example_chunked, axis=0, chunks=(1,1), drop_axis=0

In [None]:
def da_median(a, axis=None, drop_axis=None, keepdims=None, dtype=None, split_every=None, out=None):
    return da.map_blocks(np.median, a, axis=axis, drop_axis=drop_axis)

In [None]:
@wraps(da_median)
def median(
        a,
        axis=None,
        dtype=None,
        keepdims=False,
        split_every=None,
        out=None
):
    if dtype is not None:
        dt = dtype

    else:
        dt = getattr(np.empty((1,), dtype=a.dtype).sum(), "dtype", object)

    result = da.reduction(
        a,
        da_median,
        da_median,
        axis=axis,
        keepdims=keepdims,
        dtype=dt,
        split_every=split_every,
        out=out,
    )
    return result

In [None]:
example_chunked