From ba5be714b6d51081cb17741cfb52bed90818a5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Sun, 17 Jan 2021 23:30:53 +0000 Subject: [PATCH] Add OneHot transform (#420) * Add OneHot transform * Fix links in docstrings --- docs/source/transforms/preprocessing.rst | 10 +++++-- torchio/transforms/__init__.py | 10 ++++--- torchio/transforms/preprocessing/__init__.py | 6 ++-- .../preprocessing/label/__init__.py | 10 ------- .../transforms/preprocessing/label/one_hot.py | 30 +++++++++++++++++++ .../preprocessing/label/remove_labels.py | 2 +- .../preprocessing/label/sequential_labels.py | 3 +- torchio/transforms/transform.py | 2 +- 8 files changed, 51 insertions(+), 22 deletions(-) create mode 100644 torchio/transforms/preprocessing/label/one_hot.py diff --git a/docs/source/transforms/preprocessing.rst b/docs/source/transforms/preprocessing.rst index b34362dd9..2229b474b 100644 --- a/docs/source/transforms/preprocessing.rst +++ b/docs/source/transforms/preprocessing.rst @@ -89,11 +89,10 @@ Spatial .. autoclass:: ToCanonical :show-inheritance: + Label --------- -.. currentmodule:: torchio.transforms.preprocessing.label - :class:`RemapLabels` ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -114,3 +113,10 @@ Label .. autoclass:: SequentialLabels :show-inheritance: + + +:class:`OneHot` +~~~~~~~~~~~~~~~ + +.. autoclass:: OneHot + :show-inheritance: diff --git a/torchio/transforms/__init__.py b/torchio/transforms/__init__.py index 367a7670c..b89f6f2d8 100644 --- a/torchio/transforms/__init__.py +++ b/torchio/transforms/__init__.py @@ -36,9 +36,10 @@ from .preprocessing import EnsureShapeMultiple from .preprocessing import HistogramStandardization from .preprocessing.intensity.histogram_standardization import train_histogram -from .preprocessing.label.remap_labels import RemapLabels -from .preprocessing.label.sequential_labels import SequentialLabels -from .preprocessing.label.remove_labels import RemoveLabels +from .preprocessing import OneHot +from .preprocessing import RemapLabels +from .preprocessing import RemoveLabels +from .preprocessing import SequentialLabels __all__ = [ @@ -84,7 +85,8 @@ 'CropOrPad', 'EnsureShapeMultiple', 'train_histogram', + 'OneHot', 'RemapLabels', - 'SequentialLabels', 'RemoveLabels', + 'SequentialLabels', ] diff --git a/torchio/transforms/preprocessing/__init__.py b/torchio/transforms/preprocessing/__init__.py index 2c3ef9de5..dce8b7ad8 100644 --- a/torchio/transforms/preprocessing/__init__.py +++ b/torchio/transforms/preprocessing/__init__.py @@ -9,9 +9,10 @@ from .intensity.z_normalization import ZNormalization from .intensity.histogram_standardization import HistogramStandardization +from .label.one_hot import OneHot from .label.remap_labels import RemapLabels -from .label.sequential_labels import SequentialLabels from .label.remove_labels import RemoveLabels +from .label.sequential_labels import SequentialLabels __all__ = [ @@ -24,7 +25,8 @@ 'ZNormalization', 'RescaleIntensity', 'HistogramStandardization', + 'OneHot', 'RemapLabels', - 'SequentialLabels', 'RemoveLabels', + 'SequentialLabels', ] diff --git a/torchio/transforms/preprocessing/label/__init__.py b/torchio/transforms/preprocessing/label/__init__.py index ba4faee09..e69de29bb 100644 --- a/torchio/transforms/preprocessing/label/__init__.py +++ b/torchio/transforms/preprocessing/label/__init__.py @@ -1,10 +0,0 @@ -from .remap_labels import RemapLabels -from .sequential_labels import SequentialLabels -from .remove_labels import RemoveLabels - - -__all__ = [ - 'RemapLabels', - 'SequentialLabels', - 'RemoveLabels', -] diff --git a/torchio/transforms/preprocessing/label/one_hot.py b/torchio/transforms/preprocessing/label/one_hot.py new file mode 100644 index 000000000..af941fb65 --- /dev/null +++ b/torchio/transforms/preprocessing/label/one_hot.py @@ -0,0 +1,30 @@ +import torch.nn.functional as F # noqa: N812 + +from .label_transform import LabelTransform + + +class OneHot(LabelTransform): + r"""Reencode label maps using one-hot encoding. + + Args: + num_classes: See :func:`~torch.nn.functional.one_hot`. + **kwargs: See :class:`~torchio.transforms.Transform` for additional + keyword arguments. + """ + def __init__( + self, + num_classes: int = -1, + **kwargs + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.args_names = [] + + def apply_transform(self, subject): + for image in self.get_images(subject): + assert image.data.ndim == 4 and image.data.shape[0] == 1 + data = image.data.squeeze() + num_classes = -1 if self.num_classes is None else self.num_classes + one_hot = F.one_hot(data.long(), num_classes=num_classes) + image.set_data(one_hot.permute(3, 0, 1, 2).type(data.type())) + return subject diff --git a/torchio/transforms/preprocessing/label/remove_labels.py b/torchio/transforms/preprocessing/label/remove_labels.py index ccc5a0fdd..704275398 100644 --- a/torchio/transforms/preprocessing/label/remove_labels.py +++ b/torchio/transforms/preprocessing/label/remove_labels.py @@ -13,7 +13,7 @@ class RemoveLabels(RemapLabels): labels: A sequence of label integers that will be removed. background_label: integer that specifies which label is considered to be background (generally 0). - masking_method: See :class:`~torchio.RemapLabels`. + masking_method: See :class:`~torchio.transforms.RemapLabels`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ diff --git a/torchio/transforms/preprocessing/label/sequential_labels.py b/torchio/transforms/preprocessing/label/sequential_labels.py index 906767eff..3b8b4f1ef 100644 --- a/torchio/transforms/preprocessing/label/sequential_labels.py +++ b/torchio/transforms/preprocessing/label/sequential_labels.py @@ -14,8 +14,7 @@ class SequentialLabels(LabelTransform): This transformation is always `fully invertible `_. Args: - masking_method: See - :class:`~torchio.RemapLabels`. + masking_method: See :class:`~torchio.transforms.RemapLabels`. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ diff --git a/torchio/transforms/transform.py b/torchio/transforms/transform.py index 533e3bcf7..054cffab6 100644 --- a/torchio/transforms/transform.py +++ b/torchio/transforms/transform.py @@ -148,7 +148,7 @@ def apply_transform(self, subject: Subject) -> Subject: def add_transform_to_subject_history(self, subject): from .augmentation import RandomTransform from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple - from .preprocessing.label import SequentialLabels + from .preprocessing import SequentialLabels call_others = ( RandomTransform, Compose,