# Applications of Streaming Spanning Tree Algorithm to Segmentation

In this notebook we show three possible applications of Streaming Spanning Tree Algorithm, that we introduced in the previous notebook, to Segmentation. In particular we show how to use it to
- compute $\lambda$-quasi-flat-zones,
- compute a marker-based segmentation,
- compute the $(\alpha, \omega)$-constrained connectivity.

Thanks to the fact that our algorithm splits the minimum spanning tree obtained at time $t$ between a stable and an unstable part, is possible to start treating stable edges to obtain at each step a partial segmentation of the complete image. This avoid to load all the graph in memory and allows also to treat bigger images.

We have written a method for each of the cases listed above. Each of them, at each iteration $t$, yields a partial segmentation coded as a vector $\sigma_t$ of shape $2 \times m_t$. At the first row of vector $\sigma_t$, we list the ids of the nodes/pixels in the image for which we can assign a label in the segmentation, whilst in the second rows we put the labels assigned to nodes in the first row.

In [None]:
import numpy as np
from scipy.misc import ascent, imread
from matplotlib import pyplot as plt

import loadlib

from SST.utils import *
from SST.streaming.streaming_generators import HorizontalStreaming
from SST.applications.segmentation.streaming_segmentation import quasi_flat_zone_streaming

plt.rcParams['figure.figsize'] = (12.0, 12.0)
%load_ext autoreload

%matplotlib inline

In [None]:
%autoreload 2

### $\lambda$-quasi-flat-zones

In the following we show how to compute the $\lambda$-quasi-flat-zones of a streaming image. To do so, we implemented a method called ```quasi_flat_zone_streaming```, that at each time $t=0,\ldots, T$ it returns a vector of shape $2\times m_t$, containing at the first row the ids of the nodes/pixels for which we can assing a label, and at the second row the label assigned. So at each step this method yields a partial segmentation of the complete image.  

In [None]:
# loading image
# im = ascent()
# img = im[300:400,:500]
test_img = imread('RGB_US-CA-SanDiego_2010_05_03.tif')
print(test_img.shape)
# img = test_img[2000:2500, :1498]
img = test_img[2000:2500,500:1998]

plt.imshow(img)
plt.show()

In [None]:
u8 = np.arange(256)
MAP = np.array(np.meshgrid(u8, u8, u8)).T.reshape(-1, 3)
ind = np.lexsort((MAP[:,2], MAP[:,1], MAP[:,0]))
MAP = MAP[ind]

shuffle_colors = np.arange(1, len(MAP))
np.random.seed(1)
np.random.shuffle(shuffle_colors)
shuffle_colors = np.concatenate(([0], shuffle_colors))

In [None]:
def plot_segmentation(orig_img, labels, order='F', title=''):
    import time
    from PIL import Image
    
    nr, nc, nz = np.atleast_3d(orig_img).shape
    # building segmentation from the array of labels
    seg = np.zeros(nr*nc, dtype=np.int_)

    seg[labels[0]] = labels[1]
    
#     unique = np.unique(seg)
#     dict_map = { unique[i]: i for i in range(len(unique))}
#     seg = np.array([dict_map[l] for l in seg])

    res_img = label_image(orig_img, seg, order=order)
    res_img[seg.reshape((nr, nc), order=order) == 0] = 0
    if title != '':
        seg_img = seg.reshape(nr, nc, order=order)
        im = Image.fromarray(MAP[shuffle_colors[seg_img]].astype(np.uint8))
        im.save(title+'.png')
    
    plt.figure()
    plt.imshow(res_img)
    plt.axis('off')
    plt.show()



In [None]:
# creating streaming
gen = HorizontalStreaming(img)
stream = gen.generate_stream(block_shape=(500,500))
# stream = gen.generate_stream(block_shape=(4723,2347))

threshold = 10
for n, (labels, i) in enumerate(quasi_flat_zone_streaming(stream, threshold, return_img=True)):
    print("Number of new stable labels at iteration {} is {}".format(n, labels.shape[1]))
    if n > 0:
        i = stick_two_images(old_i,i,num_overlapping=1,direction='H')
        labels = np.concatenate((old_labels, labels), axis=1)
    plot_segmentation(i, labels, order='F', title='qfz_'+str(n))
    old_i = i
    old_labels = labels

### Marker based segmentation by MST

In the following we show an example of marker based segmentation by MST. We do so using a method called ```marker_flooding_streaming``` 

In [None]:
from SST.applications.segmentation.streaming_segmentation import marker_flooding_streaming

In [None]:
# # loading image
# im = ascent()
# img = im[300:400,:500]

# image shape
nr, nc = img.shape[:2]

plt.imshow(img)
plt.show()

In [None]:
# generating random markers
np.random.seed(1)
num_markers = 100

markers = np.random.randint(0 , nr*nc, size=num_markers)
markers.sort()
print(markers)
# plotting markers position
mark_img = np.zeros((nr*nc), dtype=np.uint8)
mark_img[markers] = 255
mark_img = mark_img.reshape((nr, nc), order='F')
plt.figure(figsize=(16,5))
plt.imshow(mark_img)
plt.show()

In [None]:
# instatiating image generator
gen = HorizontalStreaming(img)
stream = gen.generate_stream(block_shape=(500,500), markers=markers)

for n, (labels, i) in enumerate(marker_flooding_streaming(stream, return_img=True)):
    print("Number of new stable nodes at iteration {} is {}".format(n, labels.shape[1]))
    if n > 0:
        i = stick_two_images(old_i,i,num_overlapping=1,direction='H')
        labels = np.concatenate((old_labels, labels), axis=1)
    img_shape = i.shape[:2]
    plot_segmentation(i, labels,order='F', title='marked_based'+str(n))
    print("Number of unique labels: ", len(np.unique(labels[1])))
    old_i = i
    old_labels = labels

### $(\alpha,\omega)$-constrained connectivity

In the following we show an example of $(\alpha,\omega)$-constrained connectivity. We do so using a method called ```alpha_omega_cc_streaming``` 

In [None]:
from SST.applications.segmentation.streaming_segmentation import alpha_omega_cc_streaming

In [None]:
# loading image
im = ascent()
img = im[300:400,:500]

# image shape
nr, nc = img.shape[:2]

plt.imshow(img)
plt.show()

In [None]:
gen = HorizontalStreaming(img)
stream = gen.generate_stream(block_shape=(100,100), return_map=True)

alpha = 10
omega = 50

for n, (labels, i) in enumerate(alpha_omega_cc_streaming(stream, alpha=alpha, omega=omega, return_img=True)):
    print("Number of new stable nodes at iteration {} is {}".format(n, labels.shape[1]))
    if n > 0:
        i = stick_two_images(old_i,i,num_overlapping=1,direction='H')
        labels = np.concatenate((old_labels, labels), axis=1)
    img_shape = i.shape[:2]
    plot_segmentation(i, labels, order='F')
    old_i = i
    old_labels = labels