diff --git a/test/transform/test_label_transforms.py b/test/transform/test_label_transforms.py index 059120e5..f9a39473 100644 --- a/test/transform/test_label_transforms.py +++ b/test/transform/test_label_transforms.py @@ -58,7 +58,6 @@ def affs_brute_force_with_mask(seg, offsets, mask_bg_transition=True): class TestLabelTransforms(unittest.TestCase): def get_labels(self, with_zero): shape = (64, 64) - # shape = (6, 6) labels = np.random.randint(1, 6, size=shape).astype("uint64") if with_zero: bg_prob = 0.25 @@ -132,14 +131,14 @@ def test_distance_transform(self): self.assertTrue((tnew >= 0).all()) self.assertTrue((tnew <= 5).all()) - trafo = DistanceTransform(normalize=False, vector_distances=True) + trafo = DistanceTransform(normalize=False, directed_distances=True) tnew = trafo(target) self.assertEqual(tnew.shape, (3,) + target.shape) distances, vector_distances = tnew[0], tnew[1:] abs_dist = np.linalg.norm(vector_distances, axis=0) self.assertTrue(np.allclose(distances, abs_dist)) - trafo = DistanceTransform(normalize=True, vector_distances=True) + trafo = DistanceTransform(normalize=True, directed_distances=True) tnew = trafo(target) self.assertEqual(tnew.shape, (3,) + target.shape) self.assertTrue((tnew >= -1).all()) @@ -169,6 +168,45 @@ def test_distance_transform_empty_labels(self): tnew = trafo(target) self.assertTrue(np.allclose(tnew, 1.0)) + def test_per_object_distance_transform(self): + from torch_em.transform.label import PerObjectDistanceTransform + from skimage.data import binary_blobs + from skimage.measure import label + + labels = label(binary_blobs(256, volume_fraction=0.25)) + + trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=False, directed_distances=False, foreground=False, + ) + result = trafo(labels) + self.assertEqual(result.shape, (1,) + labels.shape) + self.assertGreaterEqual(result.min(), 0) + self.assertLessEqual(result.max(), 1) + + trafo = PerObjectDistanceTransform( + distances=False, boundary_distances=True, directed_distances=False, foreground=False, + ) + result = trafo(labels) + self.assertEqual(result.shape, (1,) + labels.shape) + self.assertGreaterEqual(result.min(), 0) + self.assertLessEqual(result.max(), 1) + + trafo = PerObjectDistanceTransform( + distances=False, boundary_distances=False, directed_distances=True, foreground=False, + ) + result = trafo(labels) + self.assertEqual(result.shape, (2,) + labels.shape) + self.assertGreaterEqual(result.min(), -1) + self.assertLessEqual(result.max(), 1) + + trafo = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, + ) + result = trafo(labels) + self.assertEqual(result.shape, (3,) + labels.shape) + self.assertGreaterEqual(result.min(), 0) + self.assertLessEqual(result.max(), 1) + if __name__ == "__main__": unittest.main() diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 42de2d7e..63dfdf0b 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -1,7 +1,7 @@ import numpy as np import skimage.measure import skimage.segmentation -from scipy.ndimage import distance_transform_edt +import vigra from ..util import ensure_array, ensure_spatial_array @@ -192,17 +192,23 @@ def __call__(self, labels): class DistanceTransform: + """Compute distances to foreground. + """ eps = 1e-7 def __init__( self, - distances=True, vector_distances=False, - normalize=True, max_distance=None, - foreground_id=1, invert=False, func=None + distances=True, + directed_distances=False, + normalize=True, + max_distance=None, + foreground_id=1, + invert=False, + func=None ): - if sum((distances, vector_distances)) == 0: - raise ValueError("At least one of 'distances' or 'vector_distances' must be set to 'True'") - self.vector_distances = vector_distances + if sum((distances, directed_distances)) == 0: + raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'") + self.directed_distances = directed_distances self.distances = distances self.normalize = normalize self.max_distance = max_distance @@ -210,7 +216,8 @@ def __init__( self.invert = invert self.func = func - def _compute_distances(self, distances): + def _compute_distances(self, directed_distances): + distances = np.linalg.norm(directed_distances, axis=0) if self.max_distance is not None: distances = np.clip(distances, 0, self.max_distance) if self.normalize: @@ -221,54 +228,196 @@ def _compute_distances(self, distances): distances = self.func(distances) return distances - def _compute_vector_distances(self, indices): - coordinates = np.indices(indices.shape[1:]).astype("float32") - vector_distances = indices - coordinates + def _compute_directed_distances(self, directed_distances): if self.max_distance is not None: - vector_distances = np.clip(vector_distances, -self.max_distance, self.max_distance) + directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance) if self.normalize: - vector_distances /= (np.abs(vector_distances).max(axis=(1, 2), keepdims=True) + self.eps) + directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps) if self.invert: - vector_distances = vector_distances.max(axis=(1, 2), keepdims=True) - vector_distances + directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances if self.func is not None: - vector_distances = self.func(vector_distances) - return vector_distances + directed_distances = self.func(directed_distances) + return directed_distances def _get_distances_for_empty_labels(self, labels): shape = labels.shape - fill_value = 0.0 if self.invert else np.linalg.norm(list(shape)) - if self.distances and self.vector_distances: - data = (np.full(shape, fill_value), np.full((labels.ndim,) + shape, fill_value)) - elif self.distances: - data = np.full(shape, fill_value) - elif self.vector_distances: - data = np.full((labels.ndim,) + shape, fill_value) - else: - raise RuntimeError + fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2) + data = np.full((labels.ndim,) + shape, fill_value) return data def __call__(self, labels): - distance_mask = labels != self.foreground_id + distance_mask = (labels == self.foreground_id).astype("uint32") # the distances are not computed corrected if they are all zero # so this case needs to be handled separately - if distance_mask.sum() == distance_mask.size: - data = self._get_distances_for_empty_labels(labels) + if distance_mask.sum() == 0: + directed_distances = self._get_distances_for_empty_labels(labels) else: - data = distance_transform_edt(distance_mask, - return_distances=self.distances, - return_indices=self.vector_distances) + ndim = distance_mask.ndim + to_channel_first = (ndim,) + tuple(range(ndim)) + directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first) if self.distances: - distances = data[0] if self.vector_distances else data - distances = self._compute_distances(distances) + distances = self._compute_distances(directed_distances) - if self.vector_distances: - indices = data[1] if self.distances else data - vector_distances = self._compute_vector_distances(indices) + if self.directed_distances: + directed_distances = self._compute_directed_distances(directed_distances) - if self.distances and self.vector_distances: - return np.concatenate((distances[None], vector_distances), axis=0) + if self.distances and self.directed_distances: + return np.concatenate((distances[None], directed_distances), axis=0) if self.distances: return distances - if self.vector_distances: - return vector_distances + if self.directed_distances: + return directed_distances + + +class PerObjectDistanceTransform: + """Compute normalized distances per object in a segmentation. + """ + eps = 1e-7 + + def __init__( + self, + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + apply_label=True, + correct_centers=True, + min_size=0, + distance_fill_value=1.0, + ): + if sum([distances, directed_distances, boundary_distances]) == 0: + raise ValueError("At least one of distances or directed distances has to be passed.") + self.distances = distances + self.boundary_distances = boundary_distances + self.directed_distances = directed_distances + self.foreground = foreground + + self.apply_label = apply_label + self.correct_centers = correct_centers + self.min_size = min_size + self.distance_fill_value = distance_fill_value + + def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances): + # Crop the mask and generate array with the correct center. + cropped_mask = mask[bb] + cropped_center = tuple(ce - b.start for ce, b in zip(center, bb)) + + # The centroid might not be inside of the object. + # In this case we correct the center by taking the maximum of the distance to the boundary. + # Note: the centroid is still the best estimate for the center, as long as it's in the object. + correct_center = not cropped_mask[cropped_center] + + # Compute the boundary distances if necessary. + # (Either if we need to correct the center, or compute the boundary distances anyways.) + if correct_center or self.boundary_distances: + # Crop the boundary mask and compute the boundary distances. + cropped_boundary_mask = boundaries[bb] + boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask) + boundary_distances[~cropped_mask] = 0 + max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape) + + # Set the crop center to the max dist point + if correct_center: + # Find the center (= maximal distance from the boundaries). + cropped_center = max_dist_point + + cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32") + cropped_center_mask[cropped_center] = 1 + + # Compute the directed distances, + if self.distances or self.directed_distances: + this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask) + else: + this_distances = None + + # Keep only the specified distances: + if self.distances and self.directed_distances: # all distances + # Compute the undirected ditacnes from directed distances and concatenate, + undir = np.linalg.norm(this_distances, axis=-1, keepdims=True) + this_distances = np.concatenate([undir, this_distances], axis=-1) + + elif self.distances: # only undirected distances + # Compute the undirected distances from directed distances and keep only them. + this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True) + + elif self.directed_distances: # only directed distances + pass # We don't have to do anything becasue the directed distances are already computed. + + # Add an extra channel for the boundary distances if specified. + if self.boundary_distances: + boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None] + if this_distances is None: + this_distances = boundary_distances + else: + this_distances = np.concatenate([this_distances, boundary_distances], axis=-1) + + # Set distances outside of the mask to zero. + this_distances[~cropped_mask] = 0 + + # Normalize the distances. + spatial_axes = tuple(range(mask.ndim)) + this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps) + + # Set the distance values in the global result. + distances[bb][cropped_mask] = this_distances[cropped_mask] + + return distances + + def __call__(self, labels): + # Apply label (connected components) if specified. + if self.apply_label: + labels = skimage.measure.label(labels).astype("uint32") + else: # Otherwise just relabel the segmentation. + labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") + + # Filter out small objects if min_size is specified. + if self.min_size > 0: + ids, sizes = np.unique(labels, return_counts=True) + discard_ids = ids[sizes < self.min_size] + labels[np.isin(labels, discard_ids)] = 0 + labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32") + + # Compute the boundaries. They will be used to determine the most central point, + # and if 'self.boundary_distances is True' to add the boundary distances. + boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32") + + # Compute region properties to derive bounding boxes and centers. + ndim = labels.ndim + props = skimage.measure.regionprops(labels) + bounding_boxes = { + prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim)) + for prop in props + } + + # Compute the object centers from centroids. + centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props} + + # Compute how many distance channels we have. + n_channels = 0 + if self.distances: # We need one channel for the overall distances. + n_channels += 1 + if self.boundary_distances: # We need one channel for the boundary distances. + n_channels += 1 + if self.directed_distances: # And ndim channels for directed distances. + n_channels += ndim + + # Compute the per object distances. + distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32") + for prop in props: + label_id = prop.label + mask = labels == label_id + distances = self.compute_normalized_object_distances( + mask, boundaries, bounding_boxes[label_id], centers[label_id], distances + ) + + # Bring the distance channel to the first dimension. + to_channel_first = (ndim,) + tuple(range(ndim)) + distances = distances.transpose(to_channel_first) + + # Add the foreground mask as first channel if specified. + if self.foreground: + binary_labels = (labels > 0).astype("float32") + distances = np.concatenate([binary_labels[None], distances], axis=0) + + return distances diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index 2327a60d..a87215bd 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -19,6 +19,9 @@ # could also refactor this into elf def size_filter(seg, min_size, hmap=None, with_background=False): + if min_size == 0: + return seg + if hmap is None: ids, sizes = np.unique(seg, return_counts=True) bg_ids = ids[sizes < min_size] @@ -113,3 +116,45 @@ def watershed_from_maxima(boundaries, foreground, min_distance, min_size=250, si seg = watershed(boundaries, markers=seeds, mask=foreground) seg = size_filter(seg, min_size) return seg + + +def watershed_from_center_and_boundary_distances( + center_distances, + boundary_distances, + foreground_map, + center_distance_threshold=0.5, + boundary_distance_threshold=0.9, + foreground_threshold=0.5, + min_size=0, +): + """Seeded watershed based on distance predictions to object center and boundaries. + + The seeds are computed by finding connected components where + + Args: + center_distances [np.ndarray] - Distance prediction to the objcet center. + boundary_distances [np.ndarray] - Inverted distance prediction to object boundaries. + foreground_map [np.ndarray] - Predictio for foreground probabilities. + center_distance_threshold [float] - Center distance predictions below this value will be + used to find seeds (intersected with thresholded boundary distance predictions). + boundary_distance_threshold [float] - Boundary distance predictions below this value will be + used to find seeds (intersected with thresholded center distance predictions). + foreground_threshold [float] - Foreground predictions above this value will be used as foreground mask. + min_size [int] - Minimal object size in the segmentation result. + + Returns: + np.ndarray - The instance segmentation. + """ + fg = foreground_map > foreground_threshold + + marker_map = np.logical_and( + center_distances < center_distance_threshold, + boundary_distances < boundary_distance_threshold + ) + marker_map[~fg] = 0 + markers = label(marker_map) + + seg = watershed(boundary_distances, markers=markers, mask=fg) + seg = size_filter(seg, min_size) + + return seg