In [None]:
%matplotlib inline

In [None]:
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
import scipy.stats
import skimage.exposure
import skimage.feature
import skimage.filters
import skimage.io
import skimage.measure
import skimage.morphology
import skimage.restoration

In [None]:
# Display slices from a volume
def display(volume, cmap="gray"):
    _, axes = plt.subplots(nrows=5, ncols=6, figsize=(16, 14))
    
    vmin = volume.min()
    
    vmax = volume.max()
    
    for ax, plane in zip(axes.flatten(), volume[::2]):
        ax.imshow(
            plane,
            cmap=cmap,
            vmax=vmax,
            vmin=vmin
        )
        
        ax.set_xticks([])
        
        ax.set_yticks([])

# Display slices from a labeled volume
def display_labels(labels, cmap_name="viridis"):
    cmap = matplotlib.cm.get_cmap(cmap_name)
    
    masked = np.ma.masked_where(labels == 0, labels)
    
    cmap.set_bad(color="black")
    
    display(masked, cmap)

# Three-dimensional image processing

This tutorial aims to highlight some of the three-dimensional image processing functionality available in `skimage` by segmenting a cellular image provided by the Allen Institute for Cell Science. This tutorial is an adaptation of Emmanuelle Gouillart's [
Segmentation of 3-D tomography images with Python and scikit-image](http://emmanuelle.github.io/segmentation-of-3-d-tomography-images-with-python-and-scikit-image.html). 💖

`skimage` expects three-dimensional data to conform to `(plane, row, column[, channels])`. We will be working with a three-dimensonal grayscale image; the `channels` dimension is omitted.

This three-dimensional image is composed of many two-dimensional images captured at different focal depths. Hence, the pixels between planes are spatially further from their row and column counterparts. We'll keep track of an additional variable `spacing` which describes the spacing between pixels.

In [None]:
data = skimage.io.imread("../images/cells.tif")

spacing = (0.29, 0.01625, 0.01625)

print("shape: {}".format(data.shape))

print("dtype: {}".format(data.dtype))

display(data)

We'll start by rescaling the image to the range `(0.0, 1.0)`, which is required by `skimage` when working with floating point data. 

Most experimental images are affected by salt and pepper noise. A few bright artifacts can decrease the relative intensity of the pixels of interest. Clipping the darkest and brightest 0.5% of pixels will increase the overall contrast of the image.

In [None]:
vmin, vmax = scipy.stats.scoreatpercentile(data, (0.5, 99.5))

rescaled = skimage.exposure.rescale_intensity(
    data, 
    in_range=(vmin, vmax), 
    out_range=np.float32
).astype(np.float32)

display(rescaled)

We can observe the contrast of the image has improved after clipping and rescaling. We'll apply a denoising operation to reduce the amount of salt and pepper noise in the image. `skimage.restoration.denoise_bilateral` is an an edge-preserving, denoising filter. It has not been adapted for three-dimensional data, but we can use it by applying the operation plane-wise.

In [None]:
denoised = np.empty_like(rescaled)

for index, plane in enumerate(rescaled):
    denoised[index] = skimage.restoration.denoise_bilateral(
        plane, 
        multichannel=False
    )

display(denoised)

In [None]:
_, (a, b) = plt.subplots(nrows=1, ncols=2, figsize=(16, 4))

a.hist(data.flatten(), bins=32)
a.set_title("Original")

b.hist(denoised.flatten(), bins=32)
b.set_title("Denoised");

Comparing the pixel intensity histograms of the original and denoised data reveals a more favorable bimodal distribution in the denoised image. The two modes correspond to background and foreground pixels, respectively. `skimage.filters.threshold_li` will determine the threshold value separating foreground pixels from background pixels. We'll use the threshold value to create a binary image for segmentation.

In [None]:
threshold = skimage.filters.threshold_li(denoised)

binary = denoised >= threshold

_, (a, b) = plt.subplots(nrows=1, ncols=2, figsize=(16, 4))

a.hist(denoised.flatten(), bins=32)
a.axvline(threshold, c="r")
a.set_title("Threshold = {:0.3f}".format(threshold))

b.imshow(binary[32], cmap="gray")
b.set_title("Thresholded image (plane = 32)");

display(binary)

The binary image has two undesirable features: darker regions of the cell interiors were identified as background pixels, and brighter pixels in the noisy planes near the top and bottom of the image were identified as foreground pixels.

We can fill holes uing the `skimage.morphology.remove_small_holes` function. Likewise, unwated objects can be removed using `skimage.morphology.remove_small_objects`. The `min_size` parameter determines a small hole or object in total pixels. An easy approximation of size is the smallest cube which encompasses a hole or object to remove.

In [None]:
normalized_spacing = tuple(np.divide(spacing[1], spacing))

print("normalized spacing: {}".format(normalized_spacing))

In [None]:
a = 30

removed_small_holes = skimage.morphology.remove_small_holes(
    binary, 
    min_size=np.product(np.multiply(a, normalized_spacing))
)

display(removed_small_holes)

In [None]:
a = 30

removed_small_objects = skimage.morphology.remove_small_objects(
    removed_small_holes,
    min_size=np.product(np.multiply(a, normalized_spacing))
)

display(removed_small_objects)

This binary image is a good segmentation of the original image. We could apply `skimage.measure.label` to assign unique labels to disjoint image regions. Connected regions in the binary image are assigned the same label (this can be observed by running the cell below). A better segmentation would assign different labels to the regions which appear disjoint in the original image.

In [None]:
labels = skimage.measure.label(removed_small_objects)

_, (a, b, c) = plt.subplots(nrows=1, ncols=3, figsize=(16, 4))

a.imshow(rescaled[30, :100, 125:], cmap="gray")
a.set_title("Rescaled")

b.imshow(labels[30, :100, 125:])
b.set_title("Labels")

c.imshow(labels[30, :100, 125:] == 8, cmap="gray")
c.set_title("Labels = 8");

Watershed segmentation can distinguish touching objects, an operation referred to as declumping. Watershed works by flooding basins of low intensity until joined by adjacent flooded basins. For declumping, these basins are distinguished by markers generated from a distance image. Points furthest from an edge have the lowest intensity and should be identified as markers.

In [None]:
distance = ndi.distance_transform_edt(removed_small_objects)

peak_local_max = skimage.feature.peak_local_max(
    distance,
    footprint=np.ones((15, 15, 15), dtype=np.bool),
    indices=False,
    labels=skimage.measure.label(removed_small_objects)
)

markers = skimage.measure.label(peak_local_max)

labels = skimage.morphology.watershed(
    -rescaled, 
    markers, 
    mask=removed_small_objects
)

display_labels(labels)

The declumping step has over-segmented a few regions of interest. We can inspect the `distance` image and plot the markers to understand where the basins for the watershed are being defined.

In [None]:
_, axes = plt.subplots(nrows=2, ncols=6, figsize=(16, 5))

vmin = distance.min()

vmax = distance.max()

for index, ax in enumerate(axes.flatten()):
    ax.imshow(
        -distance[30 + index],
        cmap="gray",
        vmin=-vmax,
        vmax=-vmin
    )
    
    peaks = np.nonzero(peak_local_max[30 + index])
    
    ax.plot(peaks[1], peaks[0], "r.")

    ax.set_xticks([])
    
    ax.set_yticks([])

The objects we're trying to identify aren't perfectly smooth. Natural variations in shape, such as figure eights, are falsely declumped. Additionally, the objects with holes or gaps near their ends were oversegmented. This could be due to the nonuniformity in the distance image or by searching small regions for makers in `skimage.feature.peak_local_max`. 

## Challenge problems

### Improve the segmentation

A few objects were oversegmented in the declumping step. Try to improve the segmentation and assign each object a single, unique label. You can try:

1. generating a smoother image by modifying the `win_size` parameter in `skimage.restoration.denoise_bilateral`, or try another filter. Many filters are available in `skimage.filters` and `skimge.filters.rank`.
1. adjusting the threshold value by trying another threshold algorithm such as `skimage.filters.otsu` or entering one manually.
1. generating different markers by changing the size of the `footprint` passed to `skimage.feature.peak_local_max`. Alternatively, try another distance function or limit the planes on which markers can be placed.

### Bonus challenge: segment the membrane channel

If segmenting the nuclei was too easy, try segmenting the accompanying membrane channel. You can load the membrane image with `skimage.io.imread`:

```python
membrane = skimage.io.imread("../images/cells_membrane.tif")
```

Hint: There should be one nuclei per membrane object.