# 3D Material Bead Packing XCT Simulation Instance Segmentation - Distributed Computation

In [31]:
import os
import time

import itk
import numpy as np

from matplotlib import pyplot as plt
from itkwidgets import view, label_statistics
import itkwidgets

from dask.distributed import Client, LocalCluster
import dask.array as da
import dask
import dask_image.ndmeasure
import zarr

In [2]:
# Option 1: local cluster

# local_cluster = LocalCluster(n_workers=8, processes=False, memory_limit='4G')
# client = Client(local_cluster)
# client

In [3]:
# Option 2: NERSC Cori Slurm dask-mpi cluster
#
# See SC20_pyHPC/nersc/README.md

scheduler_file = os.path.join(os.environ["SCRATCH"], "scheduler.json")
dask.config.config["distributed"]["dashboard"]["link"] = "{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status"

client = Client(scheduler_file=scheduler_file)
# client

In [4]:
file_name = '../data/bead_pack.tif'
beads = itk.imread(file_name)

In [5]:
beads_da = da.from_array(np.asarray(beads))
beads_da

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,8.00 MB
Shape,"(200, 200, 200)","(200, 200, 200)"
Count,1 Tasks,1 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 8.00 MB Shape (200, 200, 200) (200, 200, 200) Count 1 Tasks 1 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,8.00 MB
Shape,"(200, 200, 200)","(200, 200, 200)"
Count,1 Tasks,1 Chunks
Type,uint8,numpy.ndarray


In [14]:
slice_chunk_size = 100

In [7]:
beads_da = beads_da.rechunk((100,200,200))
beads_da

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,43 Tasks,2 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 4.00 MB Shape (200, 200, 200) (100, 200, 200) Count 43 Tasks 2 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,43 Tasks,2 Chunks
Type,uint8,numpy.ndarray


In [8]:
zarr_path = '../data/bead_pack.zarr'

beads_da.to_zarr(zarr_path, overwrite=True, compute=True)

In [9]:
beads = da.from_zarr(zarr_path)
beads

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,3 Tasks,2 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 4.00 MB Shape (200, 200, 200) (100, 200, 200) Count 3 Tasks 2 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,3 Tasks,2 Chunks
Type,uint8,numpy.ndarray


In [53]:
def processing_pipeline(image_chunk):
    import itk
    
    sharpened = itk.unsharp_mask_image_filter(image_chunk,
                                              sigma=2.0,
                                              threshold=10.0,
                                              amount=5.0)
    invert = itk.invert_intensity_image_filter(sharpened)
    thresholded = itk.otsu_threshold_image_filter(invert)
    
    StructuringElementType = itk.FlatStructuringElement[3]

    structuring_element = StructuringElementType.Ball(3, True)
    opened = itk.binary_morphological_opening_image_filter(thresholded,
                                                           kernel=structuring_element)
    filled = itk.binary_fillhole_image_filter(opened, fully_connected=True)
    
    distance = itk.signed_maurer_distance_map_image_filter(filled, inside_is_positive=False)
    
    watershed = itk.morphological_watershed_image_filter(distance, level=7.0)
    filled_cast = filled.astype(np.int16)
    watershed_mask = itk.mask_image_filter(watershed, filled_cast)

    return watershed_mask

In [54]:
start = time.time()

segmented = da.map_blocks(processing_pipeline,
              beads,
              dtype=np.int16)
segmented = segmented.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

9.770127773284912 seconds


In [56]:
view(beads, label_image=segmented, ui_collapsed=True)

Viewer(geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], rendered_image=<itk.itkImageP…

In [80]:
start = time.time()

segmented = da.map_blocks(processing_pipeline,
              beads,
              dtype=np.int16)
segmented = segmented.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

9.039737462997437 seconds


In [81]:
chunks = list(beads.chunks)
slice_chunks = list(chunks[0])
chunk_shift = int(slice_chunk_size / 2)
original_chunk = slice_chunks[0]
slice_chunks[0] = chunk_shift
slice_chunks.append(original_chunk - chunk_shift)
chunks[0] = tuple(slice_chunks)
chunks

[(50, 100, 50), (200,), (200,)]

In [82]:
beads_rechunk = beads.rechunk(chunks)
beads_rechunk

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,10 Tasks,3 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 4.00 MB Shape (200, 200, 200) (100, 200, 200) Count 10 Tasks 3 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,10 Tasks,3 Chunks
Type,uint8,numpy.ndarray


In [83]:
start = time.time()

segmented_rechunk = da.map_blocks(processing_pipeline,
              beads_rechunk,
              dtype=np.int16)
segmented_rechunk = segmented_rechunk.compute()

elapsed = time.time() - start
print(elapsed, 'seconds')

8.736785173416138 seconds


In [84]:
view(beads_rechunk, label_image=segmented_rechunk, ui_collapsed=True)

Viewer(geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], rendered_image=<itk.itkImageP…

In [86]:
labels = da.unique(segmented).compute()
labels

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144], dtype=int16)

In [103]:
mask = da.from_array((segmented > 0).astype(np.uint8), chunks=beads.chunks)
mask

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,3 Tasks,2 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 8.00 MB 4.00 MB Shape (200, 200, 200) (100, 200, 200) Count 3 Tasks 2 Chunks Type uint8 numpy.ndarray",200  200  200,

Unnamed: 0,Array,Chunk
Bytes,8.00 MB,4.00 MB
Shape,"(200, 200, 200)","(100, 200, 200)"
Count,3 Tasks,2 Chunks
Type,uint8,numpy.ndarray


In [105]:
voxel_counts_s = dask_image.ndmeasure.sum(mask, segmented, labels).compute()
voxel_counts_s

array([    0., 26087., 23369., 21858., 32335., 21818., 32213., 21988.,
       30703., 20289., 35008., 26299., 20673., 12931., 26033., 30273.,
       30393., 37943., 28180., 39115., 26308., 22613., 45499., 32437.,
       18155., 20498., 13813., 38670., 47265., 37708., 48669., 33479.,
       47296., 46691., 51315., 47275., 49412., 50048., 48166., 42826.,
       45699., 26127., 34969., 35248., 25885., 38064., 27937., 50194.,
       44945., 48997., 37682., 38459., 46929., 38546., 46565., 46616.,
       46875., 36703., 27726., 46324., 46539., 42702., 47626., 41524.,
       46396., 20747., 38973., 29569., 39282., 46530., 33769., 37490.,
       56717., 28201., 31825., 30522., 46146., 49372., 44097., 45790.,
       46223., 26032., 37542., 49417., 45476., 33372., 40836., 40779.,
       44447., 46246., 46668., 35127., 29100., 46893., 43154., 46404.,
       21849., 35929., 46848., 46845., 25266., 36096., 37536., 47684.,
       49251., 50515., 11927., 50816., 47309., 48359., 28707., 34585.,
      

In [106]:
voxel_counts = dask_image.ndmeasure.area(beads, segmented, labels).compute()
voxel_counts

array([3055178,   26087,   23369,   21858,   32335,   21818,   32213,
         21988,   30703,   20289,   35008,   26299,   20673,   12931,
         26033,   30273,   30393,   37943,   28180,   39115,   26308,
         22613,   45499,   32437,   18155,   20498,   13813,   38670,
         47265,   37708,   48669,   33479,   47296,   46691,   51315,
         47275,   49412,   50048,   48166,   42826,   45699,   26127,
         34969,   35248,   25885,   38064,   27937,   50194,   44945,
         48997,   37682,   38459,   46929,   38546,   46565,   46616,
         46875,   36703,   27726,   46324,   46539,   42702,   47626,
         41524,   46396,   20747,   38973,   29569,   39282,   46530,
         33769,   37490,   56717,   28201,   31825,   30522,   46146,
         49372,   44097,   45790,   46223,   26032,   37542,   49417,
         45476,   33372,   40836,   40779,   44447,   46246,   46668,
         35127,   29100,   46893,   43154,   46404,   21849,   35929,
         46848,   46

In [107]:
voxel_counts - voxel_counts_s

array([3055178.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,       0.,
             0.,       0.,       0.,       0.,       0.,      

In [125]:
def to_voxel_count_image(image_chunk):
    segmented = processing_pipeline(image_chunk)
    labels = np.unique(segmented)
    voxel_counts = dask_image.ndmeasure.area(image_chunk, segmented, labels)
    assert len(labels) == len(voxel_counts)
    result = np.zeros_like(image_chunk, np.uint16)
    for index in range(1, len(labels)):
        label = labels[index]
        result[segmented == label] = voxel_counts[index]
    return result                                  

voxel_count_image = da.map_blocks(to_voxel_count_image,
              beads,
              dtype=np.uint16)

voxel_count_image = voxel_count_image.compute()

In [131]:
voxel_count_image_rechunk = da.map_blocks(to_voxel_count_image,
              beads_rechunk,
              dtype=np.uint16)

voxel_count_image_rechunk = voxel_count_image_rechunk.compute()

In [132]:
voxel_count_max = da.maximum(voxel_count_image, voxel_count_image_rechunk)

In [163]:
chunk_labels = da.unique(segmented).compute()
max_chunk_labels = len(chunk_labels)
max_chunk_labels

145

In [None]:
def unique_labels(segmented,
                  segmented_rechunk,
                  voxel_count_image,
                  voxel_count_image_rechunk,
                  max_chunk_labels,
                  block_info=None):
    result = np.zeros_like(segmented)
    result[:,:,:] = block_info
    return result                                  

voxel_count_image = da.map_blocks(unique_labels,
              da.from_array(segmented, beads.chunk),
              da.from_array(segmented_rechunk, beads_rechunk.chunk),
              da.from_array(voxel_count_image_rechunk, beads_rechunk.chunk),
                                  max_chunk_labels
              dtype=np.uint16)

voxel_count_image = voxel_count_image.compute()

In [155]:
connected, num_features = dask_image.ndmeasure.label(voxel_count_max==23132)
connected = connected.astype(np.uint16)

In [158]:
view(connected)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUS3; pr…

In [157]:
view(beads, label_image=voxel_count_max, 
      label_image_blend=0.7,
      rotate=True,
      shadow=False,
      gradient_opacity=0.5)

Viewer(geometries=[], gradient_opacity=0.5, interpolation=False, label_image_blend=0.7, point_sets=[], rendere…

In [149]:
gt = itk.imread('../data/segmented_bead_pack_GT.tif')

In [150]:
view(gt)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageUS3; pr…