### Processing a lot of images using dask

This introduction to dask outlines how to do the basic preprocessing operations common to analyzing spatiotemporal imaging data. 

See also this example from the master of dask: http://matthewrocklin.com/blog/work/2017/01/17/dask-images.

`dask.array` docs: http://dask.pydata.org/en/latest/array.html  
`dask.distributed` docs: http://distributed.readthedocs.io/en/latest/ 

In [None]:
import warnings
warnings.filterwarnings('ignore')
from dask.distributed import Client, LocalCluster
import dask.array as da
from h5py import File

In [None]:
def make_fake_h5_data(dims, directory=None):        
    """
    Save a synthetic n-dimensional dataset as multiple .h5 files. Returns a list of the filenames created
    """
    
    from numpy.random import randn
    from h5py import File
    from tempfile import mkdtemp
    from os.path import sep
    
    if directory == None:
        directory = mkdtemp() + sep
    
    fname_spec = 'ind_{0:05d}.h5'    
    fnames = [directory + fname_spec.format(ind) for ind in range(dims[0])]
    
    for fn in fnames:
        data = randn(*dims[1:])        
        with File(fn) as f:
            f['default'] = data
    
    return fnames

Here we make some fake 4D data and save to disk. Even though the data are synthetic, I will refer to the first axis as "time" and the latter axes as "space". Our data starts out in the normal basal state for 4D microscopy data: 1 file per timepoint, and each file contains a 3D volume. 

In [None]:
from numpy import array
dims = (10, 100, 100, 100)
fnames = make_fake_h5_data(dims)

When we make a dask array, we need to specify how the data / computations will be distributed. This is 
done with the `chunks` argument to the ``dask.array`` constructor. If chunks is a tuple of ints, e.g. 
``(1, 3, 3, 3)``, our data will be logically divided into sub-arrays each with the size ``(1, 3, 3, 3)``. 

If the first axis of our data is time, then this chunking scheme is natural for operations where we want to do spatial operations on each image, like spatial filtering. However, for timeseries operations like baseline normalization, we will need a different chunking arrangement. 

In [None]:
chunks = dims[1:]

The `dask.array.from_array()` method takes anything that behaves like a numpy array; In this example the data are stored on disk as `.h5` files, but the `from_array()` would work just as well on a function that gave a numpy-style interface to a `.tif` file or a raw binary file.

In [None]:
# loop through fnames, making a list of dask arrays, one for each file
tmp = []
for fn in fnames:
    tmp.append(da.from_array(File(fn,'r')['default'], chunks=chunks))

# stack the dask arrays to form a single array.
data = da.stack(tmp)

# the previous code is equivalent to this one-liner:
data = da.stack([da.from_array(File(fn,'r')['default'], chunks=chunks) for fn in fnames])

Like spark rdds, most operations on dask arrays are lazy. Indexing a dask array will just return another dask array:

In [None]:
data[0,0,0,0]

To get our data out, we need to tell dask explicitly to compute a result, using the `dask.array.compute()` method:

In [None]:
data[0,0,0,0].compute()

Because our data spans multiple files, the speed of array indexing depends on the axes the index spans:

In [None]:
%%time
## This reads from a single file
tmp = data[0].compute()

In [None]:
%%time
# This reads a little from all files, so it's slower
tmp = data[:,:,:,:10].compute()

`dask.array.compute()` uses local resources (threads, by default) to process data. But  we ultimately want to scale our analysis to cluster-level computing resources. This notebook does not require a compute cluster, but dask lets us run local tasks as if we had a compute cluster by using a `dask.distributed.Client` object and constructing it with a `dask.distributed.LocalCluster` object. See https://github.com/dask/dask-drmaa for an example of this same interface applied to a compute cluster.

Once we create our `Client` object, the `dask.array.compute()` method will implicitly use the resources associated with out `Client`

In [None]:
from dask.distributed import LocalCluster

# configure our local cluster to use 1 worker and threads instead of processes. This is not optimized for 
# performance, just for demonstration.
lc = LocalCluster(n_workers=1, processes=False)
client = Client(lc)

In [None]:
# the client has a link to a dashboard for tracking the progress of your jobs. This dashboard is much 
# more exciting than the spark status page.
client

We can look at the ``chunks`` property to see how our data is arranged:

In [None]:
data.chunks

With this chunking scheme, there is 1 chunk per timepoint, and each chunk is an entire 3D image.

Now suppose we want to apply an image filter to every image in `data`. Recall that the chunking scheme of `data` is 1 3D image per chunk. So, to map a function to each chunk in our array, we can use the `dask.array.map_blocks()` method. This is just like `rdd.map()` in spark.

In [None]:
from scipy.ndimage.filters import median_filter

# each chunk / block is 4D, even though there's nothing in the first axis, 
# so we need a 4D image filter that does nothing to the first axis
filter_size = (1,3,3,3)
data_filtered = data.map_blocks(lambda v: median_filter(v, size=filter_size))

Now lets do simple rigid registration on the `data`. This is kind of annoying to do in dask, and there are probably easier ways to do it, but I haven't played around with this much. With synthetic random data this "registration" is meaningless, but the algorithm doesn't know that.

In [None]:
# define a reference image:
ref = data[0].compute()

# define a function that takes an image and estimates how to align it with the reference
# here we do simple fourier-based estimation of translations
# FYI map_blocks is not really designed for this kind of thing, but it works. This API might change.
def reg(im, reference=ref):
    from skimage.feature.register_translation import register_translation
    from numpy import squeeze
    # im will have shape (1,z,y,x), so we use squeeze to make it 3D
    
    # this part is annoying: we have to make the result explicitly 2D 
    shifts = register_translation(squeeze(im), reference)[0].reshape(1,-1)
    return shifts

# Because we are doing something a little funny with our data (turning a 4D array into 3 numbers) 
# we need to be more explicit in map_blocks about how the data shape will change, so we specify the new
# chunks and the axes that will disappear. I don't like this very much and will be happy to find something
# simpler.
shifts = data.map_blocks(reg, dtype='float', drop_axis=(2,3), chunks=(1,1)).compute()

# once we have the shifts, let's apply them to data.
# First I define a function for shifting each image by the correct shifts. This function uses a special
# keyword argument, block_id, that tells map_blocks to supply a block_id tuple for each block, like 
# a key in spark. Without the block_id value it would not be possible to apply the correct shift value.
def shifter(im, block_id):
    from scipy.ndimage.interpolation import shift
    time_index = block_id[0]
    return shift(im, shifts[time_index])

# now we map that shifting function to each image in data using map_blocks. 
data_shifted = data.map_blocks(shifter, dtype='float')

Now that our data is registered, we can consider doing some kind of timeseries operation on it, e.g. standardization with a zscore. Our data will need a different chunking scheme -- we will need at most 1 chunk per timeseries, and we can evenly divide the spatial axes into the remaining chunks. There is overhead associated with chunking, so we do not want a single timeseries per chunk, e.g. `chunks=(num_timepoints, 1, 1, 1)`. To reduce the overhead associated with a lot of chunks, we can do something like `chunks=(num_timepoints, dims[0]//10, dims[1]//10, dims[2]//10)`. The official recommendation is 10-100 MB per chunk, so use that as a guideline.

In [None]:
new_chunks = (dims[0], dims[1] // 10, dims[2] // 10, dims[3] // 10)
data_rechunked = data_shifted.rechunk(chunks=new_chunks)
data_rechunked.chunks

Now our rechunked array has blocks that contain an entire timeseries, so we can use `map_blocks` to map a function like `scipy.stats.zscore`

In [None]:
def normalize(v):
    from scipy.stats import zscore
    # remember that v will be 4D, so tell zscore which axis to work on
    return zscore(v, axis=0)

data_zscored = data_rechunked.map_blocks(normalize)

Now the data are preprocessed (spatially filtered, motion-corrected, temporally filtered) and ready for whatever else you need to do.