# Computer Vision (Winter 2021/22)

#### Practice Session 7: Segmentation

November 30th, 2021

Axel Schaffland, Ulf Krumnack

Institute of Cognitive Science
University of Osnabrück

## Today's Session

* Questions from last sheet
    * [Sheet 1 Assignment 1b](../../sheets/sheet01/sheet01.Rmd):Proof linearity and homogenity
        * What exactly are we doing?
    * [Sheet 2 Assignment 3c](../../sheets/sheet02/sheet02.Rmd): Adaptive Histogram Equalisation
        * Why are the algorithms in the solution more efficient? 
* New sheet
    * Region Growing
    * K-Means
* Canny Edge    
* Mean Shift Segmentation

## Zero Crossing

In [None]:
from IPython.display import IFrame
IFrame("../../resources/segmentation/zeroCrossingComplete.pdf", width=900, height=800)

### K-means clustering
For background information watch Lecture 11 of the Information Theory, Pattern Recognition, and Neural Networks course by the late David MacKay: http://www.inference.org.uk/itprnn_lectures/



In [1]:
from IPython.display import IFrame
IFrame("../../resources/segmentation/kmeans.pdf", width=900, height=800)

Questions:
* Is spacial information relevant?
* Local or global solution?

## Region Growing

Recursive:
```
floodfill(img, mask, (x,y), color, region)
    if pixel not visited:
        mark pixel visited in mask
        if pixel value below treshold distance from region color:
            add pixel to region
            # check if these pixels are inside the image
            floodfill(pixel left of current pixel)
            floodfill(pixel right of current pixel)
            floodfill(pixel top of current pixel)
            floodfill(pixel down of current pixel) 
```

# Mean Shift Segmentation

## Comparison of k-means clustering and mean shift

k-means | mean shift
:---|:---
parametric| non-parametric
density as a superpostion of simpler distributions | smooth distribution
find location (and shape) of simple distributions | find peaks and regions corresponding to peaks


## Literature

* [Richard Szeliski: Computer Vision Algorithms and Applications, Chapter 5.3.2](https://szeliski.org/Book/1stEdition.htm)

# Implementing mean shift segmentation

## Implementation

In [None]:
from skimage.color import rgb2gray
from imageio import imread
import numpy as np

class MSS:
    """A class for performing mean shift segmentation and clustering.
    """

    # sigma_c: the color coefficient
    sigma_c = .1  # .05

    # minimal change during mean shift step required for continuouing the process
    epsilon = .1

    # radius: the maximal distance for points in the 5D spatial-range representation
    # space to be considered neighbors
    radius = 1/10
    
    # max_dist:
    max_dist = .5

    # min_elem:
    min_elem = 10
    
    def __init__(self, img, sigma_c=None, epsilon=None, radius=None):
        """Initialize an MSS object
        
        Arguments
        ---------
        img: np.ndarray
        sigma_c: float
        epsilon: float
        radius: float
        """
        
        self.sigma_c = MSS.sigma_c if not sigma_c else sigma_c
        self.epsilon = MSS.epsilon if not epsilon else epsilon
        self.radius = MSS.radius if not radius else radius
        
        # self.img: (a copy of) the image to be processed
        self.img = img.copy()
        self.size_x = self.img.shape[0]
        self.size_y = self.img.shape[1]
        self.size_z = 3 if self.img.ndim == 3 else 1

        # here we scale the spatial coordinates to the real interval [0,1]x[0,1].
        yy, xx = np.meshgrid(np.linspace(0, 1, self.img.shape[1]), np.linspace(0, 1, self.img.shape[0]))
        yy, xx = np.indices(img.shape[:2])

        # self.d: the 5D spatial-range representation: (pixels, 5) storing for each pixel a 5-tupel (x,y,color),
        # consisting of the spatial coordinates (x,y) the 3D-color value (e.g., RGB).
        self.d = np.empty((self.img.shape[0]*self.img.shape[1], 2 + self.size_z))
        self.d[:, 0] = xx.flatten()
        self.d[:, 1] = yy.flatten()
        self.d[:, 2:] = self.img.reshape(self.img.shape[0]*self.img.shape[1],self.size_z) * self.sigma_c

        # self.g_star: the 5D output spatial-range representation resulting from mean shift filtering
        self.g_star = np.zeros_like(self.d)
  
    def d2img(self):
        """Compute image back from 5D space.
        
        Result
        ------
        image: np.ndarray
            An image representing the current state of this MSS object.
        """
        # just ignore the spatial coordinate and just return the color (all but the first two axis
        # in the spatial-range representation)
        img = self.g_star[:,2:].reshape(self.size_x, self.size_y, self.size_z) * 1/self.sigma_c
    
        return img/img.max()

    def dist(self, p1, p2, axis=None):
        """Compute (euclidean) distance between points in spatial-range representation.
        """
        return np.linalg.norm(p1 - p2, axis=axis)

    def compute_mean_shift_step(self, point):
        """Perform one mean shift step by computing the center of gravity of the neighborhood
        of the given point in the 5D spatial-range representation space.
        """
        # compute 5D neighbors, a list of pixels that are within a self.radius from point
        d_dist = self.dist(self.d, point, axis=1)
        neighbors = self.d[d_dist<self.radius,:]

        # compute the 5D gravicenter of all neighbors
        return np.sum(neighbors, axis=0)/neighbors.shape[0]

    def ms_filter(self, progress = lambda x: x):
        """Perform mean shift filtering as explained in CV-07, slide 100:
        Mean shift filtering finds the local density maxima and assigns to each
        pixel p = (s,c)T the color c* of the closest local density maximum, i.e., the
        image g is transformed to a filtered version g* by replacing p with p*:
        
        Attributes
        ----------
        progress: function
            A function to update a progress bar to be displayed during computation
            (this may be useful, as computation can take some time).
        """
        # loop over all pixels in 5D spatial-range representation
        for index, point in enumerate(progress(self.d)):

            # remember the original 5D coordinates of the pixel
            point_old = point.copy()
            delta = self.epsilon

            # move point in 5D space as long as the step size is above epsilon
            while delta >= self.epsilon:

                point = self.compute_mean_shift_step(point_old)
                delta = self.dist(point, point_old)
                point_old = point

            # store the result for this pixel
            self.g_star[index] = point

        return self.d2img()
    

    def mss_segment(self, max_dist=None, min_elem=None):
        """
        
        Arguments
        ---------
        max_dist: float
            Maximal distance between pixels in 5D space to be joined
            into a segment.
        min_elem: int
            Minimal number of pixels per segment. Smaller segments
            will be merged with the 
        """

        # use default parameter values if no values are provided
        if not min_elem: min_elem = MSS.min_elem
        if not max_dist: max_dist = MSS.max_dist

        # normalize so spatial and color are of same scale
        max_g = self.g_star.max(axis=0)
        g_star_norm = self.g_star.copy()/max_g

        # Now perform segmentation by assigning labels to pixels.
        # Initially, all labels are set to -1.
        label_count = 0
        labels = np.zeros((len(g_star_norm))) - 1

        # loop over all pixels
        for i in range(len(g_star_norm)):
            if labels[i] == -1:  # pixel has not label yet
                # create a new segment
                labels[i] = label_count
                # compute distances of unlabeled pixels to current pixels
                dists = np.zeros_like(labels) + max_dist
                dists[labels==-1] = np.linalg.norm(g_star_norm[labels==-1,:] - g_star_norm[i,:], axis=1)
                # combine pixels within the given radius (max_dist) to the new segment
                labels[dists < max_dist] = label_count
                label_count += 1

        # remove small segments
        # compute centers in joint domain
        centers = np.zeros((label_count, 2 + self.size_z), dtype=float)
        for i in range(label_count):
            mask = (labels==i)
            centers[i] = np.sum(g_star_norm[mask], axis=0)/np.sum(mask)

        # replace small segments with segments with smallest distance between segment centers
        for i in range(label_count):
            if np.sum(labels==i) < min_elem:
                center = centers[i,:].copy()
                centers[i,:] = np.inf

                dist = np.linalg.norm(centers - center, axis=1)
                labels[labels==i] = np.argmin(dist)         

        # transform back to image space
        labels_im_space = labels.reshape(self.size_x, self.size_y)

        # compute mean color per segment
        labels_colored = np.atleast_3d(np.zeros_like(img, dtype=float))
        for i in range(label_count):
            mask = (labels_im_space==i)
            if np.sum(mask):
                labels_colored[mask,:] = centers[i,2:]

        labels_colored *= max_g[2:]/self.sigma_c
        return np.squeeze(labels_colored).astype(np.uint8)

## Demo 1

Mean shift filtering and segmentation:

In [None]:
# Install `tqdm` package to display a progress bar (optional)
!conda install -c conda-forge tqdm

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os

try:
    from tqdm import tqdm as progress
except ModuleNotFoundError:
    def progress(x):
        return x

filename_image = os.path.join('images', 'peppers.png')
# filename_image = os.path.join('images', 'burano.jpg')
img = imread(filename_image)
img = img[170:270,100:200]
#img = img[500:700:2, 500:700:2]


sigma_c = .05
epsilon = 1
radius = 6
min_elem = 400

mss = MSS(img, sigma_c, epsilon, radius)

# perform mean shift filtering
seg = mss.ms_filter(progress=progress)

# perform mean shift segmentation
labels = mss.mss_segment(min_elem=min_elem)


fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(36,12))
plt.gray()
ax1.imshow(img)
ax2.imshow(seg)
ax3.imshow(labels)
plt.show()

## Demo 2

Mean shift filtering with different values $\sigma_c$ (Compare CV-6, slide 104):
* larger $\sigma_c$: distance in color space is more important: pixel colors should only change marginally during mean shift filtering
* smaller $\sigma_c$: distance in color space is less important: pixel may migrate into other "color regions" during mean shift filtering

In [None]:
filename_image = os.path.join('images', 'flower-color.png')
img = imread(filename_image)[:,:,:3]

sigma_c_1 = .04
radius_1 = 8
epsilon_1 = 1

sigma_c_2 = .004
radius_2 = 8
epsilon_2 = 1

sigma_c_3 = .05
radius_3 = 6
epsilon_3 = 1

# perform mean shift filtering
mss1 = MSS(img, sigma_c_1, epsilon_1, radius_1)
filter1 = mss1.ms_filter(progress=progress)

mss2 = MSS(img, sigma_c_2, epsilon_2, radius_2)
filter2 = mss2.ms_filter(progress=progress)

mss3 = MSS(img, sigma_c_3, epsilon_3, radius_3)
filter3 = mss3.ms_filter(progress=progress)

# output the resulting images
fig, ax = plt.subplots(2,2, figsize=(16,10))
plt.gray()
ax[0,0].imshow(img); ax[0,0].set_title("Input")
ax[0,1].imshow(filter1); ax[0,1].set_title(f"Filter 1: $\sigma_c$={sigma_c_1:.4f}, $R$={float(radius_1):.1f}")
ax[1,0].imshow(filter2); ax[1,0].set_title(f"Filter 2: $\sigma_c$={sigma_c_2:.4f}, $R$={float(radius_2):.1f}")
ax[1,1].imshow(filter3); ax[1,1].set_title(f"Filter 3: $\sigma_c$={sigma_c_3:.4f}, $R$={float(radius_3):.1f}")
plt.show()

### Mean shift segmentation

Perform mean shift segmentation:

In [None]:
# minimal number of pixels per segment:
min_elem = 50

# do the labeling
labels1 = mss1.mss_segment(min_elem=min_elem)
labels2 = mss2.mss_segment(min_elem=min_elem)
labels3 = mss3.mss_segment(min_elem=min_elem)

In [None]:
# output the resulting segmentation
fig, ax = plt.subplots(2,2, figsize=(16,12))
plt.gray()
ax[0,0].imshow(img); ax[0,0].set_title("Input")
ax[0,1].imshow(labels1)
ax[0,1].set_title(f"Filter 1: $\sigma_c$={sigma_c_1:.4f}, segments={len(np.unique(labels1))}")
ax[1,0].imshow(labels2)
ax[1,0].set_title(f"Filter 2: $\sigma_c$={sigma_c_2:.4f}, segments={len(np.unique(labels2))}")
ax[1,1].imshow(labels3)
ax[1,1].set_title(f"Filter 2: $\sigma_c$={sigma_c_3:.4f}, segments={len(np.unique(labels3))}")
plt.show()

# Canny Edge Detection

The following interactive matplotlib cell allows to explore the contribution of the parameters $t_1$, $t_2$ and $\sigma$ in the canny edge detector.

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
from imageio import imread
from skimage import feature
from skimage.color import rgb2gray


sigma = 3
t1 = .1 # .1 * max of image
t2 = .2 # .2 * max of image

img = rgb2gray(imread('../../resources/segmentation/images/murano.jpg')/255.0)

def update():
    edges = feature.canny(img, sigma=sigma, low_threshold=t1, high_threshold=t2)
    ax2.imshow(edges, cmap='gray')
    
    
def update_s(val):
    global sigma
    sigma = val
    update()
    
    
def update_t1(val):
    global t1
    t1 = val
    update()
    
    
def update_t2(val):
    global t2
    t2 = val
    update() 

    
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(8,8))
plt.subplots_adjust(left=0.25, bottom=0.25)
ax_s = plt.axes([0.25, 0.1, 0.65, 0.03])
ax_t1 = plt.axes([0.25, 0.3, 0.65, 0.03])
ax_t2 = plt.axes([0.25, 0.2, 0.65, 0.03])

s_s = Slider(ax_s, 'sigma', 0, 11, valinit=3, valstep=1)
s_s.on_changed(update_s)
s_t1 = Slider(ax_t1, 't_1', 0, 1, valinit=.1, valstep=.05)
s_t1.on_changed(update_t1)
s_t2 = Slider(ax_t2, 't_2', 0, 1, valinit=.2, valstep=.05)
s_t2.on_changed(update_t2)


ax1.imshow(img, cmap='gray')
update()
plt.show()

# Matplotlib Animations

Animations may help to visualize and understand the effects of parameters.

## "Poor man's animation"

Idea: regularly update the figure:
1. setup the figure
2. provide a function to update the figure
3. create a loop to run the animation

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import imageio

img = imageio.imread('imageio:camera.png').astype(np.float32) / 255
columns, rows = img.shape[:2]

# plot the image
fig = plt.figure(figsize=(10, 5))
plt.subplot(1,2,1); 
plt.imshow(img, cmap='gray')
mpl_line, = plt.plot([], [])

# plot the row
ax2 = plt.subplot(1,2,2)
ax2.set_ylim([0,1])
ax2.set_xlim([0, columns])
mpl_plot, = plt.plot([], [], 'b', label='image row')
plt.legend(loc='upper right')

plt.show()

In [None]:
def show_row(row):
    mpl_line.set_data([0, columns-1], [row, row])
    mpl_plot.set_data(np.arange(columns), img[row, :])
    fig.canvas.draw()  # has to be called explicitly!

for row in range(len(img)):
    show_row(row)

## The `matplotlib.anmation` module

Matplotlib provides the [`animation` module](https://matplotlib.org/stable/api/animation_api.html) to create and work with animations.

To create an animation, different `Animation` classes can be used.
* `TimedAninmation`: The `matplotlib.animation.TimedAnimation` creates an animation by displaying new frames at regular time intervals.
* `FuncAninmation`: The `matplotlib.animation.FuncAnimation` is a subclass of the `TimedAnimation`. It creates an animation by calling a function in regular intervals to update a figure.
* `ArtistAnimation`: The `matplotlib.animation.FuncAnimation` is a subclass of the `TimedAnimation`. It creates an animation from a sequence of MatPlotLib artists.

### The `FuncAnimation` class

The `FuncAnimation` takes the following arguments:
* `func`: the function to be called to create the next frame.
* `frames`: can be a number or an iterator. Will be passed as first argument to the update function
* `interval`: delay between frames in milliseconds (default: 200 = 5 frames per second). A value of 40 will yield 25 frames per second.

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import imageio

img = imageio.imread('imageio:camera.png').astype(np.float32) / 255
columns, rows = img.shape[:2]

# show the image
fig = plt.figure(figsize=(10, 5))
plt.subplot(1,2,1); 
plt.imshow(img, cmap='gray')
mpl_line, = plt.plot([], [])

# plot the row
ax2 = plt.subplot(1,2,2)
ax2.set_ylim([0,1])
ax2.set_xlim([0,columns])
mpl_plot, = plt.plot([], [], 'b', label='image row')
plt.legend(loc='upper right')

# animation function. This is called sequentially
def animate(i):
    mpl_line.set_data([0, columns-1], [i, i])
    mpl_plot.set_data(np.arange(columns), img[i, :])
    return [mpl_line, mpl_plot]

# call the animator. blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, frames=len(img), interval=20, repeat=False, blit=True)

fig.show()

### Alternative: `ArtistAnimation`

* create animation using a fixed set of Artist objects
* for each frame a collection of Artist objects is given
* only those artists are made visible on the corresponding frame, other artists are made invisible

In [None]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import imageio

img = imageio.imread('imageio:camera.png').astype(np.float32) / 255
columns, rows = img.shape[:2]

fig = plt.figure(figsize=(10, 5))
# show the image
ax1 = plt.subplot(1,2,1)
plt.imshow(img, cmap='gray')

# plot the row
ax2 = plt.subplot(1,2,2)
ax2.set_ylim([0,1])

# create a list of artists collections
frames = []
for row in range(rows):
    plt.subplot(1,2,1)
    mpl_line, = ax1.plot([0, columns-1], [row, row], 'b')
    plt.subplot(1,2,2)
    mpl_plot, = ax2.plot(np.arange(columns), img[row, :], 'b', label='image row')
    frames.append([mpl_line, mpl_plot])

anim = animation.ArtistAnimation(fig, frames, interval=20, repeat=False, blit=True)

fig.show()

### Showing animations in the notebook: `to_html5_video()`

In [None]:
from IPython.display import HTML
HTML(anim.to_html5_video())  # this may need some time!

### Storing animations: `anim.save`

* animation can be stored in different formats
* various parameters can be passed to adapt the output (see [documentation](https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.Animation.html#matplotlib.animation.Animation.save))
* different backends can be used (may have to be installed separately)

In [None]:
anim.save('anim.mp4', fps=20)  # this may need some time!

In [None]:
!ls

In [None]:
from IPython.display import Video
Video("anim.mp4")

# Efficiency: Python vs Numpy


The (average) Runtime of a function can be measured with the `%timeit` magic command:

In [None]:
L = range(2000)

%timeit [i**2 for i in L]

The `timeit` library:
* loops: calls the command repeatedly and measures the the (average) execution time
* the number of loops is automatically choosen according to the execution time
* multiple runs: `timeit` usually runs this measurement multiple times
* the parameters `-r` (number of runs) and `-n` (loops per run) can be used to controll these values

In [None]:
L = range(2000)
%timeit -r 3 -n 20 [i**2 for i in L]

Changing the function may speed up processing:

In [None]:
import math
%timeit -r 3 -n 1000 [math.pow(i,2) for i in L]

## Comparison of pure Python and Numpy

In [None]:
import numpy as np

L = range(2000)
a = np.arange(2000)

%timeit [i*2 for i in L]
%timeit a*2

Using numpy operators is usually **much** faster!

Example: Matrix sum

In [None]:
import numpy as np

def sum1(a, b, result):
    for i in np.ndindex(a.shape):
        result[i] = a[i] + b[i]

def sum2(a, b, result):
    result[:] = a + b

n = 1000
a = np.random.rand(n,n)
b = np.random.rand(*a.shape)
result = np.ndarray(a.shape)

%timeit sum1(a, b, result)
%timeit sum2(a, b, result)

Try to use **vectorized computation** (avoid loops)!

**Task:** Find out what the following function do. The try to vectorize the functions

In [None]:
import numpy as np

def func1bad(a, b, result):
    for i in np.ndindex(a.shape):
        d = a[i]**2 + b[i]**2
        result[i] = np.sqrt(d)

def func1good(a, b, result):
    # BEGIN SOLUTION
    result[:] = np.sqrt(a**2 + b**2)
    # END SOLUTION
    
n = 1000
a = np.random.rand(n,n)
b = np.random.rand(*a.shape)
result = np.ndarray(a.shape)

%timeit func1bad(a, b, result)
%timeit func1good(a, b, result)

In [None]:
import numpy as np

def func2bad(img, threshold):
    for i in np.ndindex(img.shape):
        if img[i] < threshold:
            img[i] = 0

def func2good(img, threshold):
    # BEGIN SOLUTION
    img[img<threshold] = 0
    # END SOLUTION

n = 1000
threshold = .7
img1 = np.random.rand(n,n)
img2 = img1.copy()

%timeit func2bad(img1, threshold)
%timeit func2good(img2, threshold)

In [None]:
import numpy as np

def func3bad(a, b, result):
    assert a.shape[1] == b.shape[0]
    for i,j in np.ndindex(result.shape):
        result[i,j] = 0
        for k in np.arange(a.shape[1]):
            result[i,j] += a[i,k] * b[k,j]

def func3good(a, b, result):
    # BEGIN SOLUTION
    result[:] = a @ b
    # END SOLUTION

n = 100
a = np.random.rand(n,n)
b = np.random.rand(*a.shape)
result = np.ndarray(a.shape)

%timeit func3bad(a, b, result)
%timeit func3good(a, b, result)

In [None]:
import numpy as np

def func4(edges, result):
    for row,col in np.ndindex(result.shape):
        result[row,col] = False
        if (((edges[row,col] > 0) and (edges[row,col+1] < 0)) or
            ((edges[row,col] < 0) and (edges[row,col+1] > 0))):
            result[row,col] = True
        elif (((edges[row,col] > 0) and (edges[row+1,col] < 0)) or
            ((edges[row,col] < 0) and (edges[row+1,col] > 0))):
            result[row,col] = True

def func4vector(edges, result):
    # BEGIN SOLUTION
    result[:] = (edges[:-1, 1:] * edges[1:, 1:] <= 0) | (edges[1:, :-1] * edges[1:, 1:] <= 0)
    
def func4vector_bad(edges, result):
    # The following is not really faster, in fact it is much slower!
    
    result[:] = (np.logical_and(edges[:-1, 1:]<0.,edges[1:, 1:]>0.) |
                 np.logical_and(edges[:-1, 1:]>0.,edges[1:, 1:]<0.) |
                 np.logical_and(edges[1:, :-1]<0.,edges[1:, 1:]>0.) |
                 np.logical_and(edges[1:, :-1]>0.,edges[1:, 1:]<0.))
    # END SOLUTION

n = 1000
edges = np.random.randn(n,n)
result = np.ndarray((edges.shape[0]-1,edges.shape[1]-1), dtype=bool)

%timeit func4(edges, result)
%timeit func4vector(edges, result)