Skip to content

Commit

Permalink
Merge pull request #74 from tobiasmaier/bugfixes
Browse files Browse the repository at this point in the history
Bugfixes
  • Loading branch information
jni committed Jan 13, 2016
2 parents 47118ab + cb0bd74 commit c8dd73d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
16 changes: 8 additions & 8 deletions gala/agglo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from .dtypes import label_dtype


def contingency_table(a, b):
ct = ev_contingency_table(a, b)
def contingency_table(a, b, ignore_seg=[0], ignore_gt=[0]):
ct = ev_contingency_table(a, b, ignore_seg, ignore_gt)
nx, ny = ct.shape
ctout = np.zeros((2*nx + 1, ny), ct.dtype)
ct.todense(out=ctout[:nx, :])
Expand Down Expand Up @@ -721,7 +721,7 @@ def set_ground_truth(self, gt=None):
gt_ignore = [0, gtm] if (gt==0).any() else [gtm]
seg_ignore = [0, self.boundary_body] if \
(self.watershed==0).any() else [self.boundary_body]
self.gt = morpho.pad(gt, gt_ignore)
self.gt = morpho.pad(gt, gtm)
self.rig = contingency_table(self.watershed, self.gt,
ignore_seg=seg_ignore,
ignore_gt=gt_ignore)
Expand Down Expand Up @@ -1194,7 +1194,7 @@ def learn_edge(self, edge, ctables, assignments, feature_map):
for ctable in ctables]
]
labels = [np.sign(mean(cont_label)) for cont_label in cont_labels]
if any(map(isnan, labels)) or any([label == 0 for l in labels]):
if any(map(isnan, labels)) or any([label == 0 for label in labels]):
logging.debug('NaN or 0 labels found. ' +
' '.join(map(str, [labels, (n1, n2)])))
labels = [1 if i==0 or isnan(i) or n1 in self.frozen_nodes or
Expand Down Expand Up @@ -1768,9 +1768,9 @@ def split_vi(self, gt=None):
if self.gt is None and gt is None:
return array([0,0])
elif self.gt is not None:
return split_vi(None, None, self.rig)
return split_vi(self.rig)
else:
return split_vi(self.get_segmentation(), gt, None, [0], [0])
return split_vi(self.get_segmentation(), gt, [0], [0])


def boundary_indices(self, n1, n2):
Expand Down Expand Up @@ -1919,12 +1919,12 @@ def is_mito(g, n, channel=2, threshold=0.5):

def best_possible_segmentation(ws, gt):
"""Build the best possible segmentation given a superpixel map."""
cnt = contingency_table(ws, gt)
ws = Rag(ws)
cnt = contingency_table(ws.get_segmentation(), gt)
assignment = cnt == cnt.max(axis=1)[:,newaxis]
hard_assignment = where(assignment.sum(axis=1) > 1)[0]
# currently ignoring hard assignment nodes
assignment[hard_assignment,:] = 0
ws = Rag(ws)
for gt_node in range(1,cnt.shape[1]):
ws.merge_subgraph(where(assignment[:,gt_node])[0])
return ws.get_segmentation()
2 changes: 2 additions & 0 deletions gala/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def get_classifier(name='random forest', *args, **kwargs):
is_naive_bayes = name.find('naive') > -1
is_logistic = name.startswith('logis')
if vigra_available and is_random_forest:
if 'random_state' in kwargs:
del kwargs['random_state']
return VigraRandomForest(*args, **kwargs)
elif is_random_forest:
return DefaultRandomForest(*args, **kwargs)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_agglo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,33 @@ def test_thin_fragment_agglo2():
g = agglo2.Rag(labels)
assert (1, 3) not in g.graph.edges()


def test_best_possible_segmentation():
ws = np.array([[2,3],[4,5]], np.int32)
gt = np.array([[1,2],[1,2]], np.int32)
best = agglo.best_possible_segmentation(ws, gt)
assert np.all(best[0,:] == best[1,:])


def test_set_ground_truth():
labels = [[1, 0, 2],
[1, 0, 2],
[1, 0, 2]]
g = agglo.Rag(np.array(labels))
g.set_ground_truth(np.array(labels))


def test_split_vi():
labels = [[1, 0, 2],
[1, 0, 2],
[1, 0, 2]]
g = agglo.Rag(np.array(labels))
vi0 = g.split_vi(np.array(labels))
g.set_ground_truth(np.array(labels))
vi1 = g.split_vi()
assert np.all(vi0 == vi1)


if __name__ == '__main__':
from numpy import testing
testing.run_module_suite()
Expand Down

0 comments on commit c8dd73d

Please sign in to comment.