Skip to content

Commit

Permalink
Add label remapping transforms (#402)
Browse files Browse the repository at this point in the history
* Fix KeyError when retrieving history for a custom Transform.

* Preserve include and exclude lists for inverse transforms.

* Fix KeyError when retrieving history for a custom Transform.

Preserve include and exclude lists for inverse transforms.

* Add label transformations.

* Add minor edits

Co-authored-by: Fernando <fepegar@gmail.com>
  • Loading branch information
efirdc and fepegar committed Dec 29, 2020
1 parent b9ac52d commit 6aebda0
Show file tree
Hide file tree
Showing 20 changed files with 447 additions and 86 deletions.
26 changes: 26 additions & 0 deletions docs/source/transforms/preprocessing.rst
Expand Up @@ -85,3 +85,29 @@ Spatial

.. autoclass:: ToCanonical
:show-inheritance:

Label
---------

.. currentmodule:: torchio.transforms.preprocessing.label


:class:`RemapLabels`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RemapLabels
:show-inheritance:


:class:`RemoveLabels`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RemoveLabels
:show-inheritance:


:class:`SequentialLabels`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: SequentialLabels
:show-inheritance:
Empty file.
17 changes: 17 additions & 0 deletions tests/transforms/label/test_remap_labels.py
@@ -0,0 +1,17 @@
from torchio.transforms import RemapLabels
from ...utils import TorchioTestCase


class TestRemapLabels(TorchioTestCase):
"""Tests for `RemapLabels`."""
def test_remap(self):
remapping = {1: 2, 2: 1, 5: 10, 6: 11}
remap_labels = RemapLabels(remapping=remapping)

subject = self.get_subject_with_labels(labels=remapping.keys())
transformed = remap_labels(subject)
inverse_transformed = transformed.apply_inverse_transform()

self.assertEqual(self.get_unique_labels(subject.label), set(remapping.keys()))
self.assertEqual(self.get_unique_labels(transformed.label), set(remapping.values()))
self.assertEqual(self.get_unique_labels(inverse_transformed.label), set(remapping.keys()))
28 changes: 28 additions & 0 deletions tests/transforms/label/test_remove_labels.py
@@ -0,0 +1,28 @@
from torchio.transforms import RemoveLabels
from ...utils import TorchioTestCase


class TestRemoveLabels(TorchioTestCase):
"""Tests for `RemoveLabels`."""
def test_remove(self):
initial_labels = (1, 2, 3, 4, 5, 6, 7)
labels_to_remove = (1, 2, 5, 6)
remaining_labels = (3, 4, 7)

remove_labels = RemoveLabels(labels_to_remove)

subject = self.get_subject_with_labels(labels=initial_labels)
transformed = remove_labels(subject)
inverse_transformed = transformed.apply_inverse_transform(warn=False)
self.assertEqual(
self.get_unique_labels(subject.label),
set(initial_labels),
)
self.assertEqual(
self.get_unique_labels(transformed.label),
set(remaining_labels),
)
self.assertEqual(
self.get_unique_labels(inverse_transformed.label),
set(remaining_labels),
)
19 changes: 19 additions & 0 deletions tests/transforms/label/test_sequential_labels.py
@@ -0,0 +1,19 @@
from torchio.transforms import SequentialLabels
from ...utils import TorchioTestCase


class TestSequentialLabels(TorchioTestCase):
"""Tests for `SequentialLabels`."""
def test_sequential(self):
initial_labels = (2, 8, 9, 10, 15, 20, 100)
transformed_labels = (1, 2, 3, 4, 5, 6, 7)

sequential_labels = SequentialLabels()

subject = self.get_subject_with_labels(labels=initial_labels)
transformed = sequential_labels(subject)
inverse_transformed = transformed.apply_inverse_transform()

self.assertEqual(self.get_unique_labels(subject.label), set(initial_labels))
self.assertEqual(self.get_unique_labels(transformed.label), set(transformed_labels))
self.assertEqual(self.get_unique_labels(inverse_transformed.label), set(initial_labels))
5 changes: 4 additions & 1 deletion tests/transforms/test_transforms.py
Expand Up @@ -44,6 +44,9 @@ def get_transform(self, channels, is_3d=True, labels=True):
tio.RandomAffine(): 3,
elastic: 1,
}),
tio.RemapLabels(remapping={1: 2, 2: 1, 3: 20, 4: 25}, masking_method='Left'),
tio.RemoveLabels([1, 3]),
tio.SequentialLabels(),
tio.Pad(pad_args, padding_mode=3),
tio.Crop(crop_args),
]
Expand Down Expand Up @@ -121,7 +124,7 @@ def test_transforms_subject_4d(self):
transformed = transform(subject)
trsf_channels = len(transformed.t1.data)
assert trsf_channels > 1, f'Lost channels in {transform.name}'
if transform.name != 'RandomLabelsToImage':
if transform.name not in ['RandomLabelsToImage', 'RemapLabels', 'RemoveLabels', 'SequentialLabels']:
self.assertEqual(
subject.shape[0],
transformed.shape[0],
Expand Down
25 changes: 24 additions & 1 deletion tests/utils.py
Expand Up @@ -119,6 +119,20 @@ def get_subject_with_partial_volume_label_map(self, components=1):
),
)

def get_subject_with_labels(self, labels):
return tio.Subject(
label=tio.LabelMap(
self.get_image_path(
'label_multi', labels=labels
)
)
)

def get_unique_labels(self, label_map):
labels = torch.unique(label_map.data)
labels = {i.item() for i in labels if i != 0}
return labels

def tearDown(self):
"""Tear down test fixtures, if any."""
shutil.rmtree(self.dir)
Expand All @@ -131,6 +145,7 @@ def get_image_path(
self,
stem,
binary=False,
labels=None,
shape=(10, 20, 30),
spacing=(1, 1, 1),
components=1,
Expand All @@ -144,6 +159,14 @@ def get_image_path(
data = (data > 0.5).astype(np.uint8)
if not data.sum() and force_binary_foreground:
data[..., 0] = 1
elif labels is not None:
data = (data * (len(labels) + 1)).astype(np.uint8)
new_data = np.zeros_like(data)
for i, label in enumerate(labels):
new_data[data == (i + 1)] = label
if not (new_data == label).sum():
new_data[..., i] = label
data = new_data
elif self.flip_coin(): # cast some images
data *= 100
dtype = np.uint8 if self.flip_coin() else np.uint16
Expand Down Expand Up @@ -171,7 +194,7 @@ def get_tests_data_dir(self):
return Path(__file__).parent / 'image_data'

def assertTensorNotEqual(self, *args, **kwargs): # noqa: N802
message_kwarg = dict(msg=args[2]) if len(args) == 3 else {}
message_kwarg = {'msg': args[2]} if len(args) == 3 else {}
with self.assertRaises(AssertionError, **message_kwarg):
self.assertTensorEqual(*args, **kwargs)

Expand Down
4 changes: 3 additions & 1 deletion torchio/data/subject.py
Expand Up @@ -128,7 +128,9 @@ def get_inverse_transform(self, warn=True) -> 'Transform':
return self.get_composed_history().inverse(warn=warn)

def apply_inverse_transform(self, warn=True) -> 'Subject':
return self.get_inverse_transform(warn=warn)(self)
transformed = self.get_inverse_transform(warn=warn)(self)
transformed.clear_history()
return transformed

def clear_history(self) -> None:
self.applied_transforms = []
Expand Down
6 changes: 6 additions & 0 deletions torchio/transforms/__init__.py
Expand Up @@ -35,6 +35,9 @@
from .preprocessing import RescaleIntensity
from .preprocessing import HistogramStandardization
from .preprocessing.intensity.histogram_standardization import train as train_histogram
from .preprocessing.label.remap_labels import RemapLabels
from .preprocessing.label.sequential_labels import SequentialLabels
from .preprocessing.label.remove_labels import RemoveLabels


__all__ = [
Expand Down Expand Up @@ -79,4 +82,7 @@
'RescaleIntensity',
'CropOrPad',
'train_histogram',
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
]
3 changes: 2 additions & 1 deletion torchio/transforms/augmentation/composition.py
Expand Up @@ -69,7 +69,8 @@ def inverse(self, warn: bool = True) -> Transform:
result = Compose(transforms)
else: # return noop if no invertible transforms are found
def result(x): return x # noqa: E704
warnings.warn('No invertible transforms found', RuntimeWarning)
if warn:
warnings.warn('No invertible transforms found', RuntimeWarning)
return result


Expand Down
7 changes: 7 additions & 0 deletions torchio/transforms/preprocessing/__init__.py
Expand Up @@ -8,6 +8,10 @@
from .intensity.z_normalization import ZNormalization
from .intensity.histogram_standardization import HistogramStandardization

from .label.remap_labels import RemapLabels
from .label.sequential_labels import SequentialLabels
from .label.remove_labels import RemoveLabels


__all__ = [
'Pad',
Expand All @@ -18,4 +22,7 @@
'RescaleIntensity',
'ZNormalization',
'HistogramStandardization',
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
]
@@ -1,24 +1,25 @@
from typing import Union
import torch
from ....data.subject import Subject
from ....typing import TypeCallable
from ....transforms.transform import Transform
from ....transforms.transform import TypeMaskingMethod
from ... import IntensityTransform


TypeMaskingMethod = Union[str, TypeCallable, None]


class NormalizationTransform(IntensityTransform):
"""Base class for intensity preprocessing transforms.
Args:
masking_method: Defines the mask used to compute the normalization statistics. It can be one of:
- ``None``: the mask image is all ones, i.e. all values in the image are used
- ``None``: the mask image is all ones, i.e. all values in the image are used.
- A string: key to a :class:`torchio.LabelMap` in the subject which is used as a mask,
OR an anatomical label: ``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
``'Inferior'``, ``'Superior'`` which specifies a side of the mask volume to be ones.
- A string: the mask image is retrieved from the subject, which is expected the string as a key
- A function: the mask image is computed as a function of the intensity image.
The function must receive and return a :class:`torch.Tensor`
- A function: the mask image is computed as a function of the intensity image. The function must receive and return a :class:`torch.Tensor`
**kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments.
Example:
Expand All @@ -39,32 +40,12 @@ def __init__(
masking_method: TypeMaskingMethod = None,
**kwargs
):
"""
masking_method is used to choose the values used for normalization.
It can be:
- A string: the mask will be retrieved from the subject
- A function: the mask will be computed using the function
- None: all values are used
"""
super().__init__(**kwargs)
self.mask_name = None
self.masking_method = masking_method
if masking_method is None:
self.masking_method = self.ones
elif callable(masking_method):
self.masking_method = masking_method
elif isinstance(masking_method, str):
self.mask_name = masking_method

def get_mask(self, subject: Subject, tensor: torch.Tensor) -> torch.Tensor:
if self.mask_name is None:
return self.masking_method(tensor)
else:
return subject[self.mask_name].data.bool()

def apply_transform(self, subject: Subject) -> Subject:
for image_name, image in self.get_images_dict(subject).items():
mask = self.get_mask(subject, image.data)
mask = Transform.get_mask(self.masking_method, subject, image.data)
self.apply_normalization(subject, image_name, mask)
return subject

Expand All @@ -76,12 +57,3 @@ def apply_normalization(
) -> None:
# There must be a nicer way of doing this
raise NotImplementedError

@staticmethod
def ones(tensor: torch.Tensor) -> torch.Tensor:
return torch.ones_like(tensor, dtype=torch.bool)

@staticmethod
def mean(tensor: torch.Tensor) -> torch.Tensor:
mask = tensor > tensor.mean()
return mask
10 changes: 10 additions & 0 deletions torchio/transforms/preprocessing/label/__init__.py
@@ -0,0 +1,10 @@
from .remap_labels import RemapLabels
from .sequential_labels import SequentialLabels
from .remove_labels import RemoveLabels


__all__ = [
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
]

0 comments on commit 6aebda0

Please sign in to comment.