In [3]:
%reload_ext autoreload
%autoreload 2

import numpy as np

from joblib import Parallel, delayed

import matplotlib.pyplot as plt
%matplotlib inline

import os, sys
import time

sys.path.append(os.path.join(os.environ['GORDON_REPO_DIR'], 'utilities'))
from utilities2015 import *

In [4]:
dm = DataManager(stack='MD594', section=139, segm_params_id='tSLIC200')

textonmap = dm.load_pipeline_result('texMap')
n_texton = textonmap.max() + 1

In [5]:
histograms_normalized = dm.load_pipeline_result('histogramsPixelNormalized')

In [None]:
window_size = 201
window_halfsize = (window_size-1)/2

single_channel_maps = [textonmap[dm.ymin-window_halfsize : dm.ymax+1+window_halfsize, 
                                 dm.xmin-window_halfsize : dm.xmax+1+window_halfsize] == c
                       for c in range(n_texton)]

In [None]:
sys.stderr.write('computing histogram for each pixel\n')
t = time.time()

from skimage.transform import integral_image

# it is important to pad the integral image with zeros before first row and first column
def compute_integral_image(m):
    return np.pad(integral_image(m), ((1,0),(1,0)), mode='constant', constant_values=0)

int_imgs = np.dstack(Parallel(n_jobs=4)(delayed(compute_integral_image)(m) for m in single_channel_maps))

sys.stderr.write('done in %.2f seconds\n' % (time.time() - t))

In [None]:
histograms = int_imgs[window_size:, window_size:] + \
            int_imgs[:-window_size, :-window_size] - \
            int_imgs[window_size:, :-window_size] - \
            int_imgs[:-window_size, window_size:]
        
histograms_normalized = histograms/histograms.sum(axis=-1)[...,None].astype(np.float)

In [None]:
dm.save_pipeline_result(histograms_normalized, 'histogramsPixelNormalized')

In [None]:
del single_channel_maps, histograms, int_imgs

In [None]:
fig, axes = plt.subplots(3,7,figsize=(20,10));
for ax in axes.flat:
    ax.axis('off');
    
axes[0,0].imshow(dm.image_rgb_jpg[dm.ymin:dm.ymin+1000, dm.xmin:dm.xmin+1000]);
axes[0,1].imshow(textonmap[dm.ymin:dm.ymin+1000, dm.xmin:dm.xmin+1000]);
for i in range(n_texton):
    axes[1+i/7, i%7].matshow(histograms_normalized[:1000, :1000, i]);
    axes[1+i/7, i%7].set_title('histogram channel %d' % i);

In [6]:
from itertools import chain
from operator import itemgetter, attrgetter
from skimage.measure import regionprops
from skimage.segmentation import slic, mark_boundaries, relabel_sequential
from skimage.util import img_as_ubyte, pad
import cv2

# w_spatial = 0.001

# spacing = 200
# sp_ys, sp_xs = np.mgrid[0:dm.h:spacing, 0:dm.w:spacing]
# sp_texhists = histograms_normalized[sp_ys.flat, sp_xs.flat]
# centroids = np.c_[sp_xs.flat, sp_ys.flat, sp_texhists]
# n_superpixels = centroids.shape[0]

In [7]:
def compute_distance_to_centroids(centroids_xy, centroids_texture, spacing, w_spatial, hist_map, h, w,
                                 ymins, ymaxs, xmins, xmaxs, window_spatial_distances):
    
    n = len(centroids_xy)
    
    ds = [None for _ in range(n)]
        
    for ci in range(n):
    
        ymin = ymins[ci]
        xmin = xmins[ci]
        ymax = ymaxs[ci]
        xmax = xmaxs[ci]
        
        cx, cy = centroids_xy[ci].astype(np.int)
                
        crop_window_x_min = spacing-cx if cx-spacing < 0 else 0
        crop_window_y_min = spacing-cy if cy-spacing < 0 else 0
        crop_window_x_max = 2*spacing - (cx+spacing - (w - 1)) if cx+spacing > w - 1 else 2*spacing
        crop_window_y_max = 2*spacing - (cy+spacing - (h - 1)) if cy+spacing > h - 1 else 2*spacing
                
        spatial_ds = window_spatial_distances[crop_window_y_min:crop_window_y_max+1,
                                              crop_window_x_min:crop_window_x_max+1].reshape((-1,))

        texture_ds = chi2s([centroids_texture[ci]], 
                           hist_map[ymin:ymax+1, xmin:xmax+1].reshape((-1, hist_map.shape[-1])))
        
        try:
            ds[ci] = w_spatial * spatial_ds + texture_ds
        except:
            sys.stderr.write('1, %d,%d,%d,%d; 2, %d,%d,%d,%d\n'%(xmin,ymin,xmax,ymax,
                                                                crop_window_x_min,crop_window_y_min,
                                                                 crop_window_x_max,crop_window_y_max))
            raise
            
    return ds

# def compute_new_centroids(sps, assignments, hist_map):
    
#     centroids = np.zeros((len(sps), 2 + hist_map.shape[-1]))
#     counts = np.zeros(len(sps))
#     for r, ss in enumerate(assignments):
#         centroids[ss, 1] += r
#         centroids[ss, 0] += ss
# #         for c, s in enumerate(ss):
# #         centroids[s,:2] += (c,r)
#         centroids[ss, 2:] += hist_map[r,c]
#         counts[ss] += 1
            
#     centroids = centroids / counts
            
# #     for i, sp_i in enumerate(sps):        
# #         rs, cs = np.where(assignments == sp_i)
# #         centroids[i] = np.r_[np.c_[rs, cs].mean(axis=0)[::-1], hist_map[rs, cs].mean(axis=0)]
#     return centroids


# def compute_new_centroids2(sp_coords, hist_map):

#     centroids = np.empty((len(sp_coords), 2+hist_map.shape[-1]))
#     for i, coords in enumerate(sp_coords):    
#         rs = coords[:,0]
#         cs = coords[:,1]
#         centroids[i, 0] = cs.mean()
#         centroids[i, 1] = rs.mean()
#         centroids[i, 2:] = hist_map[rs, cs].mean(axis=0)
    
#     return centroids

# def compute_new_centroids_texture(sp_coords, hist_map):
#     centroids_textures = [hist_map[coords[:,0], coords[:,1]].mean(axis=0) for coords in sp_coords]    
#     return centroids_textures

In [None]:
from skimage.segmentation import mark_boundaries
dm._load_image(versions=['rgb-jpg'])

In [18]:
def slic_texture(hist_map, spacing=200, w_spatial=0.001, max_iter=5):

#     hist_map = histograms_normalized
#     spacing=200
#     w_spatial=0.001
#     max_iter=1

    h, w, n_texton = hist_map.shape

    sp_ys, sp_xs = np.mgrid[0:h:spacing, 0:w:spacing]
    
    n_superpixels = len(sp_ys.flat)
    
    centroids_textures = hist_map[0:h:spacing, 0:w:spacing].reshape((-1, n_texton))
    centroids_xy = np.c_[sp_xs.flat, sp_ys.flat]

    ys, xs = np.mgrid[-spacing:spacing+1, -spacing:spacing+1]
    window_spatial_distances = np.sqrt(ys**2 + xs**2)
    
    for iter_i in range(max_iter):

        print 'iteration', iter_i

        cx = centroids_xy[:, 0].astype(np.int)
        cy = centroids_xy[:, 1].astype(np.int)
        window_ymins = np.maximum(0, cy - spacing)
        window_xmins = np.maximum(0, cx - spacing)
        window_ymaxs = np.minimum(h-1, cy + spacing)
        window_xmaxs = np.minimum(w-1, cx + spacing)
                
        assignments = -1 * np.ones((h, w), np.int16)
        distances = np.inf * np.ones((h, w), np.float16)

        sys.stderr.write('%d superpixels\n'%n_superpixels)

        a = time.time()            

        sys.stderr.write('compute distance\n')
        
        res = Parallel(n_jobs=16)(delayed(compute_distance_to_centroids)(centroids_xy[si:ei], 
                                                                         centroids_textures[si:ei], 
                                                                         spacing=spacing, w_spatial=w_spatial, 
                                                                         hist_map=hist_map, h=h, w=w, 
                                                                         ymins=window_ymins[si:ei], 
                                                                         ymaxs=window_ymaxs[si:ei], 
                                                                         xmins=window_xmins[si:ei], 
                                                                         xmaxs=window_xmaxs[si:ei],
                                                window_spatial_distances=window_spatial_distances)
                                    for si, ei in zip(np.arange(0, n_superpixels, n_superpixels/128), 
                                        np.arange(0, n_superpixels, n_superpixels/128) + n_superpixels/128))

        sys.stderr.write('done in %.2f seconds\n' % (time.time() - a))

        a = time.time()

        sys.stderr.write('aggregate\n')

        for sp_i, new_ds in enumerate(chain(*res)):
            
            ymin = window_ymins[sp_i]
            xmin = window_xmins[sp_i]
            ymax = window_ymaxs[sp_i]
            xmax = window_xmaxs[sp_i]

            q = new_ds.reshape((ymax+1-ymin, xmax+1-xmin))
            s = q < distances[ymin:ymax+1, xmin:xmax+1]

            distances[ymin:ymax+1, xmin:xmax+1][s] = q[s]
            assignments[ymin:ymax+1, xmin:xmax+1][s] = sp_i
    
        del res

        sys.stderr.write('done in %.2f seconds\n' % (time.time() - a))
        
        
        img_superpixelized = mark_boundaries(dm.image_rgb_jpg[dm.ymin:dm.ymax+1, dm.xmin:dm.xmax+1], 
                                             assignments, color=(1,0,0))
        img_superpixelized = img_as_ubyte(img_superpixelized)
        dm.save_pipeline_result(img_superpixelized, 'segmentationWithoutTextIter%d'%iter_i)
        
        
        sys.stderr.write('update assignment\n')
        t = time.time()

        props = regionprops(assignments+1)
        sp_coords = map(attrgetter('coords'), props)
        sp_centroid = np.asarray(map(attrgetter('centroid'), props))
        
        centroids_textures = [hist_map[coords[:,0], coords[:,1]].mean(axis=0) for coords in sp_coords]
        
        centroids_xy_new = sp_centroid[:, ::-1]

        sys.stderr.write('total centroid location change = %d\n' % 
                         np.sum(np.abs(centroids_xy_new - centroids_xy)))

        centroids_xy = centroids_xy_new

        sys.stderr.write('done in %.2f seconds\n' % (time.time() - t))

    return assignments

In [19]:
assignments = slic_texture(histograms_normalized, max_iter=10, spacing=200)

iteration 0


1953 superpixels
compute distance
done in 19.01 seconds
aggregate
done in 4.20 seconds
update assignment
total centroid location change = 38244


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter0.jpg
iteration

done in 10.46 seconds
1953 superpixels


 1


compute distance
done in 19.18 seconds
aggregate
done in 4.26 seconds
update assignment
total centroid location change = 19260


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter1.jpg
iteration

done in 11.23 seconds
1953 superpixels


 2


compute distance
done in 23.48 seconds
aggregate
done in 4.38 seconds
update assignment
total centroid location change = 14949


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter2.jpg
iteration

done in 10.73 seconds
1953 superpixels


 3


compute distance
done in 22.13 seconds
aggregate
done in 4.33 seconds
update assignment
total centroid location change = 12586


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter3.jpg
iteration

done in 10.65 seconds
1953 superpixels


 4


compute distance
done in 21.31 seconds
aggregate
done in 4.46 seconds
update assignment
total centroid location change = 10911


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter4.jpg
iteration

done in 10.82 seconds
1953 superpixels


 5


compute distance
done in 21.34 seconds
aggregate
done in 4.39 seconds
update assignment
total centroid location change = 9708


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter5.jpg
iteration

done in 10.75 seconds
1953 superpixels


 6


compute distance
done in 21.71 seconds
aggregate
done in 4.45 seconds
update assignment
total centroid location change = 8810


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter6.jpg
iteration

done in 10.84 seconds
1953 superpixels


 7


compute distance
done in 21.36 seconds
aggregate
done in 4.43 seconds
update assignment
total centroid location change = 8024


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter7.jpg
iteration

done in 10.83 seconds
1953 superpixels


 8


compute distance
done in 21.66 seconds
aggregate
done in 4.42 seconds
update assignment
total centroid location change = 7292


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter8.jpg
iteration

done in 11.03 seconds
1953 superpixels


 9


compute distance
done in 22.34 seconds
aggregate
done in 4.43 seconds
update assignment
total centroid location change = 6715


saved /oasis/projects/nsf/csd395/yuncong/CSHL_data_results/MD594/0139/MD594_0139_lossless_segm-tSLIC200_segmentationWithoutTextIter9.jpg


done in 11.09 seconds


In [None]:
img_superpixelized = mark_boundaries(dm.image_rgb_jpg[dm.ymin:dm.ymax+1, dm.xmin:dm.xmax+1], 
                                     assignments, color=(1,0,0))
display(img_superpixelized)

In [None]:
hist_map = histograms_normalized

In [None]:
h, w = hist_map.shape[:2]
max_iter = 1
spacing=200
w_spatial=0.001

from itertools import chain
from operator import itemgetter

sp_ys, sp_xs = np.mgrid[0:h:spacing, 0:w:spacing]
sp_texhists = hist_map[sp_ys.flat, sp_xs.flat]
centroids = np.c_[sp_xs.flat, sp_ys.flat, sp_texhists]
n_superpixels = centroids.shape[0]

In [None]:
print n_superpixels

In [None]:
# for iter_i in range(max_iter):

iter_i = 0

print 'iteration', iter_i

assignments = -1 * np.ones((h, w), np.int16)
distances = np.inf * np.ones((h, w), np.float16)

sys.stderr.write('compute_distance_to_centroids\n')

for i in range(0, n_superpixels, 100):

    sys.stderr.write('compute_distance_to_centroids\n')
    t = time.time()
    
    res = Parallel(n_jobs=16)(delayed(compute_distance_to_centroids)(centroids_p, 
                                                                     spacing=spacing,
                                                                    w_spatial=w_spatial,
                                                                    hist_map=hist_map,
                                                                     h=h, w=w)
                              for centroids_p in np.array_split(centroids[i:i+100], 16))
    
    print time.time() - t


    t = time.time()

    new_dists = list(chain(*res))

    for sp_i, nds in enumerate(new_dists):

        cx = int(centroids[i+sp_i, 0])
        cy = int(centroids[i+sp_i, 1])

        ymin = max(0, cy - 2*spacing)
        xmin = max(0, cx - 2*spacing)
        ymax = min(h-1, cy + 2*spacing)
        xmax = min(w-1, cx + 2*spacing)

        ys, xs = np.mgrid[ymin:ymax+1, xmin:xmax+1].astype(np.int)
        cls = np.c_[ys.flat, xs.flat]

        s = nds < distances[cls[:,0], cls[:,1]]
        distances[cls[s,0], cls[s,1]] = nds[s]
        assignments[cls[s,0], cls[s,1]] = i + sp_i

    del res
    del new_dists
    
    print time.time() - t
    
    break


# sys.stderr.write('done in %.2f seconds\n' % (time.time() - t))

In [None]:
img_superpixelized = mark_boundaries(dm.image_rgb_jpg[dm.ymin:dm.ymax+1, dm.xmin:dm.xmax+1], 
                                     assignments, color=(1,0,0))
img_superpixelized = img_as_ubyte(img_superpixelized)
#     dm.save_pipeline_result(img_superpixelized, 'segmentationWithoutTextIter%d'%iter_i)

sys.stderr.write('update assignment\n')

centroids_part = Parallel(n_jobs=16)(delayed(compute_new_centroids)(sps, assignments=assignments,
                                                                   hist_map=hist_map) 
                                     for sps in np.array_split(range(n_superpixels), 16))
centroids_new = np.vstack(centroids_part)

print 'total centroid location change', np.sum(np.abs(centroids_new[:,:2] - centroids[:,:2]))

centroids = centroids_new

sys.stderr.write('done in %.2f seconds\n' % (time.time() - t))

In [None]:
dm._load_image(versions=['rgb-jpg'])

In [None]:
from skimage.segmentation import mark_boundaries
viz = mark_boundaries(dm.image_rgb_jpg[dm.ymin:dm.ymax+1, dm.xmin:dm.xmax+1], q, color=(1,0,0))
display(viz)