In [None]:
%matplotlib inline

# Tutorial 3: Segmentation

This tutorial was adapted from:
- https://github.com/scikit-image/skimage-tutorials/blob/master/lectures/4_segmentation.ipynb
- https://scikit-image.org/docs/stable/_downloads/plot_marked_watershed.ipynb
- https://scikit-image.org/docs/stable/_downloads/plot_rag_merge.ipynb

--------------

## Separating an image into one or more regions of interest.

Everyone has heard or seen Photoshop or a similar graphics editor take a person from one image and place them into another.  The first step of doing this is *identifying where that person is in the source image*.

In popular culture, the Terminator's vision segments humans:

<img src="../img/terminator-vision.png" width="700px"/>

### Segmentation contains two major sub-fields

* **Supervised** segmentation: Some prior knowledge, possibly from human input, is used to guide the algorithm.  Supervised algorithms currently included in scikit-image include
  *  Thresholding algorithms which require user input (`skimage.filters.threshold_*`)
  * `skimage.segmentation.random_walker`
  * `skimage.segmentation.active_contour`
  * `skimage.segmentation.watershed`
* **Unsupervised** segmentation: No prior knowledge.  These algorithms attempt to subdivide into meaningful regions automatically.  The user may be able to tweak settings like number of regions.
  *  Thresholding algorithms which require no user input.
  *  `skimage.segmentation.slic`
  * `skimage.segmentation.chan_vese`
  * `skimage.segmentation.felzenszwalb`
  * `skimage.segmentation.quickshift`


First, some standard imports and a helper function to display our results

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from skimage import io, data, draw, color, filters, segmentation as seg

def image_show(image, nrows=1, ncols=1, cmap='gray', **kwargs):
    fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 16))
    ax.imshow(image, cmap='gray')
    ax.axis('off')
    return fig, ax

## Thresholding

In some images, global or local contrast may be sufficient to separate regions of interest.  Simply choosing all pixels above or below a certain *threshold* may be sufficient to segment such an image.

Let's try this on an image of a textbook.

In [None]:
text = data.page()

image_show(text);

### Histograms

A histogram simply plots the frequency (number of times) values within a certain range appear against the data values themselves.  It is a powerful tool to get to know your data - or decide where you would like to threshold.

In [None]:
fig, ax = plt.subplots(1, 1)
ax.hist(text.ravel(), bins=32, range=[0, 256])
ax.set_xlim(0, 256);

### Experimentation: supervised thresholding

Try simple NumPy methods and a few different thresholds on this image.  Because *we* are setting the threshold, this is *supervised* segmentation.

In [None]:
text_segmented = # your code here

image_show(text_segmented);

Not ideal results!  The shadow on the left creates problems; no single global value really fits.

What if we don't want to set the threshold every time?  There are several published methods which look at the histogram and choose what should be an optimal threshold without user input.  These are unsupervised.  

### Experimentation: unsupervised thresholding

Here we will experiment with a number of automatic thresholding methods available in scikit-image.  Because these require no input beyond the image itself, this is *unsupervised* segmentation.

These functions generally return the threshold value(s), rather than applying it to the image directly.

In [None]:
filters.try_all_threshold(text, figsize=(20, 10));

In [None]:
text_threshold = filters.threshold_local(text, 51, offset=10)
image_show(text < text_threshold);

## Supervised segmentation

Thresholding can be useful, but is rather basic and a high-contrast image will often limit its utility.  For doing more fun things - like removing part of an image - we need more advanced tools.

### Watershed

In [None]:
from skimage import morphology
from scipy import ndimage as ndi

In [None]:
image = data.camera()

# denoise image
denoised = filters.rank.median(image, morphology.disk(2))

In [None]:
# find continuous region (low gradient -
# where less than 10 for this image) --> markers
# disk(5) is used here to get a more smooth image
markers = filters.rank.gradient(denoised, morphology.disk(5)) < 10

# Turn binary mask into labels: 0 is background, all nonzero values turned into unique labels
markers, _ = ndi.label(markers)

In [None]:
# local gradient (disk(2) is used to keep edges thin)
gradient = filters.rank.gradient(denoised, morphology.disk(2))

In [None]:
# process the watershed
labels = seg.watershed(gradient, markers)

In [None]:
fig, axes = plt.subplots(ncols=4, figsize=(20, 10))
for img, ax in zip([image, gradient, markers, labels], axes):
    ax.imshow(img, cmap=plt.cm.nipy_spectral)
    ax.axis('off')

## Unsupervised segmentation

Sometimes, human input is not possible or feasible - or, perhaps your images are so large that it is not feasible to consider all pixels simultaneously.  Unsupervised segmentation can then break the image down into several sub-regions, so instead of millions of pixels you have tens to hundreds of regions.

### SLIC

There are many analogies to machine learning in unsupervised segmentation.  Our first example directly uses a common machine learning algorithm under the hood - K-Means.

In [None]:
# SLIC works in color
coffee = data.coffee()
plt.imshow(coffee)

In [None]:
coffee_slic = seg.slic(coffee)

In [None]:
# label2rgb replaces each discrete label with the average interior color
coffee_out = color.label2rgb(coffee_slic, chelsea, kind='avg')
coffee_out = seg.mark_boundaries(coffee_out, coffee_slic, color=(0, 0, 0))
fig, ax = plt.subplots(figsize=(20, 10))
ax.imshow(coffee_out)

## RAG Merging

This example constructs a Region Adjacency Graph (RAG) and progressively merges regions that are similar in color. Merging two adjacent regions produces a new region with all the pixels from the merged regions. Regions are merged until no highly similar region pairs remain.

In [None]:
from skimage.future import graph

In [None]:
def _weight_mean_color(graph, src, dst, n):
    """Callback to handle merging nodes by recomputing mean color.

    The method expects that the mean color of `dst` is already computed.

    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the `"weight"` attribute set as the absolute
        difference of the mean color between node `dst` and `n`.
    """

    diff = graph.node[dst]['mean color'] - graph.node[n]['mean color']
    diff = np.linalg.norm(diff)
    return {'weight': diff}

def merge_mean_color(graph, src, dst):
    """Callback called before merging two nodes of a mean color distance graph.

    This method computes the mean color of `dst`.

    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    """
    graph.node[dst]['total color'] += graph.node[src]['total color']
    graph.node[dst]['pixel count'] += graph.node[src]['pixel count']
    graph.node[dst]['mean color'] = (graph.node[dst]['total color'] /
                                     graph.node[dst]['pixel count'])

In [None]:
g = graph.rag_mean_color(coffee, coffee_slic)

In [None]:
labels2 = graph.merge_hierarchical(coffee_slic, g, thresh=35, rag_copy=False,
                                   in_place_merge=True,
                                   merge_func=merge_mean_color,
                                   weight_func=_weight_mean_color)

out = color.label2rgb(labels2, coffee, kind='avg')
out = seg.mark_boundaries(out, labels2, (0, 0, 0))
fig, ax = plt.subplots(figsize=(20, 10))
ax.imshow(out)