In [1]:
from dask.distributed import Client, LocalCluster

from matplotlib import pyplot as plt
from itkwidgets import view

import dask.array.image
import numpy as np
import itk
import dask.array as da
import dask

import os
import time

## Specify the distributed computing resources

Dask supports running on HPC MPI, HPC job schedulers, cloud-based clusters, or a set of systems connected over SSH.

As a first step, we specify how to connect to our distributed computing resources.

In this example, we will test with a client that just uses the multiple cores on our local systems.

In [2]:
# Option 1: local cluster

# local_cluster = LocalCluster(n_workers=8, processes=False, memory_limit='4G')
# client = Client(local_cluster)

In [3]:
# Option 2: NERSC Cori Slurm dask-mpi cluster
#
# See SC20_pyHPC/nersc/README.md

scheduler_file = os.path.join(os.environ["SCRATCH"], "scheduler.json")
dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"

client = Client(scheduler_file=scheduler_file)

In [4]:
client

0,1
Client  Scheduler: tcp://10.128.0.20:33623  Dashboard: /user/thewtex/cori-shared-node-cpu/proxy/10.128.0.20:33473/status,Cluster  Workers: 99  Cores: 99  Memory: 0 B


In [5]:
# An example image processing pipeline with a parameter
def my_processing_pipeline(image_chunk, radius=2):
    import itk
    
    denoised = itk.median_image_filter(image_chunk, radius=radius)
    return denoised

In [6]:
input_filepath = '../data/bead_pack.tif'

In [7]:
image = itk.imread(input_filepath)
image = np.asarray(image)

## Option 0: Non-distributed

In [8]:
start = time.time()

denoised = my_processing_pipeline(image, radius=2)

elapsed = time.time() - start
print(elapsed, 'seconds')

3.7748758792877197 seconds


## Option 1: Dask Client.submit

### Submit the processing pipeline and image as a task

In [9]:
start = time.time()

denoised = client.submit(my_processing_pipeline, np.asarray(image)).result()

elapsed = time.time() - start
print(elapsed, 'seconds')

Large object of size 8.00 MB detected in task graph: 
  (array([[[112, 107, 112, ..., 113, 117, 117],
     ...  dtype=uint8),)
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good


3.8913466930389404 seconds


In [10]:
# Chunk the data into a Dask array.
chunks = chunks=(5, 200, 200) # 40 blocks

chunked_data = da.from_array(image, chunks=chunks)
chunked_data

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,200.00 kB
Shape,"(200, 200, 200)","(5, 200, 200)"
Count,41 Tasks,40 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 200.00 kB Shape (200, 200, 200) (5, 200, 200) Count 41 Tasks 40 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,200.00 kB
Shape,"(200, 200, 200)","(5, 200, 200)"
Count,41 Tasks,40 Chunks
Type,uint8,numpy.ndarray


### Scatter the data across the clients

In [11]:
# Scatter the data across the clients
scattered_data = client.scatter(chunked_data, broadcast=True)

start = time.time()

denoised = client.submit(my_processing_pipeline, scattered_data).result()

elapsed = time.time() - start
print(elapsed, 'seconds')

4.177775144577026 seconds


distributed.client - ERROR - Failed to reconnect to scheduler after 10.00 seconds, closing client
_GatheringFuture exception was never retrieved
future: <_GatheringFuture finished exception=CancelledError()>
concurrent.futures._base.CancelledError


## Option 2: Dask Array map_blocks

In [12]:
start = time.time()

denoised = da.map_blocks(my_processing_pipeline,
              chunked_data,
              radius=2,
              dtype=chunked_data.dtype)
denoised = denoised.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

7.0909483432769775 seconds


In [13]:
view(denoised)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUC3; pr…

## Option 3: Dask Array map_overlap

Extend the processed region, then trim the result.

In [14]:
start = time.time()

denoised = da.map_overlap(my_processing_pipeline,
              chunked_data,
              depth=2,
              trim=True,
              radius=2,
              dtype=chunked_data.dtype)
denoised = denoised.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

8.25371789932251 seconds


## Option 4: Xarray DataArray map_blocks

In [15]:
image = itk.imread(input_filepath)
da = itk.xarray_from_image(image)
da

In [16]:
da = da.chunk(chunks)

In [17]:
start = time.time()

denoised = da.map_blocks(my_processing_pipeline,
              kwargs = { 'radius': 2 } )
denoised = denoised.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

5.544623851776123 seconds
