Skip to content

Commit

Permalink
Merge 4a82804 into 05bac5e
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jan 15, 2021
2 parents 05bac5e + 4a82804 commit 7cb8859
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 22 deletions.
10 changes: 8 additions & 2 deletions docs/source/transforms/preprocessing.rst
Expand Up @@ -89,11 +89,10 @@ Spatial
.. autoclass:: ToCanonical
:show-inheritance:


Label
---------

.. currentmodule:: torchio.transforms.preprocessing.label


:class:`RemapLabels`
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -114,3 +113,10 @@ Label

.. autoclass:: SequentialLabels
:show-inheritance:


:class:`OneHot`
~~~~~~~~~~~~~~~

.. autoclass:: OneHot
:show-inheritance:
10 changes: 6 additions & 4 deletions torchio/transforms/__init__.py
Expand Up @@ -36,9 +36,10 @@
from .preprocessing import EnsureShapeMultiple
from .preprocessing import HistogramStandardization
from .preprocessing.intensity.histogram_standardization import train_histogram
from .preprocessing.label.remap_labels import RemapLabels
from .preprocessing.label.sequential_labels import SequentialLabels
from .preprocessing.label.remove_labels import RemoveLabels
from .preprocessing import OneHot
from .preprocessing import RemapLabels
from .preprocessing import RemoveLabels
from .preprocessing import SequentialLabels


__all__ = [
Expand Down Expand Up @@ -84,7 +85,8 @@
'CropOrPad',
'EnsureShapeMultiple',
'train_histogram',
'OneHot',
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
'SequentialLabels',
]
6 changes: 4 additions & 2 deletions torchio/transforms/preprocessing/__init__.py
Expand Up @@ -9,9 +9,10 @@
from .intensity.z_normalization import ZNormalization
from .intensity.histogram_standardization import HistogramStandardization

from .label.one_hot import OneHot
from .label.remap_labels import RemapLabels
from .label.sequential_labels import SequentialLabels
from .label.remove_labels import RemoveLabels
from .label.sequential_labels import SequentialLabels


__all__ = [
Expand All @@ -24,7 +25,8 @@
'ZNormalization',
'RescaleIntensity',
'HistogramStandardization',
'OneHot',
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
'SequentialLabels',
]
10 changes: 0 additions & 10 deletions torchio/transforms/preprocessing/label/__init__.py
@@ -1,10 +0,0 @@
from .remap_labels import RemapLabels
from .sequential_labels import SequentialLabels
from .remove_labels import RemoveLabels


__all__ = [
'RemapLabels',
'SequentialLabels',
'RemoveLabels',
]
30 changes: 30 additions & 0 deletions torchio/transforms/preprocessing/label/one_hot.py
@@ -0,0 +1,30 @@
import torch.nn.functional as F # noqa: N812

from .label_transform import LabelTransform


class OneHot(LabelTransform):
r"""Reencode label maps using one-hot encoding.
Args:
num_classes: See :func:`~torch.nn.functional.one_hot`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
def __init__(
self,
num_classes: int = -1,
**kwargs
):
super().__init__(**kwargs)
self.num_classes = num_classes
self.args_names = []

def apply_transform(self, subject):
for image in self.get_images(subject):
assert image.data.ndim == 4 and image.data.shape[0] == 1
data = image.data.squeeze()
num_classes = -1 if self.num_classes is None else self.num_classes
one_hot = F.one_hot(data.long(), num_classes=num_classes)
image.set_data(one_hot.permute(3, 0, 1, 2).type(data.type()))
return subject
2 changes: 1 addition & 1 deletion torchio/transforms/preprocessing/label/remove_labels.py
Expand Up @@ -13,7 +13,7 @@ class RemoveLabels(RemapLabels):
labels: A sequence of label integers that will be removed.
background_label: integer that specifies which label is considered to
be background (generally 0).
masking_method: See :class:`~torchio.RemapLabels`.
masking_method: See :class:`~torchio.transforms.RemapLabels`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
Expand Down
3 changes: 1 addition & 2 deletions torchio/transforms/preprocessing/label/sequential_labels.py
Expand Up @@ -14,8 +14,7 @@ class SequentialLabels(LabelTransform):
This transformation is always `fully invertible <invertibility>`_.
Args:
masking_method: See
:class:`~torchio.RemapLabels`.
masking_method: See :class:`~torchio.transforms.RemapLabels`.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchio/transforms/transform.py
Expand Up @@ -157,7 +157,7 @@ def apply_transform(self, subject: Subject) -> Subject:
def add_transform_to_subject_history(self, subject):
from .augmentation import RandomTransform
from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
from .preprocessing.label import SequentialLabels
from .preprocessing import SequentialLabels
call_others = (
RandomTransform,
Compose,
Expand Down

0 comments on commit 7cb8859

Please sign in to comment.