diff --git a/docs/source/transforms/preprocessing.rst b/docs/source/transforms/preprocessing.rst index c26348f88..b7da83900 100644 --- a/docs/source/transforms/preprocessing.rst +++ b/docs/source/transforms/preprocessing.rst @@ -85,3 +85,29 @@ Spatial .. autoclass:: ToCanonical :show-inheritance: + +Label +--------- + +.. currentmodule:: torchio.transforms.preprocessing.label + + +:class:`RemapLabels` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RemapLabels + :show-inheritance: + + +:class:`RemoveLabels` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RemoveLabels + :show-inheritance: + + +:class:`SequentialLabels` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: SequentialLabels + :show-inheritance: diff --git a/tests/transforms/label/__init__.py b/tests/transforms/label/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transforms/label/test_remap_labels.py b/tests/transforms/label/test_remap_labels.py new file mode 100644 index 000000000..bdb0c213c --- /dev/null +++ b/tests/transforms/label/test_remap_labels.py @@ -0,0 +1,17 @@ +from torchio.transforms import RemapLabels +from ...utils import TorchioTestCase + + +class TestRemapLabels(TorchioTestCase): + """Tests for `RemapLabels`.""" + def test_remap(self): + remapping = {1: 2, 2: 1, 5: 10, 6: 11} + remap_labels = RemapLabels(remapping=remapping) + + subject = self.get_subject_with_labels(labels=remapping.keys()) + transformed = remap_labels(subject) + inverse_transformed = transformed.apply_inverse_transform() + + self.assertEqual(self.get_unique_labels(subject.label), set(remapping.keys())) + self.assertEqual(self.get_unique_labels(transformed.label), set(remapping.values())) + self.assertEqual(self.get_unique_labels(inverse_transformed.label), set(remapping.keys())) diff --git a/tests/transforms/label/test_remove_labels.py b/tests/transforms/label/test_remove_labels.py new file mode 100644 index 000000000..5089e151b --- /dev/null +++ b/tests/transforms/label/test_remove_labels.py @@ -0,0 +1,28 @@ +from torchio.transforms import RemoveLabels +from ...utils import TorchioTestCase + + +class TestRemoveLabels(TorchioTestCase): + """Tests for `RemoveLabels`.""" + def test_remove(self): + initial_labels = (1, 2, 3, 4, 5, 6, 7) + labels_to_remove = (1, 2, 5, 6) + remaining_labels = (3, 4, 7) + + remove_labels = RemoveLabels(labels_to_remove) + + subject = self.get_subject_with_labels(labels=initial_labels) + transformed = remove_labels(subject) + inverse_transformed = transformed.apply_inverse_transform(warn=False) + self.assertEqual( + self.get_unique_labels(subject.label), + set(initial_labels), + ) + self.assertEqual( + self.get_unique_labels(transformed.label), + set(remaining_labels), + ) + self.assertEqual( + self.get_unique_labels(inverse_transformed.label), + set(remaining_labels), + ) diff --git a/tests/transforms/label/test_sequential_labels.py b/tests/transforms/label/test_sequential_labels.py new file mode 100644 index 000000000..f760b11df --- /dev/null +++ b/tests/transforms/label/test_sequential_labels.py @@ -0,0 +1,19 @@ +from torchio.transforms import SequentialLabels +from ...utils import TorchioTestCase + + +class TestSequentialLabels(TorchioTestCase): + """Tests for `SequentialLabels`.""" + def test_sequential(self): + initial_labels = (2, 8, 9, 10, 15, 20, 100) + transformed_labels = (1, 2, 3, 4, 5, 6, 7) + + sequential_labels = SequentialLabels() + + subject = self.get_subject_with_labels(labels=initial_labels) + transformed = sequential_labels(subject) + inverse_transformed = transformed.apply_inverse_transform() + + self.assertEqual(self.get_unique_labels(subject.label), set(initial_labels)) + self.assertEqual(self.get_unique_labels(transformed.label), set(transformed_labels)) + self.assertEqual(self.get_unique_labels(inverse_transformed.label), set(initial_labels)) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 33a8ca8e1..dfb3df8f5 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -44,6 +44,9 @@ def get_transform(self, channels, is_3d=True, labels=True): tio.RandomAffine(): 3, elastic: 1, }), + tio.RemapLabels(remapping={1: 2, 2: 1, 3: 20, 4: 25}, masking_method='Left'), + tio.RemoveLabels([1, 3]), + tio.SequentialLabels(), tio.Pad(pad_args, padding_mode=3), tio.Crop(crop_args), ] @@ -121,7 +124,7 @@ def test_transforms_subject_4d(self): transformed = transform(subject) trsf_channels = len(transformed.t1.data) assert trsf_channels > 1, f'Lost channels in {transform.name}' - if transform.name != 'RandomLabelsToImage': + if transform.name not in ['RandomLabelsToImage', 'RemapLabels', 'RemoveLabels', 'SequentialLabels']: self.assertEqual( subject.shape[0], transformed.shape[0], diff --git a/tests/utils.py b/tests/utils.py index 8bf351477..7514105e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -119,6 +119,20 @@ def get_subject_with_partial_volume_label_map(self, components=1): ), ) + def get_subject_with_labels(self, labels): + return tio.Subject( + label=tio.LabelMap( + self.get_image_path( + 'label_multi', labels=labels + ) + ) + ) + + def get_unique_labels(self, label_map): + labels = torch.unique(label_map.data) + labels = {i.item() for i in labels if i != 0} + return labels + def tearDown(self): """Tear down test fixtures, if any.""" shutil.rmtree(self.dir) @@ -131,6 +145,7 @@ def get_image_path( self, stem, binary=False, + labels=None, shape=(10, 20, 30), spacing=(1, 1, 1), components=1, @@ -144,6 +159,14 @@ def get_image_path( data = (data > 0.5).astype(np.uint8) if not data.sum() and force_binary_foreground: data[..., 0] = 1 + elif labels is not None: + data = (data * (len(labels) + 1)).astype(np.uint8) + new_data = np.zeros_like(data) + for i, label in enumerate(labels): + new_data[data == (i + 1)] = label + if not (new_data == label).sum(): + new_data[..., i] = label + data = new_data elif self.flip_coin(): # cast some images data *= 100 dtype = np.uint8 if self.flip_coin() else np.uint16 @@ -171,7 +194,7 @@ def get_tests_data_dir(self): return Path(__file__).parent / 'image_data' def assertTensorNotEqual(self, *args, **kwargs): # noqa: N802 - message_kwarg = dict(msg=args[2]) if len(args) == 3 else {} + message_kwarg = {'msg': args[2]} if len(args) == 3 else {} with self.assertRaises(AssertionError, **message_kwarg): self.assertTensorEqual(*args, **kwargs) diff --git a/torchio/data/subject.py b/torchio/data/subject.py index 0233752e4..329279899 100644 --- a/torchio/data/subject.py +++ b/torchio/data/subject.py @@ -128,7 +128,9 @@ def get_inverse_transform(self, warn=True) -> 'Transform': return self.get_composed_history().inverse(warn=warn) def apply_inverse_transform(self, warn=True) -> 'Subject': - return self.get_inverse_transform(warn=warn)(self) + transformed = self.get_inverse_transform(warn=warn)(self) + transformed.clear_history() + return transformed def clear_history(self) -> None: self.applied_transforms = [] diff --git a/torchio/transforms/__init__.py b/torchio/transforms/__init__.py index 9ce0c9ee8..dfdafcce0 100644 --- a/torchio/transforms/__init__.py +++ b/torchio/transforms/__init__.py @@ -35,6 +35,9 @@ from .preprocessing import RescaleIntensity from .preprocessing import HistogramStandardization from .preprocessing.intensity.histogram_standardization import train as train_histogram +from .preprocessing.label.remap_labels import RemapLabels +from .preprocessing.label.sequential_labels import SequentialLabels +from .preprocessing.label.remove_labels import RemoveLabels __all__ = [ @@ -79,4 +82,7 @@ 'RescaleIntensity', 'CropOrPad', 'train_histogram', + 'RemapLabels', + 'SequentialLabels', + 'RemoveLabels', ] diff --git a/torchio/transforms/augmentation/composition.py b/torchio/transforms/augmentation/composition.py index 9e0d1a664..6b9456e10 100644 --- a/torchio/transforms/augmentation/composition.py +++ b/torchio/transforms/augmentation/composition.py @@ -69,7 +69,8 @@ def inverse(self, warn: bool = True) -> Transform: result = Compose(transforms) else: # return noop if no invertible transforms are found def result(x): return x # noqa: E704 - warnings.warn('No invertible transforms found', RuntimeWarning) + if warn: + warnings.warn('No invertible transforms found', RuntimeWarning) return result diff --git a/torchio/transforms/preprocessing/__init__.py b/torchio/transforms/preprocessing/__init__.py index b772884a6..ec10816a9 100644 --- a/torchio/transforms/preprocessing/__init__.py +++ b/torchio/transforms/preprocessing/__init__.py @@ -8,6 +8,10 @@ from .intensity.z_normalization import ZNormalization from .intensity.histogram_standardization import HistogramStandardization +from .label.remap_labels import RemapLabels +from .label.sequential_labels import SequentialLabels +from .label.remove_labels import RemoveLabels + __all__ = [ 'Pad', @@ -18,4 +22,7 @@ 'RescaleIntensity', 'ZNormalization', 'HistogramStandardization', + 'RemapLabels', + 'SequentialLabels', + 'RemoveLabels', ] diff --git a/torchio/transforms/preprocessing/intensity/normalization_transform.py b/torchio/transforms/preprocessing/intensity/normalization_transform.py index b9ffd48a2..33ae74abc 100644 --- a/torchio/transforms/preprocessing/intensity/normalization_transform.py +++ b/torchio/transforms/preprocessing/intensity/normalization_transform.py @@ -1,24 +1,25 @@ -from typing import Union import torch from ....data.subject import Subject -from ....typing import TypeCallable +from ....transforms.transform import Transform +from ....transforms.transform import TypeMaskingMethod from ... import IntensityTransform -TypeMaskingMethod = Union[str, TypeCallable, None] - - class NormalizationTransform(IntensityTransform): """Base class for intensity preprocessing transforms. Args: masking_method: Defines the mask used to compute the normalization statistics. It can be one of: - - ``None``: the mask image is all ones, i.e. all values in the image are used + - ``None``: the mask image is all ones, i.e. all values in the image are used. + + - A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask, + OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``, + ``'Inferior'``, ``'Superior'`` which specifies a side of the mask volume to be ones. - - A string: the mask image is retrieved from the subject, which is expected the string as a key + - A function: the mask image is computed as a function of the intensity image. + The function must receive and return a :class:`torch.Tensor` - - A function: the mask image is computed as a function of the intensity image. The function must receive and return a :class:`torch.Tensor` **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: @@ -39,32 +40,12 @@ def __init__( masking_method: TypeMaskingMethod = None, **kwargs ): - """ - masking_method is used to choose the values used for normalization. - It can be: - - A string: the mask will be retrieved from the subject - - A function: the mask will be computed using the function - - None: all values are used - """ super().__init__(**kwargs) - self.mask_name = None self.masking_method = masking_method - if masking_method is None: - self.masking_method = self.ones - elif callable(masking_method): - self.masking_method = masking_method - elif isinstance(masking_method, str): - self.mask_name = masking_method - - def get_mask(self, subject: Subject, tensor: torch.Tensor) -> torch.Tensor: - if self.mask_name is None: - return self.masking_method(tensor) - else: - return subject[self.mask_name].data.bool() def apply_transform(self, subject: Subject) -> Subject: for image_name, image in self.get_images_dict(subject).items(): - mask = self.get_mask(subject, image.data) + mask = Transform.get_mask(self.masking_method, subject, image.data) self.apply_normalization(subject, image_name, mask) return subject @@ -76,12 +57,3 @@ def apply_normalization( ) -> None: # There must be a nicer way of doing this raise NotImplementedError - - @staticmethod - def ones(tensor: torch.Tensor) -> torch.Tensor: - return torch.ones_like(tensor, dtype=torch.bool) - - @staticmethod - def mean(tensor: torch.Tensor) -> torch.Tensor: - mask = tensor > tensor.mean() - return mask diff --git a/torchio/transforms/preprocessing/label/__init__.py b/torchio/transforms/preprocessing/label/__init__.py new file mode 100644 index 000000000..ba4faee09 --- /dev/null +++ b/torchio/transforms/preprocessing/label/__init__.py @@ -0,0 +1,10 @@ +from .remap_labels import RemapLabels +from .sequential_labels import SequentialLabels +from .remove_labels import RemoveLabels + + +__all__ = [ + 'RemapLabels', + 'SequentialLabels', + 'RemoveLabels', +] diff --git a/torchio/transforms/preprocessing/label/remap_labels.py b/torchio/transforms/preprocessing/label/remap_labels.py new file mode 100644 index 000000000..b8373fe23 --- /dev/null +++ b/torchio/transforms/preprocessing/label/remap_labels.py @@ -0,0 +1,91 @@ +from typing import Dict + +from ....data import LabelMap +from ...transform import Transform, TypeMaskingMethod + + +class RemapLabels(Transform): + r"""Remap the integer ids of labels in a LabelMap. + + This transformation may not be invertible if two labels are combined by the + remapping. + A masking method can be used to correctly split the label into two during + the `inverse transformation `_ (see example). + + Args: + remapping: Dictionary that specifies how labels should be remapped. + The keys are the old label ids, and the corresponding values replace + them. + masking_method: Defines a mask for where the label remapping is applied. It can be one of: + + - ``None``: the mask image is all ones, i.e. all values in the image are used. + + - A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask, + OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``, + ``'Inferior'``, ``'Superior'`` which specifies a side of the mask volume to be ones. + + - A function: the mask image is computed as a function of the intensity image. + The function must receive and return a :class:`torch.Tensor`. + + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + + Example: + >>> import torchio as tio + >>> # Target label map has the following labels: + >>> # {'left_ventricle': 1, 'right_ventricle': 2, 'left_caudate': 3, 'right_caudate': 4, + >>> # 'left_putamen': 5, 'right_putamen': 6, 'left_thalamus': 7, 'right_thalamus': 8} + >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}) + >>> # Merge right side labels with left side labels + >>> transformed = transform(subject) + >>> # Undesired behavior: The inverse transform will remap ALL left side labels to right side labels + >>> # so the label map only has right side labels. + >>> inverse_transformed = transformed.apply_inverse_transform() + >>> # Here's the *right* way to do it with masking: + >>> transform = tio.RemapLabels({2:1, 4:3, 6:5, 8:7}, masking_method="Right") + >>> # Remap the labels on the right side only (no difference yet). + >>> transformed = transform(subject) + >>> # Apply the inverse on the right side only. The labels are correctly split into left/right. + >>> inverse_transformed = transformed.apply_inverse_transform() + """ + def __init__( + self, + remapping: Dict[int, int], + masking_method: TypeMaskingMethod = None, + **kwargs + ): + super().__init__(**kwargs) + self.kwargs = kwargs + self.remapping = remapping + self.masking_method = masking_method + self.args_names = ('remapping', 'masking_method',) + + def apply_transform(self, subject): + images = subject.get_images( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for image in images: + if not isinstance(image, LabelMap): + continue + + new_data = image.data.clone() + mask = Transform.get_mask(self.masking_method, subject, new_data) + for old_id, new_id in self.remapping.items(): + new_data[mask & (image.data == old_id)] = new_id + image.data = new_data + + return subject + + def is_invertible(self): + return True + + def inverse(self): + inverse_remapping = {v: k for k, v in self.remapping.items()} + inverse_transform = RemapLabels( + inverse_remapping, + masking_method=self.masking_method, + **self.kwargs, + ) + return inverse_transform diff --git a/torchio/transforms/preprocessing/label/remove_labels.py b/torchio/transforms/preprocessing/label/remove_labels.py new file mode 100644 index 000000000..1d4e05127 --- /dev/null +++ b/torchio/transforms/preprocessing/label/remove_labels.py @@ -0,0 +1,35 @@ +from typing import Sequence + +from ...transform import TypeMaskingMethod +from .remap_labels import RemapLabels + + +class RemoveLabels(RemapLabels): + r"""Remove labels from a label map by remapping them to the background label. + + This transformation is not `invertible `_. + + Args: + labels: A sequence of label integers that will be removed. + background_label: integer that specifies which label is considered to be + background (generally 0). + masking_method: See :class:`~torchio.RemapLabels`. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + """ + def __init__( + self, + labels: Sequence[int], + background_label: int = 0, + masking_method: TypeMaskingMethod = None, + **kwargs + ): + remapping = {label: background_label for label in labels} + super().__init__(remapping, masking_method, **kwargs) + self.labels = labels + self.background_label = background_label + self.masking_method = masking_method + self.args_names = ('labels', 'background_label', 'masking_method',) + + def is_invertible(self): + return False diff --git a/torchio/transforms/preprocessing/label/sequential_labels.py b/torchio/transforms/preprocessing/label/sequential_labels.py new file mode 100644 index 000000000..2c100a6b4 --- /dev/null +++ b/torchio/transforms/preprocessing/label/sequential_labels.py @@ -0,0 +1,53 @@ +import torch + +from ....data import LabelMap +from ...transform import Transform, TypeMaskingMethod +from .remap_labels import RemapLabels + + +class SequentialLabels(Transform): + r"""Remap the integer IDs of labels in a LabelMap to be sequential. + + For example, if a label map has 6 labels with IDs (3, 5, 9, 15, 16, 23), + then this will apply a :class:`~torchio.RemapLabels` transform with + ``remapping={3: 1, 5: 2, 9: 3, 15: 4, 16: 5, 23: 6}``. + This transformation is always `fully invertible `_. + + Args: + masking_method: See + :class:`~torchio.RemapLabels`. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + """ + def __init__( + self, + masking_method: TypeMaskingMethod = None, + **kwargs + ): + super().__init__(**kwargs) + self.masking_method = masking_method + self.args_names = [] + + def apply_transform(self, subject): + images_dict = subject.get_images_dict( + intensity_only=False, + include=self.include, + exclude=self.exclude, + ) + for name, image in images_dict.items(): + if not isinstance(image, LabelMap): + continue + + unique_labels = torch.unique(image.data) + remapping = { + unique_labels[i].item(): i + for i in range(1, len(unique_labels)) + } + transform = RemapLabels( + remapping=remapping, + masking_method=self.masking_method, + include=name, + ) + subject = transform(subject) + + return subject diff --git a/torchio/transforms/preprocessing/spatial/bounds_transform.py b/torchio/transforms/preprocessing/spatial/bounds_transform.py index f8f3d6b5a..881587067 100644 --- a/torchio/transforms/preprocessing/spatial/bounds_transform.py +++ b/torchio/transforms/preprocessing/spatial/bounds_transform.py @@ -1,17 +1,7 @@ -from typing import Union, Tuple -import numpy as np -from ....typing import TypeTripletInt +from ....transforms.transform import TypeBounds from ... import SpatialTransform -TypeSixBounds = Tuple[int, int, int, int, int, int] -TypeBounds = Union[ - int, - TypeTripletInt, - TypeSixBounds, -] - - class BoundsTransform(SpatialTransform): """Base class for transforms that change image bounds. @@ -30,32 +20,3 @@ def __init__( def is_invertible(self): return True - - @staticmethod - def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds: - try: - bounds_parameters = tuple(bounds_parameters) - except TypeError: - bounds_parameters = (bounds_parameters,) - - # Check that numbers are integers - for number in bounds_parameters: - if not isinstance(number, (int, np.integer)) or number < 0: - message = ( - 'Bounds values must be integers greater or equal to zero,' - f' not "{bounds_parameters}" of type {type(number)}' - ) - raise ValueError(message) - bounds_parameters = tuple(int(n) for n in bounds_parameters) - bounds_parameters_length = len(bounds_parameters) - if bounds_parameters_length == 6: - return bounds_parameters - if bounds_parameters_length == 1: - return 6 * bounds_parameters - if bounds_parameters_length == 3: - return tuple(np.repeat(bounds_parameters, 2).tolist()) - message = ( - 'Bounds parameter must be an integer or a tuple of' - f' 3 or 6 integers, not {bounds_parameters}' - ) - raise ValueError(message) diff --git a/torchio/transforms/preprocessing/spatial/crop_or_pad.py b/torchio/transforms/preprocessing/spatial/crop_or_pad.py index 471c88789..d1af4ac4a 100644 --- a/torchio/transforms/preprocessing/spatial/crop_or_pad.py +++ b/torchio/transforms/preprocessing/spatial/crop_or_pad.py @@ -5,7 +5,8 @@ from .pad import Pad from .crop import Crop -from .bounds_transform import BoundsTransform, TypeTripletInt, TypeSixBounds +from .bounds_transform import BoundsTransform +from ...transform import TypeTripletInt, TypeSixBounds from ....data.subject import Subject from ....utils import round_up diff --git a/torchio/transforms/transform.py b/torchio/transforms/transform.py index 40af6f96b..384217f52 100644 --- a/torchio/transforms/transform.py +++ b/torchio/transforms/transform.py @@ -12,10 +12,19 @@ from ..utils import to_tuple from ..data.subject import Subject from ..data.io import nib_to_sitk, sitk_to_nib -from ..typing import TypeData, TypeNumber, TypeKeys +from ..data.image import LabelMap +from ..typing import TypeData, TypeNumber, TypeKeys, TypeCallable, TypeTripletInt from .interpolation import Interpolation, get_sitk_interpolator from .data_parser import DataParser, TypeTransformInput +TypeSixBounds = Tuple[int, int, int, int, int, int] +TypeBounds = Union[ + int, + TypeTripletInt, + TypeSixBounds, +] +TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None] + class Transform(ABC): """Abstract class for all TorchIO transforms. @@ -121,11 +130,13 @@ def apply_transform(self, subject: Subject): def add_transform_to_subject_history(self, subject): from .augmentation import RandomTransform from . import Compose, OneOf, CropOrPad + from .preprocessing.label import SequentialLabels call_others = ( RandomTransform, Compose, OneOf, CropOrPad, + SequentialLabels, ) if not isinstance(self, call_others): subject.add_transform(self, self._get_reproducing_arguments()) @@ -311,7 +322,7 @@ def _get_reproducing_arguments(self): Return a dictionary with the arguments that would be necessary to reproduce the transform exactly. """ - reproducing_arguments = dict(include=self.include, exclude=self.exclude, copy=self.copy) + reproducing_arguments = {'include': self.include, 'exclude': self.exclude, 'copy': self.copy} reproducing_arguments.update({name: getattr(self, name) for name in self.args_names}) return reproducing_arguments @@ -337,3 +348,98 @@ def _use_seed(seed): @staticmethod def get_sitk_interpolator(interpolation: str) -> int: return get_sitk_interpolator(interpolation) + + @staticmethod + def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds: + try: + bounds_parameters = tuple(bounds_parameters) + except TypeError: + bounds_parameters = (bounds_parameters,) + + # Check that numbers are integers + for number in bounds_parameters: + if not isinstance(number, (int, np.integer)) or number < 0: + message = ( + 'Bounds values must be integers greater or equal to zero,' + f' not "{bounds_parameters}" of type {type(number)}' + ) + raise ValueError(message) + bounds_parameters = tuple(int(n) for n in bounds_parameters) + bounds_parameters_length = len(bounds_parameters) + if bounds_parameters_length == 6: + return bounds_parameters + if bounds_parameters_length == 1: + return 6 * bounds_parameters + if bounds_parameters_length == 3: + return tuple(np.repeat(bounds_parameters, 2).tolist()) + message = ( + 'Bounds parameter must be an integer or a tuple of' + f' 3 or 6 integers, not {bounds_parameters}' + ) + raise ValueError(message) + + @staticmethod + def ones(tensor: torch.Tensor) -> torch.Tensor: + return torch.ones_like(tensor, dtype=torch.bool) + + @staticmethod + def mean(tensor: torch.Tensor) -> torch.Tensor: + mask = tensor > tensor.mean() + return mask + + @staticmethod + def get_mask(masking_method: TypeMaskingMethod, subject: Subject, tensor: torch.Tensor) -> torch.Tensor: + if masking_method is None: + return Transform.ones(tensor) + elif callable(masking_method): + return masking_method(tensor) + elif type(masking_method) is str: + if masking_method in subject and isinstance(subject[masking_method], LabelMap): + return subject[masking_method].data.bool() + if masking_method.title() in ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'): + return Transform.get_mask_from_anatomical_label(masking_method.title(), tensor) + elif type(masking_method) in (tuple, list, int): + return Transform.get_mask_from_bounds(masking_method, tensor) + message = ( + 'Masking method parameter must be a function, a label map name,' + " an anatomical label: ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior')," + ' or a bounds parameter: (an int, tuple of 3 ints, or tuple of 6 ints)' + f' not {masking_method} of type {type(masking_method)}' + ) + raise ValueError(message) + + @staticmethod + def get_mask_from_anatomical_label(anatomical_label: str, tensor: torch.Tensor) -> torch.Tensor: + anatomical_label = anatomical_label.title() + if anatomical_label.title() not in ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'): + message = ( + "Anatomical label must be one of ('Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior')" + f' not {anatomical_label}' + ) + raise ValueError(message) + mask = torch.zeros_like(tensor, dtype=torch.bool) + _, width, height, depth = tensor.shape + if anatomical_label == 'Right': + mask[:, width // 2:] = True + elif anatomical_label == 'Left': + mask[:, :width // 2] = True + elif anatomical_label == 'Anterior': + mask[:, :, height // 2:] = True + elif anatomical_label == 'Posterior': + mask[:, :, :height // 2] = True + elif anatomical_label == 'Superior': + mask[:, :, :, depth // 2:] = True + elif anatomical_label == 'Inferior': + mask[:, :, :, :depth // 2] = True + return mask + + @staticmethod + def get_mask_from_bounds(bounds_parameters: TypeBounds, tensor: torch.Tensor) -> torch.Tensor: + bounds_parameters = Transform.parse_bounds(bounds_parameters) + low = bounds_parameters[::2] + high = bounds_parameters[1::2] + i0, j0, k0 = low + i1, j1, k1 = np.array(tensor.shape[1:]) - high + mask = torch.zeros_like(tensor, dtype=torch.bool) + mask[:, i0:i1, j0:j1, k0:k1] = True + return mask diff --git a/torchio/utils.py b/torchio/utils.py index 7dd043009..776d07b7d 100644 --- a/torchio/utils.py +++ b/torchio/utils.py @@ -213,5 +213,5 @@ def history_collate(batch: Sequence, collate_transforms=True): def get_subclasses(target_class: type) -> List[type]: subclasses = target_class.__subclasses__() - subclasses += sum([get_subclasses(cls) for cls in subclasses], []) + subclasses += sum((get_subclasses(cls) for cls in subclasses), []) return subclasses