diff --git a/test/transform/test_label_transforms.py b/test/transform/test_label_transforms.py index 8fbd6154..8cfedb8d 100644 --- a/test/transform/test_label_transforms.py +++ b/test/transform/test_label_transforms.py @@ -116,6 +116,37 @@ def test_affinities_with_ignore_transition(self): self.assertTrue(np.allclose(affs, expected_affs)) self.assertTrue(np.allclose(mask, expected_mask)) + def test_distance_transform(self): + from torch_em.transform.label import DistanceTransform + target = np.random.rand(128, 128) > 0.95 + + trafo = DistanceTransform(normalize=True, max_distance=None) + tnew = trafo(target) + self.assertFalse(np.allclose(tnew, 0)) + self.assertTrue((tnew >= 0).all()) + self.assertTrue((tnew <= 1).all()) + + trafo = DistanceTransform(normalize=False, max_distance=5) + tnew = trafo(target) + self.assertFalse(np.allclose(tnew, 0)) + self.assertTrue((tnew >= 0).all()) + self.assertTrue((tnew <= 5).all()) + + trafo = DistanceTransform(normalize=False, vector_distances=True) + tnew = trafo(target) + self.assertEqual(tnew.shape, (3,) + target.shape) + distances, vector_distances = tnew[0], tnew[1:] + abs_dist = np.linalg.norm(vector_distances, axis=0) + self.assertTrue(np.allclose(distances, abs_dist)) + + trafo = DistanceTransform(normalize=True, vector_distances=True) + tnew = trafo(target) + self.assertEqual(tnew.shape, (3,) + target.shape) + distances, vector_distances = tnew[0], tnew[1:] + + self.assertTrue((tnew >= -1).all()) + self.assertTrue((tnew <= 1).all()) + if __name__ == '__main__': unittest.main() diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index 877d8c72..78ebd086 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -1,6 +1,7 @@ import numpy as np import skimage.measure import skimage.segmentation +from scipy.ndimage import distance_transform_edt from ..util import ensure_array, ensure_spatial_array @@ -157,3 +158,64 @@ def __call__(self, labels): for i, class_id in enumerate(class_ids): one_hot[i][labels == class_id] = 1.0 return one_hot + + +class DistanceTransform: + def __init__( + self, + distances=True, vector_distances=False, + normalize=True, max_distance=None, + foreground_id=1, invert=False, func=None + ): + if sum((distances, vector_distances)) == 0: + raise ValueError("At least one of 'distances' or 'vector_distances' must be set to 'True'") + self.vector_distances = vector_distances + self.distances = distances + self.normalize = normalize + self.max_distance = max_distance + self.foreground_id = foreground_id + self.invert = invert + self.func = func + + def _compute_distances(self, distances): + if self.max_distance is not None: + distances = np.clip(distances, 0, self.max_distance) + if self.normalize: + distances /= distances.max() + if self.invert: + distances = distances.max() - distances + if self.func is not None: + distances = self.func(distances) + return distances + + def _compute_vector_distances(self, indices): + coordinates = np.indices(indices.shape[1:]).astype("float32") + vector_distances = indices - coordinates + if self.max_distance is not None: + vector_distances = np.clip(vector_distances, -self.max_distance, self.max_distance) + if self.normalize: + vector_distances /= vector_distances.max(axis=(1, 2), keepdims=True) + if self.invert: + vector_distances = vector_distances.max(axis=(1, 2), keepdims=True) - vector_distances + if self.func is not None: + vector_distances = self.func(vector_distances) + return vector_distances + + def __call__(self, labels): + data = distance_transform_edt(labels != self.foreground_id, + return_distances=self.distances, + return_indices=self.vector_distances) + if self.distances: + distances = data[0] if self.vector_distances else data + distances = self._compute_distances(distances) + + if self.vector_distances: + indices = data[1] if self.distances else data + vector_distances = self._compute_vector_distances(indices) + + if self.distances and self.vector_distances: + return np.concatenate((distances[None], vector_distances), axis=0) + if self.distances: + return distances + if self.vector_distances: + return vector_distances