Skip to content

Commit

Permalink
Distance label trafos (#177)
Browse files Browse the repository at this point in the history
Implement per object distance transformations
  • Loading branch information
constantinpape committed Dec 5, 2023
1 parent 4ba7bd2 commit 4925ec5
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 43 deletions.
44 changes: 41 additions & 3 deletions test/transform/test_label_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def affs_brute_force_with_mask(seg, offsets, mask_bg_transition=True):
class TestLabelTransforms(unittest.TestCase):
def get_labels(self, with_zero):
shape = (64, 64)
# shape = (6, 6)
labels = np.random.randint(1, 6, size=shape).astype("uint64")
if with_zero:
bg_prob = 0.25
Expand Down Expand Up @@ -132,14 +131,14 @@ def test_distance_transform(self):
self.assertTrue((tnew >= 0).all())
self.assertTrue((tnew <= 5).all())

trafo = DistanceTransform(normalize=False, vector_distances=True)
trafo = DistanceTransform(normalize=False, directed_distances=True)
tnew = trafo(target)
self.assertEqual(tnew.shape, (3,) + target.shape)
distances, vector_distances = tnew[0], tnew[1:]
abs_dist = np.linalg.norm(vector_distances, axis=0)
self.assertTrue(np.allclose(distances, abs_dist))

trafo = DistanceTransform(normalize=True, vector_distances=True)
trafo = DistanceTransform(normalize=True, directed_distances=True)
tnew = trafo(target)
self.assertEqual(tnew.shape, (3,) + target.shape)
self.assertTrue((tnew >= -1).all())
Expand Down Expand Up @@ -169,6 +168,45 @@ def test_distance_transform_empty_labels(self):
tnew = trafo(target)
self.assertTrue(np.allclose(tnew, 1.0))

def test_per_object_distance_transform(self):
from torch_em.transform.label import PerObjectDistanceTransform
from skimage.data import binary_blobs
from skimage.measure import label

labels = label(binary_blobs(256, volume_fraction=0.25))

trafo = PerObjectDistanceTransform(
distances=True, boundary_distances=False, directed_distances=False, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (1,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=False, boundary_distances=True, directed_distances=False, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (1,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=False, boundary_distances=False, directed_distances=True, foreground=False,
)
result = trafo(labels)
self.assertEqual(result.shape, (2,) + labels.shape)
self.assertGreaterEqual(result.min(), -1)
self.assertLessEqual(result.max(), 1)

trafo = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True,
)
result = trafo(labels)
self.assertEqual(result.shape, (3,) + labels.shape)
self.assertGreaterEqual(result.min(), 0)
self.assertLessEqual(result.max(), 1)


if __name__ == "__main__":
unittest.main()
229 changes: 189 additions & 40 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import skimage.measure
import skimage.segmentation
from scipy.ndimage import distance_transform_edt
import vigra

from ..util import ensure_array, ensure_spatial_array

Expand Down Expand Up @@ -192,25 +192,32 @@ def __call__(self, labels):


class DistanceTransform:
"""Compute distances to foreground.
"""
eps = 1e-7

def __init__(
self,
distances=True, vector_distances=False,
normalize=True, max_distance=None,
foreground_id=1, invert=False, func=None
distances=True,
directed_distances=False,
normalize=True,
max_distance=None,
foreground_id=1,
invert=False,
func=None
):
if sum((distances, vector_distances)) == 0:
raise ValueError("At least one of 'distances' or 'vector_distances' must be set to 'True'")
self.vector_distances = vector_distances
if sum((distances, directed_distances)) == 0:
raise ValueError("At least one of 'distances' or 'directed_distances' must be set to 'True'")
self.directed_distances = directed_distances
self.distances = distances
self.normalize = normalize
self.max_distance = max_distance
self.foreground_id = foreground_id
self.invert = invert
self.func = func

def _compute_distances(self, distances):
def _compute_distances(self, directed_distances):
distances = np.linalg.norm(directed_distances, axis=0)
if self.max_distance is not None:
distances = np.clip(distances, 0, self.max_distance)
if self.normalize:
Expand All @@ -221,54 +228,196 @@ def _compute_distances(self, distances):
distances = self.func(distances)
return distances

def _compute_vector_distances(self, indices):
coordinates = np.indices(indices.shape[1:]).astype("float32")
vector_distances = indices - coordinates
def _compute_directed_distances(self, directed_distances):
if self.max_distance is not None:
vector_distances = np.clip(vector_distances, -self.max_distance, self.max_distance)
directed_distances = np.clip(directed_distances, -self.max_distance, self.max_distance)
if self.normalize:
vector_distances /= (np.abs(vector_distances).max(axis=(1, 2), keepdims=True) + self.eps)
directed_distances /= (np.abs(directed_distances).max(axis=(1, 2), keepdims=True) + self.eps)
if self.invert:
vector_distances = vector_distances.max(axis=(1, 2), keepdims=True) - vector_distances
directed_distances = directed_distances.max(axis=(1, 2), keepdims=True) - directed_distances
if self.func is not None:
vector_distances = self.func(vector_distances)
return vector_distances
directed_distances = self.func(directed_distances)
return directed_distances

def _get_distances_for_empty_labels(self, labels):
shape = labels.shape
fill_value = 0.0 if self.invert else np.linalg.norm(list(shape))
if self.distances and self.vector_distances:
data = (np.full(shape, fill_value), np.full((labels.ndim,) + shape, fill_value))
elif self.distances:
data = np.full(shape, fill_value)
elif self.vector_distances:
data = np.full((labels.ndim,) + shape, fill_value)
else:
raise RuntimeError
fill_value = 0.0 if self.invert else np.sqrt(np.linalg.norm(list(shape)) ** 2 / 2)
data = np.full((labels.ndim,) + shape, fill_value)
return data

def __call__(self, labels):
distance_mask = labels != self.foreground_id
distance_mask = (labels == self.foreground_id).astype("uint32")
# the distances are not computed corrected if they are all zero
# so this case needs to be handled separately
if distance_mask.sum() == distance_mask.size:
data = self._get_distances_for_empty_labels(labels)
if distance_mask.sum() == 0:
directed_distances = self._get_distances_for_empty_labels(labels)
else:
data = distance_transform_edt(distance_mask,
return_distances=self.distances,
return_indices=self.vector_distances)
ndim = distance_mask.ndim
to_channel_first = (ndim,) + tuple(range(ndim))
directed_distances = vigra.filters.vectorDistanceTransform(distance_mask).transpose(to_channel_first)

if self.distances:
distances = data[0] if self.vector_distances else data
distances = self._compute_distances(distances)
distances = self._compute_distances(directed_distances)

if self.vector_distances:
indices = data[1] if self.distances else data
vector_distances = self._compute_vector_distances(indices)
if self.directed_distances:
directed_distances = self._compute_directed_distances(directed_distances)

if self.distances and self.vector_distances:
return np.concatenate((distances[None], vector_distances), axis=0)
if self.distances and self.directed_distances:
return np.concatenate((distances[None], directed_distances), axis=0)
if self.distances:
return distances
if self.vector_distances:
return vector_distances
if self.directed_distances:
return directed_distances


class PerObjectDistanceTransform:
"""Compute normalized distances per object in a segmentation.
"""
eps = 1e-7

def __init__(
self,
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
apply_label=True,
correct_centers=True,
min_size=0,
distance_fill_value=1.0,
):
if sum([distances, directed_distances, boundary_distances]) == 0:
raise ValueError("At least one of distances or directed distances has to be passed.")
self.distances = distances
self.boundary_distances = boundary_distances
self.directed_distances = directed_distances
self.foreground = foreground

self.apply_label = apply_label
self.correct_centers = correct_centers
self.min_size = min_size
self.distance_fill_value = distance_fill_value

def compute_normalized_object_distances(self, mask, boundaries, bb, center, distances):
# Crop the mask and generate array with the correct center.
cropped_mask = mask[bb]
cropped_center = tuple(ce - b.start for ce, b in zip(center, bb))

# The centroid might not be inside of the object.
# In this case we correct the center by taking the maximum of the distance to the boundary.
# Note: the centroid is still the best estimate for the center, as long as it's in the object.
correct_center = not cropped_mask[cropped_center]

# Compute the boundary distances if necessary.
# (Either if we need to correct the center, or compute the boundary distances anyways.)
if correct_center or self.boundary_distances:
# Crop the boundary mask and compute the boundary distances.
cropped_boundary_mask = boundaries[bb]
boundary_distances = vigra.filters.distanceTransform(cropped_boundary_mask)
boundary_distances[~cropped_mask] = 0
max_dist_point = np.unravel_index(np.argmax(boundary_distances), boundary_distances.shape)

# Set the crop center to the max dist point
if correct_center:
# Find the center (= maximal distance from the boundaries).
cropped_center = max_dist_point

cropped_center_mask = np.zeros_like(cropped_mask, dtype="uint32")
cropped_center_mask[cropped_center] = 1

# Compute the directed distances,
if self.distances or self.directed_distances:
this_distances = vigra.filters.vectorDistanceTransform(cropped_center_mask)
else:
this_distances = None

# Keep only the specified distances:
if self.distances and self.directed_distances: # all distances
# Compute the undirected ditacnes from directed distances and concatenate,
undir = np.linalg.norm(this_distances, axis=-1, keepdims=True)
this_distances = np.concatenate([undir, this_distances], axis=-1)

elif self.distances: # only undirected distances
# Compute the undirected distances from directed distances and keep only them.
this_distances = np.linalg.norm(this_distances, axis=-1, keepdims=True)

elif self.directed_distances: # only directed distances
pass # We don't have to do anything becasue the directed distances are already computed.

# Add an extra channel for the boundary distances if specified.
if self.boundary_distances:
boundary_distances = (boundary_distances[max_dist_point] - boundary_distances)[..., None]
if this_distances is None:
this_distances = boundary_distances
else:
this_distances = np.concatenate([this_distances, boundary_distances], axis=-1)

# Set distances outside of the mask to zero.
this_distances[~cropped_mask] = 0

# Normalize the distances.
spatial_axes = tuple(range(mask.ndim))
this_distances /= (np.abs(this_distances).max(axis=spatial_axes, keepdims=True) + self.eps)

# Set the distance values in the global result.
distances[bb][cropped_mask] = this_distances[cropped_mask]

return distances

def __call__(self, labels):
# Apply label (connected components) if specified.
if self.apply_label:
labels = skimage.measure.label(labels).astype("uint32")
else: # Otherwise just relabel the segmentation.
labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")

# Filter out small objects if min_size is specified.
if self.min_size > 0:
ids, sizes = np.unique(labels, return_counts=True)
discard_ids = ids[sizes < self.min_size]
labels[np.isin(labels, discard_ids)] = 0
labels = vigra.analysis.relabelConsecutive(labels)[0].astype("uint32")

# Compute the boundaries. They will be used to determine the most central point,
# and if 'self.boundary_distances is True' to add the boundary distances.
boundaries = skimage.segmentation.find_boundaries(labels, mode="inner").astype("uint32")

# Compute region properties to derive bounding boxes and centers.
ndim = labels.ndim
props = skimage.measure.regionprops(labels)
bounding_boxes = {
prop.label: tuple(slice(prop.bbox[i], prop.bbox[i + ndim]) for i in range(ndim))
for prop in props
}

# Compute the object centers from centroids.
centers = {prop.label: np.round(prop.centroid).astype("int") for prop in props}

# Compute how many distance channels we have.
n_channels = 0
if self.distances: # We need one channel for the overall distances.
n_channels += 1
if self.boundary_distances: # We need one channel for the boundary distances.
n_channels += 1
if self.directed_distances: # And ndim channels for directed distances.
n_channels += ndim

# Compute the per object distances.
distances = np.full(labels.shape + (n_channels,), self.distance_fill_value, dtype="float32")
for prop in props:
label_id = prop.label
mask = labels == label_id
distances = self.compute_normalized_object_distances(
mask, boundaries, bounding_boxes[label_id], centers[label_id], distances
)

# Bring the distance channel to the first dimension.
to_channel_first = (ndim,) + tuple(range(ndim))
distances = distances.transpose(to_channel_first)

# Add the foreground mask as first channel if specified.
if self.foreground:
binary_labels = (labels > 0).astype("float32")
distances = np.concatenate([binary_labels[None], distances], axis=0)

return distances
Loading

0 comments on commit 4925ec5

Please sign in to comment.