diff --git a/gala/agglo.py b/gala/agglo.py index 10d963a..69a33c8 100644 --- a/gala/agglo.py +++ b/gala/agglo.py @@ -5,7 +5,7 @@ import random import logging import json -import collections +import functools from copy import deepcopy # libraries @@ -123,18 +123,47 @@ def conditional_countdown(seq, start=1, pred=bool): # Merge priority functions # ############################ + +def batchify(func): + """Convert classical (g, n1, n2) -> f policy to batch (g, [e]) -> [f] + + This is meant for policies that wouldn't gain much from batch evaluation + or that aren't used very much. + + Parameters + ---------- + func : function + A merge priority function with signature (g, n1, n2) -> f. + + Returns + ------- + batch_func : function + A batch merge priority function with signature (g, [(n1, n2)]) -> [f]. + """ + def batch_func(g, edges): + result = [] + for n1, n2 in edges: + result.append(func(g, n1, n2)) + return result + return batch_func + + +@batchify def oriented_boundary_mean(g, n1, n2): return mean(g.oriented_probabilities_r[g.boundary(n1, n2)]) +@batchify def boundary_mean(g, n1, n2): return mean(g.probabilities_r[g.boundary(n1, n2)]) +@batchify def boundary_median(g, n1, n2): return median(g.probabilities_r[g.boundary(n1, n2)]) +@batchify def approximate_boundary_mean(g, n1, n2): """Return the boundary mean as computed by a MomentsFeatureManager. @@ -144,76 +173,127 @@ def approximate_boundary_mean(g, n1, n2): def make_ladder(priority_function, threshold, strictness=1): - def ladder_function(g, n1, n2): - s1 = g.node[n1]['size'] - s2 = g.node[n2]['size'] - ladder_condition = \ - (s1 < threshold and not g.at_volume_boundary(n1)) or \ - (s2 < threshold and not g.at_volume_boundary(n2)) - if strictness >= 2: - ladder_condition &= ((s1 < threshold) != (s2 < threshold)) - if strictness >= 3: - ladder_condition &= len(g.boundary(n1, n2)) > 2 - - if ladder_condition: - return priority_function(g, n1, n2) - else: - return inf + """Convert priority function to merge small segments first. + + Small segments tend to mess with other segmentation metrics, so we + merge them early so that more sophisticated function can work on big + segments. This is particularly useful for bad fragment generation + methods that generate lots of tiny fragments. + + Parameters + ---------- + priority_function : function (g, [e]) -> [f] + The merge priority function to convert. + threshold : int or float + The minimum size to be considered for merging. + strictness : int in {1, 2, 3} + How hard to check for segment size: + - 1: only merge small nodes that are not at the volume boundary. + - 2: only merge small nodes not at the volume boundary, *but not + to each other.* + - 3: conditions 1 and 2 but also ensure that the boundary shared + between segments is bigger than 2 voxels. + + Returns + ------- + ladder_priority_function : function (g, [e]) -> [f] + Same as priority function but only for small segments, otherwise + returns infinity. + """ + def ladder_function(g, edges): + edges = np.array(edges) + pass_ladder = np.empty(len(edges), dtype=bool) + for i, (n1, n2) in enumerate(edges): + s1 = g.node[n1]['size'] + s2 = g.node[n2]['size'] + ladder_condition = \ + (s1 < threshold and not g.at_volume_boundary(n1)) or \ + (s2 < threshold and not g.at_volume_boundary(n2)) + if strictness >= 2: + ladder_condition &= ((s1 < threshold) != (s2 < threshold)) + if strictness >= 3: + ladder_condition &= len(g.boundary(n1, n2)) > 2 + pass_ladder[i] = ladder_condition + priority = np.empty(len(edges), dtype=float) + priority[pass_ladder] = priority_function(g, edges[pass_ladder]) + priority[~pass_ladder] = np.inf + return priority return ladder_function def no_mito_merge(priority_function): - def predict(g, n1, n2): - frozen = (n1 in g.frozen_nodes or - n2 in g.frozen_nodes or - (n1, n2) in g.frozen_edges) - if frozen: - return np.inf - else: - return priority_function(g, n1, n2) - return predict + """Convert priority function to avoid merging mitochondria. + Mitochondria are super annoying in segmentation. This uses pre- + -computed mitochondrion labels for the segments to avoid merging + anything that looks like a mitochondrion, in the beginning. These + can be dealt with later when the bulk of the segmentation is + correct. -def mito_merge(): - def predict(g, n1, n2): - if n1 in g.frozen_nodes and n2 in g.frozen_nodes: - return np.inf - elif (n1, n2) in g.frozen_edges: - return np.inf - elif n1 not in g.frozen_nodes and n2 not in g.frozen_nodes: - return np.inf - else: - if n1 in g.frozen_nodes: - mito = n1 - cyto = n2 - else: - mito = n2 - cyto = n1 - if g.node[mito]['size'] > g.node[cyto]['size']: - return np.inf - else: - return 1.0 - (float(len(g.boundary(mito, cyto)))/ - sum([len(g.boundary(mito, x)) for x in g.neighbors(mito)])) + Parameters + ---------- + priority_function : function (g, [e]) -> [f] + The merge priority function to convert. + + Returns + ------- + mito_priority_function : function (g, [e]) -> [f] + Same as priority function, but avoids merging frozen nodes/edges. + Freezing can be defined using any property, not just mitochondria! + + See Also + -------- + mito_merge + """ + def predict(g, edges): + priorities = priority_function(g, edges) + for i, (n1, n2) in enumerate(edges): + frozen = (n1 in g.frozen_nodes or + n2 in g.frozen_nodes or + (n1, n2) in g.frozen_edges) + if frozen: + priorities[i] = np.inf + return priorities return predict -def classifier_probability(feature_extractor, classifier): - def predict(g, n1, n2): - if n1 == g.boundary_body or n2 == g.boundary_body: - return inf - features = np.atleast_2d(feature_extractor(g, n1, n2)) - try: - prediction = classifier.predict_proba(features) - prediction_arr = np.array(prediction, copy=False) - if prediction_arr.ndim > 2: - prediction_arr = prediction_arr[0] - try: - prediction = prediction_arr[0][1] - except (TypeError, IndexError): - prediction = prediction_arr[0] - except AttributeError: - prediction = classifier.predict(features)[0] - return prediction +@batchify +def mito_merge(g, n1, n2): + """Simple priority funct to merge segments previously labeled as mito.""" + if n1 in g.frozen_nodes and n2 in g.frozen_nodes: + return np.inf + elif (n1, n2) in g.frozen_edges: + return np.inf + elif n1 not in g.frozen_nodes and n2 not in g.frozen_nodes: + return np.inf + else: + if n1 in g.frozen_nodes: + mito = n1 + cyto = n2 + else: + mito = n2 + cyto = n1 + if g.node[mito]['size'] > g.node[cyto]['size']: + return np.inf + else: + return 1.0 - (float(len(g.boundary(mito, cyto)))/ + sum([len(g.boundary(mito, x)) for x in g.neighbors(mito)])) + + +def classifier_probability(feature_map, classifier): + def predict(g, edges): + edges = np.atleast_2d(edges) + boundary = np.sum(edges == g.boundary_body, axis=1).astype(bool) + result = np.empty(len(edges)) + result[boundary] = np.inf + features = np.atleast_2d([feature_map(g, n1, n2) + for n1, n2 in edges[~boundary]]) + if features.size > 0: + prediction = classifier.predict_proba(features)[:, 1] + else: + prediction = np.array([]) + result[~boundary] = prediction + return result return predict @@ -222,31 +302,34 @@ def ordered_priority(edges): n = len(edges) for i, (n1, n2) in enumerate(edges): score = float(i)/n - d[(n1,n2)] = score - d[(n2,n1)] = score - def ord(g, n1, n2): - return d.get((n1,n2), inf) + d[(n1, n2)] = score + d[(n2, n1)] = score + + def ord(g, edges): + return [d.get(e, inf) for e in edges] return ord -def expected_change_vi(feature_extractor, classifier, alpha=1.0, beta=1.0): - prob_func = classifier_probability(feature_extractor, classifier) - def predict(g, n1, n2): - p = prob_func(g, n1, n2) # Prediction from the classifier +def expected_change_vi(feature_map, classifier, alpha=1.0, beta=1.0): + prob_func = classifier_probability(feature_map, classifier) + def predict(g, edges): + p = prob_func(g, edges) # Prediction from the classifier # Calculate change in VI if n1 and n2 should not be merged - v = compute_local_vi_change( - g.node[n1]['size'], g.node[n2]['size'], g.volume_size - ) + n1_sizes = np.fromiter((g.node[n1]['size'] for n1, n2 in edges), + dtype=float, count=len(edges)) + n2_sizes = np.fromiter((g.node[n2]['size'] for n1, n2 in edges), + dtype=float, count=len(edges)) + v = compute_local_vi_change(n1_sizes, n2_sizes, g.volume_size) # Return expected change - return (p*alpha*v + (1.0-p)*(-beta*v)) + return p*alpha*v - (1-p)*beta*v return predict def compute_local_vi_change(s1, s2, n): """Compute change in VI if we merge disjoint sizes s1,s2 in a volume n.""" - py1 = float(s1)/n - py2 = float(s2)/n - py = py1+py2 + py1 = s1 / n + py2 = s2 / n + py = py1 + py2 return -(py1*np.log2(py1) + py2*np.log2(py2) - py*np.log2(py)) @@ -261,20 +344,22 @@ def compute_true_delta_vi(ctable, n1, n2): 2*(p3g_log_p3g - p1g_log_p1g - p2g_log_p2g) -def expected_change_rand(feature_extractor, classifier, alpha=1.0, beta=1.0): - prob_func = classifier_probability(feature_extractor, classifier) - def predict(g, n1, n2): - p = float(prob_func(g, n1, n2)) # Prediction from the classifier - v = compute_local_rand_change( - g.node[n1]['size'], g.node[n2]['size'], g.volume_size - ) - return p*v*alpha + (1.0-p)*(-beta*v) +def expected_change_rand(feature_map, classifier, alpha=1.0, beta=1.0): + prob_func = classifier_probability(feature_map, classifier) + def predict(g, edges): + p = prob_func(g, edges) # Prediction from the classifier + n1_sizes = np.fromiter((g.node[n1]['size'] for n1, n2 in edges), + dtype=float, count=len(edges)) + n2_sizes = np.fromiter((g.node[n2]['size'] for n1, n2 in edges), + dtype=float, count=len(edges)) + v = compute_local_rand_change(n1_sizes, n2_sizes, g.volume_size) + return p*v*alpha - (1-p)*beta*v return predict def compute_local_rand_change(s1, s2, n): """Compute change in rand if we merge disjoint sizes s1,s2 in volume n.""" - return float(s1*s2)/nchoosek(n,2) + return s1 * s2 / nchoosek(n, 2) def compute_true_delta_rand(ctable, n1, n2, n): @@ -292,20 +377,23 @@ def compute_true_delta_rand(ctable, n1, n2, n): return (2 * delta_sxy - delta_sx) / nchoosek(n, 2) -def boundary_mean_ladder(g, n1, n2, threshold, strictness=1): +def boundary_mean_ladder(g, edges, threshold, strictness=1): f = make_ladder(boundary_mean, threshold, strictness) - return f(g, n1, n2) + return f(g, edges) -def boundary_mean_plus_sem(g, n1, n2, alpha=-6): - bvals = g.probabilities_r[g.boundary(n1, n2)] - return mean(bvals) + alpha*sem(bvals) +def boundary_mean_plus_sem(g, edges, alpha=-6): + bvals = [g.probabilities_r[g.boundary(n1, n2)] for n1, n2 in edges] + means = np.fromiter(map(mean, bvals), dtype=float, count=len(edges)) + sems = np.fromiter(map(sem, bvals), dtype=float, count=len(edges)) + return means + alpha*sems -def random_priority(g, n1, n2): - if n1 == g.boundary_body or n2 == g.boundary_body: - return inf - return random.random() +def random_priority(g, edges): + edges = np.atleast_2d(edges) + result = np.random.rand(len(edges)) + result[np.sum(edges == g.boundary_body, axis=1).astype(bool)] = np.inf + return result class Rag(Graph): @@ -423,6 +511,7 @@ def __init__(self, watershed=array([], label_dtype), for n1, n2 in self.edges(): if isfrozenedge(self, n1, n2): self.frozen_edges.add((n1,n2)) + self.update_unchanged_edges = update_unchanged_edges if update_unchanged_edges: self.move_edge = self.merge_edge_properties @@ -835,9 +924,13 @@ def build_merge_queue(self): are merged, affected edges can be invalidated and reinserted in the queue with a new priority. """ + edges = self.real_edges() + if edges: + weights = self.merge_priority_function(self, edges) + else: + weights = [] queue_items = [] - for l1, l2 in self.real_edges_iter(): - w = self.merge_priority_function(self,l1,l2) + for w, (l1, l2) in zip(weights, edges): qitem = [w, True, l1, l2] queue_items.append(qitem) self[l1][l2]['qlink'] = qitem @@ -1413,17 +1506,22 @@ def merge_nodes(self, n1, n2, merge_priority=0.0): common_neighbors = np.intersect1d(self.neighbors(n1), self.neighbors(n2), assume_unique=True) + edges_to_update = [] for n in common_neighbors: self.merge_edge_properties((n2, n), (n1, n)) + edges_to_update.append((n1, n)) new_neighbors = np.setdiff1d(self.neighbors(n2), np.concatenate((common_neighbors, [n1])), assume_unique=True) for n in new_neighbors: self.move_edge((n2, n), (n1, n)) + if self.update_unchanged_edges: + edges_to_update.append((n1, n)) try: self.merge_queue.invalidate(self[n1][n2]['qlink']) except KeyError: # no edge or no queue link pass + self.update_merge_queue(edges_to_update) node_id = self.tree.merge(n1, n2, w) self.remove_node(n2) self.rename_node(n1, node_id) @@ -1566,10 +1664,9 @@ def merge_edge_properties(self, src, dst): self.merge_queue.invalidate(self[w][x]['qlink']) except KeyError: pass - self.update_merge_queue(u, v) - def update_merge_queue(self, u, v): + def update_merge_queue(self, edges): """Update the merge queue item for edge (u, v). Add new by default. Parameters @@ -1581,16 +1678,16 @@ def update_merge_queue(self, u, v): ------- None """ - if self.boundary_body in [u, v]: - return - if 'qlink' in self[u][v]: - self.merge_queue.invalidate(self[u][v]['qlink']) - if not self.merge_queue.is_null_queue: - w = self.merge_priority_function(self,u,v) - new_qitem = [w, True, u, v] - self[u][v]['qlink'] = new_qitem - self[u][v]['weight'] = w - self.merge_queue.push(new_qitem) + edges = [e for e in edges if self.boundary_body not in e] + if not self.merge_queue.is_null_queue and edges: + weights = self.merge_priority_function(self, edges) + for w, (u, v) in zip(weights, edges): + if 'qlink' in self[u][v]: + self.merge_queue.invalidate(self[u][v]['qlink']) + new_qitem = [w, True, u, v] + self[u][v]['qlink'] = new_qitem + self[u][v]['weight'] = w + self.merge_queue.push(new_qitem) def get_segmentation(self, threshold=None): @@ -1948,7 +2045,7 @@ def compute_W(self, merge_priority_function, sigma=255.0*20, nodes=None): i, j = nodes2ind[u], nodes2ind[v] except KeyError: continue - w = merge_priority_function(self,u,v) + w = merge_priority_function(self, ((u, v))) W[i,j] = W[j,i] = np.exp(-w**2/sigma) return W diff --git a/tests/test_agglo.py b/tests/test_agglo.py index 032eed1..c6e7a9a 100644 --- a/tests/test_agglo.py +++ b/tests/test_agglo.py @@ -27,16 +27,18 @@ def test_2_connectivity(): p = np.array([[1., 0.], [0., 1.]]) ws = np.array([[1, 2], [3, 4]], np.uint32) g = agglo.Rag(ws, p, connectivity=2, use_slow=True) - assert_equal(agglo.boundary_mean(g, 1, 2), 0.5) - assert_equal(agglo.boundary_mean(g, 1, 4), 1.0) + assert_equal(agglo.boundary_mean(g, [[1, 2]]), [0.5]) + assert_equal(agglo.boundary_mean(g, [[1, 4]]), [1.0]) + assert_equal(agglo.boundary_mean(g, [[1, 2], [1, 4]]), [0.5, 1.0]) def test_float_watershed(): """Ensure float arrays passed as watersheds don't crash everything.""" p = np.array([[1., 0.], [0., 1.]]) ws = np.array([[1, 2], [3, 4]], np.float32) g = agglo.Rag(ws, p, connectivity=2, use_slow=True) - assert_equal(agglo.boundary_mean(g, 1, 2), 0.5) - assert_equal(agglo.boundary_mean(g, 1, 4), 1.0) + assert_equal(agglo.boundary_mean(g, [[1, 2]])[0], 0.5) + assert_equal(agglo.boundary_mean(g, [[1, 4]])[0], 1.0) + assert_equal(agglo.boundary_mean(g, [[1, 2], [1, 4]]), [0.5, 1.0]) def test_empty_rag(): @@ -57,7 +59,8 @@ def test_agglomeration(): def test_ladder_agglomeration(): i = 2 g = agglo.Rag(wss[i], probs[i], agglo.boundary_mean, - normalize_probabilities=True, use_slow=True) + normalize_probabilities=True, use_slow=True, + update_unchanged_edges=True) g.agglomerate_ladder(3) g.agglomerate(0.51) assert_allclose(ev.vi(g.get_segmentation(), results[i]), 0.0, @@ -81,7 +84,7 @@ def frozen(g, i): normalize_probabilities=True, isfrozennode=frozen, use_slow=True) g.agglomerate(0.15) - g.merge_priority_function = agglo.mito_merge() + g.merge_priority_function = agglo.mito_merge g.rebuild_merge_queue() g.agglomerate(1.0) assert_allclose(ev.vi(g.get_segmentation(), results[i]), 0.0, @@ -162,16 +165,16 @@ def dummy_data(): def test_manual_agglo_fast_rag(dummy_data): frag, gt, g = dummy_data - assert agglo.boundary_mean(g, 6, 7) == 0.8 - assert agglo.boundary_mean(g, 6, 10) == 0.8 + assert agglo.boundary_mean(g, [[6, 7]])[0] == 0.8 + assert agglo.boundary_mean(g, [[6, 10]])[0] == 0.8 original_ids_0 = [g[u][v]['boundary-ids'] for u, v in [(5, 9), (6, 10)]] original_ids_1 = [g[u][v]['boundary-ids'] for u, v in [(7, 11), (8, 12)]] original_ids_2 = [g[u][v]['boundary-ids'] for u, v in [(2, 3), (6, 7)]] g.merge_subgraph([1, 2, 5, 6]) # results in node ID 20 - assert agglo.boundary_mean(g, 20, 10) == 0.8 + assert agglo.boundary_mean(g, [[20, 10]])[0] == 0.8 g.merge_subgraph(range(9, 17)) assert g[20][27]['boundary-ids'] == set.union(*original_ids_0) - assert np.allclose(agglo.boundary_mean(g, 20, 27), 0.8, atol=0.02) + assert np.allclose(agglo.boundary_mean(g, [[20, 27]])[0], 0.8, atol=0.02) g.merge_subgraph([3, 4, 7, 8]) assert g[27][30]['boundary-ids'] == set.union(*original_ids_1) g.merge_nodes(27, 30)