In [None]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append('/home/yuncong/Brain/pipeline_scripts')
from utilities2014 import *
import os

from scipy.spatial.distance import cdist, pdist, squareform
from joblib import Parallel, delayed
from skimage.color import gray2rgb

from skimage.measure import find_contours
from skimage.util import img_as_float

import matplotlib.pyplot as plt
%matplotlib inline

sys.path.append('/home/yuncong/project/opencv-2.4.9/release/lib/python2.7/site-packages')
import cv2

from networkx import from_dict_of_lists, Graph, adjacency_matrix, dfs_postorder_nodes
from networkx.algorithms import node_connected_component

os.environ['GORDON_DATA_DIR'] = '/home/yuncong/project/DavidData2014tif/'
os.environ['GORDON_REPO_DIR'] = '/home/yuncong/Brain'
os.environ['GORDON_RESULT_DIR'] = '/home/yuncong/project/DavidData2014results/'
os.environ['GORDON_LABELING_DIR'] = '/home/yuncong/project/DavidData2014labelings/'

dm = DataManager(generate_hierarchy=False, stack='RS141', resol='x5')
dm.set_gabor_params(gabor_params_id='blueNisslWide')
dm.set_segmentation_params(segm_params_id='blueNisslRegular')
dm.set_vq_params(vq_params_id='blueNissl')

dm.set_slice(7)
dm._load_image()

texton_hists = dm.load_pipeline_result('texHist', 'npy')
segmentation = dm.load_pipeline_result('segmentation', 'npy')
n_superpixels = len(np.unique(segmentation)) - 1
textonmap = dm.load_pipeline_result('texMap', 'npy')
n_texton = len(np.unique(textonmap)) - 1
neighbors = dm.load_pipeline_result('neighbors', 'npy')
sp_properties = dm.load_pipeline_result('spProps', 'npy')
# each item is (center_y, center_x, area, mean_intensity, ymin, xmin, ymax, xmax)
segmentation_vis = dm.load_pipeline_result('segmentationWithoutText', 'jpg')

In [None]:
try:
    sp_sp_dists = dm.load_pipeline_result('texHistPairwiseDist', 'npy')
#     raise
except:
    def f(a):
        sp_dists = cdist(a, texton_hists, metric=chi2)
#         sp_dists = cdist(a, texton_hists, metric=js)
        return sp_dists

    sp_dists = Parallel(n_jobs=16)(delayed(f)(s) for s in np.array_split(texton_hists, 16))
    sp_sp_dists = np.vstack(sp_dists)
    
    dm.save_pipeline_result(sp_sp_dists, 'texHistPairwiseDist', 'npy')

center_dists = pdist(sp_properties[:, :2])
center_dist_matrix = squareform(center_dists)

neighbors_dict = dict(zip(np.arange(n_superpixels), [list(i) for i in neighbors]))
neighbor_graph = from_dict_of_lists(neighbors_dict)

In [None]:
def find_boundary_sps(clusters, neighbors, neighbor_graph, mode=None):
    '''
    Identify superpixels that are at the boundary of regions: surround set and frontier set
    
    Parameters
    ----------
    clusters : list of integer lists
    neighbors : neighbor_list
    neighbor_graph : 
    '''
        
    n_superpixels = len(clusters)
    
    surrounds_sps = []
    frontiers_sps = []
    
    for cluster_ind, cluster in enumerate(clusters):
        
        surrounds = set([i for i in set.union(*[neighbors[c] for c in cluster]) if i not in cluster and i != -1])
        surrounds = set([i for i in surrounds if any([n not in cluster for n in neighbors[i]])])

        if len(surrounds) == 0:
            surrounds_sps.append([])
            frontiers_sps.append([])

        else:
            
            if mode == 'surrounds' or mode == 'both':
                surrounds_subgraph = neighbor_graph.subgraph(surrounds)
                surrounds_traversal = list(dfs_postorder_nodes(surrounds_subgraph))
                surrounds_sps.append(surrounds_traversal)
            
            if mode == 'frontiers' or mode == 'both':
                frontiers = set.union(*[neighbors[c] for c in surrounds]) & set(cluster)
                frontiers_subgraph = neighbor_graph.subgraph(frontiers)
                frontiers_traversal = list(dfs_postorder_nodes(frontiers_subgraph))
                frontiers_sps.append(frontiers_traversal)            
    
    if mode == 'surrounds':
        return surrounds_sps
    elif mode == 'frontiers':
        return frontiers_sps
    else:
        return surrounds_sps, frontiers_sps

In [None]:
def compute_cluster_score(cluster, texton_hists=texton_hists, neighbors=neighbors):
    
    cluster_list = list(cluster)
    cluster_avg = texton_hists[cluster_list].mean(axis=0)
    
    surrounds = set([i for i in set.union(*[neighbors[c] for c in cluster]) if i not in cluster and i != -1])
    surrounds_list = list(surrounds)
    surround_dist = np.squeeze(cdist([cluster_avg], texton_hists[surrounds_list], chi2)).min()

    surds = find_boundary_sps([cluster], neighbors=neighbors, neighbor_graph=neighbor_graph,
                                    mode='surrounds')
    
    compactness = len(surds[0])**2/float(len(cluster))
    compactness = .001 * np.maximum(compactness-40,0)**2
    
    size_prior = .1 * (1-np.exp(-.8*len(cluster)))
    
    score = surround_dist - compactness + size_prior
    
    interior_dist = np.nan
    interior_pval = np.nan
    surround_pval = np.nan
    
    return score, surround_dist, interior_dist, compactness, surround_pval, interior_pval, size_prior

In [None]:
def grow_cluster3(seed, neighbors=neighbors, texton_hists=texton_hists, output=False, all_history=False):
            
    visited = set([])
    curr_cluster = set([])
        
    candidate_scores = [0]
    candidate_sps = [seed]

    score_tuples = []
    added_sps = []
    
    iter_ind = 0
        
    while len(candidate_sps) > 0:

        best_ind = np.argmax(candidate_scores)
        
        heuristic = candidate_scores[best_ind]
        sp = candidate_sps[best_ind]
        
        del candidate_scores[best_ind]
        del candidate_sps[best_ind]
        
        if sp in curr_cluster:
            continue
                
        iter_ind += 1
        curr_cluster.add(sp)
        added_sps.append(sp)
        
        tt = compute_cluster_score(curr_cluster)
        tot, exterior, interior, compactness, surround_pval, interior_pval, size_prior = tt
        score_tuples.append(np.r_[heuristic, tt])
        
        if output:
            print 'iter', iter_ind, 'add', sp

        visited.add(sp)
        
        candidate_sps = (set(candidate_sps) | (neighbors[sp] - set([-1])) | (visited - curr_cluster)) - curr_cluster
        candidate_sps = list(candidate_sps)
        
#         f_avg = texton_freqs[list(curr_cluster)].sum(axis=0)
#         candidate_scores = [chi2pval(f_avg, texton_freqs[i])[0] for i in candidate_sps]

        h_avg = texton_hists[list(curr_cluster)].mean(axis=0)
        candidate_scores = [-chi2(h_avg, texton_hists[i]) for i in candidate_sps]

#         candidate_scores = [compute_cluster_score(curr_cluster | set([s])) for s in candidate_sps]
                
        if len(visited) > int(n_superpixels * 0.03):
            break

    score_tuples = np.array(score_tuples)
    
    min_size = 2
    scores = score_tuples[:,1]
    cutoff = np.argmax(scores[min_size:]) + min_size
    
    if output:
        print 'cutoff', cutoff

    final_cluster = added_sps[:cutoff]
    final_score = scores[cutoff]
    
    if all_history:
        return list(final_cluster), final_score, added_sps, score_tuples
    else:
        return list(final_cluster), final_score

In [None]:
t = sp_properties[list(final_cluster), 4:]
rmin = int(t[:,0].min())
cmin = int(t[:,1].min())
rmax = int(t[:,2].max())
cmax = int(t[:,3].max())

In [None]:
t = sp_properties[added_sps, 4:]
rmin = int(t[:,0].min())
cmin = int(t[:,1].min())
rmax = int(t[:,2].max())
cmax = int(t[:,3].max())

In [None]:
margin = 50

In [None]:
final_cluster, final_score = grow_cluster3(3303, neighbors=neighbors, texton_hists=texton_hists, output=False)
# display(visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis, text=True))

vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()


vis = visualize_cluster(added_sps[:1], segmentation=segmentation, segmentation_vis=segmentation_vis, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

vis = visualize_cluster(added_sps[:10], segmentation=segmentation, segmentation_vis=segmentation_vis, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

vis = visualize_cluster(added_sps[:50], segmentation=segmentation, segmentation_vis=segmentation_vis, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

vis = visualize_cluster(added_sps, segmentation=segmentation, segmentation_vis=segmentation_vis, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

In [None]:
vis = dm.visualize_cluster(final_cluster, text=False)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

In [None]:
final_cluster, final_score = grow_cluster3(3266, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(3129, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(2872, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

In [None]:
t = sp_properties[[1306], 4:]
rmin = int(t[:,0].min())
cmin = int(t[:,1].min())
rmax = int(t[:,2].max())
cmax = int(t[:,3].max())
margin = 300

final_cluster, final_score = grow_cluster3(1152, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(1203, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(1169, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()


final_cluster, final_score = grow_cluster3(1145, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-margin:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

In [None]:
t = sp_properties[[1051], 4:]
rmin = int(t[:,0].min())
cmin = int(t[:,1].min())
rmax = int(t[:,2].max())
cmax = int(t[:,3].max())
margin = 300

final_cluster, final_score = grow_cluster3(1204, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-200:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(1041, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-200:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(1501, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-200:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

final_cluster, final_score = grow_cluster3(861, neighbors=neighbors, texton_hists=texton_hists, output=False)
vis = visualize_cluster(final_cluster, segmentation=segmentation, segmentation_vis=segmentation_vis,
                        text=False, highlight_seed=True)
plt.imshow(vis[rmin-200:rmax+margin, cmin-margin:cmax+margin])
plt.axis('off')
plt.show()

In [None]:
plt.bar(np.arange(n_texton), texton_hists[1968]);
plt.xticks([]);
plt.yticks([]);
plt.show()

plt.bar(np.arange(n_texton), texton_hists[2017]);
plt.xticks([]);
plt.yticks([]);
plt.show()

plt.bar(np.arange(n_texton), texton_hists[2268]);
plt.xticks([]);
plt.yticks([]);
plt.show()