Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions ffn/utils/decision_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,53 @@

from connectomics.common import bounding_box
from connectomics.segmentation import labels
from ffn.inference import segmentation as segmentation_lib
import numpy as np
import pandas as pd
from scipy import ndimage


def find_decision_points(seg: np.ndarray,
voxel_size: Sequence[float],
max_distance: Optional[float] = None,
subvol_box: Optional[bounding_box.BoundingBox] = None
) -> dict[tuple[int, int], tuple[float, np.ndarray]]:
def find_decision_points(
seg: np.ndarray,
voxel_size: Sequence[float],
max_distance: Optional[float] = None,
subvol_box: Optional[bounding_box.BoundingBox] = None,
optimize_sparse: bool = False,
sparse_noise_threshold: int = 0,
) -> dict[tuple[int, int], tuple[float, np.ndarray]]:
"""Identifies decision points in a segmentation subvolume.

Args:
seg: 3d uint64 ndarray of segmentation data
voxel_size: 3-tuple (xyz) defining the physical voxel size
max_distance: maximum distance between the segment and the decision point
(same units as voxel_size); if None, distances will not be limited
subvol_box: selector for a subvolume within `seg` within which
to search for decision points; the whole subvolume is always used
to compute the distance transform
subvol_box: selector for a subvolume within `seg` within which to search for
decision points; the whole subvolume is always used to compute the
distance transform
optimize_sparse: if True, first counts the number of segments in `seg` and
returns early if there are fewer than 2.
sparse_noise_threshold: if > 0 and `optimize_sparse` is True, ignores
components with voxel counts < this threshold when counting segments.

Returns:
dict from segment ID pairs to tuples of:
approximate physical distance from the segment to the decision point
(x, y, z) decision point
"""
if optimize_sparse:
_, counts = segmentation_lib.clean_up_and_count(
seg,
split_cc=False,
min_size=sparse_noise_threshold,
compute_id_map=False,
)

if counts is not None and len([k for k in counts.keys() if k > 0]) <= 1:
# If there are 0 or 1 unique segments (excluding background),
# they cannot possibly touch another segment.
return {}

# EDT is the Euclidean Distance Transform, specifying how far voxels added
# in 'expanded_seg' are from the seeds in 'seg'.
expanded_seg, edt = labels.watershed_expand(seg, voxel_size, max_distance)
Expand Down
33 changes: 33 additions & 0 deletions ffn/utils/tests/decision_point_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,39 @@ def test_find_decision_point(self):
self.assertIn((1, 2), points)
self.assertLen(points, 1)

def test_find_decision_point_optimize_sparse(self):
# 2 segments, but one is very small and should be filtered out
seg = np.zeros((100, 80, 60), dtype=np.uint64)
seg[:40, :, :] = 1
# 2 voxels of label 2
seg[60, 0, 0] = 2
seg[61, 0, 0] = 2

# Without optimization, 1 and 2 might connect if they are grown far enough.
points = decision_point.find_decision_points(seg, (1, 1, 1))
self.assertIn((1, 2), points)

# With optimization but threshold 0, they still connect (no size filtering)
points = decision_point.find_decision_points(
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=0
)
self.assertIn((1, 2), points)

# With optimization and threshold >= 2, label 2 is zeroed.
# We're left with label 1 (size > 2), so 1 segment -> returns empty
points = decision_point.find_decision_points(
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=3
)
self.assertEmpty(points)

# With just 1 segment from the start
seg = np.zeros((100, 80, 60), dtype=np.uint64)
seg[:40, :, :] = 1
points = decision_point.find_decision_points(
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=0
)
self.assertEmpty(points)


if __name__ == '__main__':
absltest.main()
Loading