diff --git a/docs/source/data/patch_based.rst b/docs/source/data/patch_based.rst index d8df9d582..2e69ad2fd 100644 --- a/docs/source/data/patch_based.rst +++ b/docs/source/data/patch_based.rst @@ -10,7 +10,7 @@ volumes for training and testing. .. toctree:: - :maxdepth: 2 + :maxdepth: 3 patch_training.rst patch_inference.rst diff --git a/docs/source/data/patch_training.rst b/docs/source/data/patch_training.rst index 44c81c3bc..e6edb27e8 100644 --- a/docs/source/data/patch_training.rst +++ b/docs/source/data/patch_training.rst @@ -1,11 +1,14 @@ Training ======== -Random samplers ---------------- +Patch samplers +-------------- + +Samplers are used to randomly extract patches from volumes. +They are called with a sample generated by an +:py:class:`~torchio.ImagesDataset` and return a Python generator that yields +cropped versions of the sample. -TorchIO includes grid, uniform and label patch samplers. There is also an -aggregator used for dense predictions. For more information about patch-based training, see `this NiftyNet tutorial `_. @@ -13,11 +16,24 @@ For more information about patch-based training, see .. currentmodule:: torchio.data -:class:`ImageSampler` +:class:`PatchSampler` ^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: ImageSampler - :members: +.. autoclass:: PatchSampler + :show-inheritance: + + +:class:`WeightedSampler` +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: WeightedSampler + :show-inheritance: + + +:class:`UniformSampler` +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: UniformSampler :show-inheritance: @@ -25,22 +41,12 @@ For more information about patch-based training, see ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: LabelSampler - :members: :show-inheritance: Queue ----- -In the following animation, :attr:`shuffle_subjects` is ``False`` -and :attr:`shuffle_patches` is ``True``. - -.. raw:: html - - - - - .. currentmodule:: torchio.data diff --git a/examples/example_heteromodal.py b/examples/example_heteromodal.py index 7c1d9ed6c..898451372 100644 --- a/examples/example_heteromodal.py +++ b/examples/example_heteromodal.py @@ -12,7 +12,7 @@ import torchio from torchio import Image, Subject, ImagesDataset, Queue -from torchio.data import ImageSampler +from torchio.data import UniformSampler def main(): # Define training and patches sampling parameters @@ -45,8 +45,7 @@ def main(): subjects_dataset, queue_length, samples_per_volume, - patch_size, - ImageSampler, + UniformSampler(patch_size), ) # This collate_fn is needed in the case of missing modalities diff --git a/tests/data/inference/__init__.py b/tests/data/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/test_grid_sampler.py b/tests/data/inference/test_grid_sampler.py similarity index 96% rename from tests/data/test_grid_sampler.py rename to tests/data/inference/test_grid_sampler.py index 261dec831..b0e4dbcd9 100644 --- a/tests/data/test_grid_sampler.py +++ b/tests/data/inference/test_grid_sampler.py @@ -1,7 +1,8 @@ #!/usr/bin/env python -from ..utils import TorchioTestCase from torchio.data import GridSampler +from ...utils import TorchioTestCase + class TestGridSampler(TorchioTestCase): """Tests for `GridSampler`.""" diff --git a/tests/data/sampler/__init__.py b/tests/data/sampler/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/sampler/test_label_sampler.py b/tests/data/sampler/test_label_sampler.py new file mode 100644 index 000000000..b9d398ae9 --- /dev/null +++ b/tests/data/sampler/test_label_sampler.py @@ -0,0 +1,26 @@ +import torch +import torchio +from torchio.data import LabelSampler +from ...utils import TorchioTestCase + + +class TestLabelSampler(TorchioTestCase): + """Tests for `LabelSampler` class.""" + + def test_label_sampler(self): + sampler = LabelSampler(5) + for patch in sampler(self.sample, num_patches=10): + patch_center = patch['label'][torchio.DATA][0, 2, 2, 2] + self.assertEqual(patch_center, 1) + + def test_label_probabilities(self): + labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, -1) + subject = torchio.Subject( + label=torchio.Image(tensor=labels, type=torchio.LABEL), + ) + sample = torchio.ImagesDataset([subject])[0] + probs_dict = {0: 0, 1: 50, 2: 25, 3: 25} + sampler = LabelSampler(5, 'label', label_probabilities=probs_dict) + probabilities = sampler.get_probability_map(sample) + fixture = torch.Tensor((0, 0, 2/12, 2/12, 3/12, 2/12, 0)) + assert torch.all(probabilities.squeeze().eq(fixture)) diff --git a/tests/data/sampler/test_weighted_sampler.py b/tests/data/sampler/test_weighted_sampler.py new file mode 100644 index 000000000..154d0025b --- /dev/null +++ b/tests/data/sampler/test_weighted_sampler.py @@ -0,0 +1,25 @@ +import torch +import torchio +from torchio.data import WeightedSampler +from ...utils import TorchioTestCase + + +class TestWeightedSampler(TorchioTestCase): + """Tests for `WeightedSampler` class.""" + + def test_weighted_sampler(self): + sample = self.get_sample((7, 7, 7)) + sampler = WeightedSampler(5, 'prob') + patch = next(iter(sampler(sample))) + self.assertEqual(tuple(patch['index_ini']), (1, 1, 1)) + + def get_sample(self, image_shape): + t1 = torch.rand(*image_shape) + prob = torch.zeros_like(t1) + prob[3, 3, 3] = 1 + subject = torchio.Subject( + t1=torchio.Image(tensor=t1), + prob=torchio.Image(tensor=prob), + ) + sample = torchio.ImagesDataset([subject])[0] + return sample diff --git a/tests/data/test_label_sampler.py b/tests/data/test_label_sampler.py deleted file mode 100644 index 801a27a76..000000000 --- a/tests/data/test_label_sampler.py +++ /dev/null @@ -1,13 +0,0 @@ -import tempfile -import unittest -from pathlib import Path -from ..utils import TorchioTestCase -from torchio.data import LabelSampler - - -class TestLabelSampler(TorchioTestCase): - """Tests for `LabelSampler` class.""" - - def test_label_sampler(self): - sampler = LabelSampler(self.sample, 5) - next(iter(sampler)) diff --git a/tests/data/test_queue.py b/tests/data/test_queue.py index 35aa9eadc..f5f077aba 100644 --- a/tests/data/test_queue.py +++ b/tests/data/test_queue.py @@ -1,5 +1,5 @@ from torch.utils.data import DataLoader -from torchio.data import ImageSampler +from torchio.data import UniformSampler from torchio import ImagesDataset, Queue, DATA from torchio.utils import create_dummy_dataset from ..utils import TorchioTestCase @@ -19,12 +19,13 @@ def setUp(self): def test_queue(self): subjects_dataset = ImagesDataset(self.subjects_list) + patch_size = 10 + sampler = UniformSampler(patch_size) queue_dataset = Queue( subjects_dataset, max_length=6, samples_per_volume=2, - patch_size=10, - sampler_class=ImageSampler, + sampler=sampler, num_workers=0, verbose=True, ) diff --git a/torchio/data/__init__.py b/torchio/data/__init__.py index 2c73fe180..968aa39cb 100644 --- a/torchio/data/__init__.py +++ b/torchio/data/__init__.py @@ -2,5 +2,5 @@ from .image import Image from .subject import Subject from .dataset import ImagesDataset -from .sampler import ImageSampler, LabelSampler from .inference import GridSampler, GridAggregator +from .sampler import PatchSampler, LabelSampler, WeightedSampler, UniformSampler diff --git a/torchio/data/queue.py b/torchio/data/queue.py index a3e13093e..adf94b443 100644 --- a/torchio/data/queue.py +++ b/torchio/data/queue.py @@ -1,12 +1,13 @@ import random import warnings -from typing import List, Iterator from itertools import islice +from typing import List, Iterator + from tqdm import trange from torch.utils.data import Dataset, DataLoader -from .. import TypeTuple + +from .sampler import PatchSampler from .dataset import ImagesDataset -from .sampler import ImageSampler class Queue(Dataset): @@ -17,16 +18,11 @@ class Queue(Dataset): :class:`~torchio.data.dataset.ImagesDataset`. max_length: Maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less - often, but more RAM is needed to store the patches. + often, but more CPU memory is needed to store the patches. samples_per_volume: Number of patches to extract from each volume. A small number of patches ensures a large variability in the queue, but training will be slower. - patch_size: Tuple of integers :math:`(d, h, w)` to generate patches - of size :math:`d \times h \times w`. - If a single number :math:`n` is provided, - :math:`d = h = w = n`. - sampler_class: An instance of :class:`~torchio.data.datasetampler` used - to define the patches sampling strategy. + sampler: A sampler used to extract patches from the volumes. num_workers: Number of subprocesses to use for data loading (as in :class:`torch.utils.data.DataLoader`). ``0`` means that the data will be loaded in the main process. @@ -37,6 +33,16 @@ class Queue(Dataset): queue. verbose: If ``True``, some debugging messages are printed. + This sketch can be used to experiment and understand how the queue works. + In this case, :attr:`shuffle_subjects` is ``False`` + and :attr:`shuffle_patches` is ``True``. + + .. raw:: html + + + + + .. note:: :attr:`num_workers` refers to the number of workers used to load and transform the volumes. Multiprocessing is not needed to pop patches from the queue. @@ -50,7 +56,7 @@ class Queue(Dataset): ... max_length=300, ... samples_per_volume=10, ... patch_size=96, - ... sampler_class=torchio.sampler.ImageSampler, + ... sampler=, ... num_workers=4, ... shuffle_subjects=True, ... shuffle_patches=True, @@ -69,8 +75,7 @@ def __init__( subjects_dataset: ImagesDataset, max_length: int, samples_per_volume: int, - patch_size: TypeTuple, - sampler_class: ImageSampler, + sampler: PatchSampler, num_workers: int = 0, shuffle_subjects: bool = True, shuffle_patches: bool = True, @@ -81,8 +86,7 @@ def __init__( self.shuffle_subjects = shuffle_subjects self.shuffle_patches = shuffle_patches self.samples_per_volume = samples_per_volume - self.sampler_class = sampler_class - self.patch_size = patch_size + self.sampler = sampler self.num_workers = num_workers self.verbose = verbose self.subjects_iterable = self.get_subjects_iterable() @@ -130,8 +134,7 @@ def iterations_per_epoch(self) -> int: return self.num_subjects * self.samples_per_volume def fill(self) -> None: - assert self.sampler_class is not None - assert self.patch_size is not None + assert self.sampler is not None if self.max_length % self.samples_per_volume != 0: message = ( f'Queue length ({self.max_length})' @@ -153,9 +156,9 @@ def fill(self) -> None: iterable = range(num_subjects_for_queue) for _ in iterable: subject_sample = self.get_next_subject_sample() - sampler = self.sampler_class(subject_sample, self.patch_size) - samples = list(islice(sampler, self.samples_per_volume)) - self.patches_list.extend(samples) + iterable = self.sampler(subject_sample) + patches = list(islice(iterable, self.samples_per_volume)) + self.patches_list.extend(patches) if self.shuffle_patches: random.shuffle(self.patches_list) diff --git a/torchio/data/sampler/__init__.py b/torchio/data/sampler/__init__.py index 5f43f4a88..9c1b17504 100644 --- a/torchio/data/sampler/__init__.py +++ b/torchio/data/sampler/__init__.py @@ -1,2 +1,4 @@ -from .sampler import ImageSampler from .label import LabelSampler +from .sampler import PatchSampler +from .uniform import UniformSampler +from .weighted import WeightedSampler diff --git a/torchio/data/sampler/label.py b/torchio/data/sampler/label.py index edb8b6dd6..365d4e2d6 100644 --- a/torchio/data/sampler/label.py +++ b/torchio/data/sampler/label.py @@ -1,51 +1,95 @@ -from .sampler import ImageSampler, crop -from ... import DATA, LABEL, TYPE -from ..subject import Subject +from typing import Dict, Optional +import torch -class LabelSampler(ImageSampler): - r"""Extract random patches containing labeled voxels. +from ...data.subject import Subject +from ...torchio import TypePatchSize, DATA, TYPE, LABEL +from .weighted import WeightedSampler - This iterable dataset yields patches that contain at least one voxel - without background. - It extracts the label data from the first image in the sample with type - :py:attr:`torchio.LABEL`. +class LabelSampler(WeightedSampler): + r"""Extract random patches with labeled voxels at their center. + + This sampler yields patches whose center value is greater than 0 + in the :py:attr:`label_name`. Args: - sample: Sample generated by a - :py:class:`~torchio.data.dataset.ImagesDataset`, from which image - patches will be extracted. - patch_size: Tuple of integers :math:`(d, h, w)` to generate patches - of size :math:`d \times h \times w`. - If a single number :math:`n` is provided, - :math:`d = h = w = n`. - - .. warning:: For now, this implementation is not efficient because it uses - brute force to look for foreground voxels. It the number of - non-background voxels is very small, this sampler will be slow. + patch_size: See :py:class:`~torchio.data.PatchSampler`. + label_name: Name of the label image in the sample that will be used to + generate the sampling probability map. If ``None``, the first image + of type :py:attr:`torchio.LABEL` found in the subject sample will be + used. + label_probabilities: Dictionary containing the probability that each + class will be sampled. Probabilities do not need to be normalized. + For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a + sampler whose patches centers will have 50% probability of being + labeled as ``1``, 25% of being ``2`` and 25% of being ``3``. + If ``None``, the label map is binarized and the value is set to + ``{0: 0, 1: 1}``. + + Example: + >>> import torchio + >>> subject = torchio.datasets.Colin27() + >>> subject + Colin27(Keys: ('t1', 'head', 'brain'); images: 3) + >>> sample = torchio.ImagesDataset([subject])[0] + >>> sampler = torchio.data.LabelSampler(64, 'brain') + >>> generator = sampler(sample) + >>> for patch in generator: + ... print(patch.shape) + + If you want a specific number of patches from a volume, e.g. 10: + + >>> generator = sampler(sample, num_patches=10) + >>> for patch in iterator: + ... print(patch.shape) + """ - @staticmethod - def get_first_label_image_dict(sample: Subject): - for image_dict in sample.get_images(intensity_only=False): - if image_dict[TYPE] == LABEL: - label_image_dict = image_dict - break + def __init__( + self, + patch_size: TypePatchSize, + label_name: Optional[str] = None, + label_probabilities: Optional[Dict[int, float]] = None, + ): + super().__init__(patch_size, probability_map=label_name) + self.label_probabilities_dict = label_probabilities + + def get_probability_map(self, sample: Subject) -> torch.Tensor: + if self.probability_map_name is None: + for image in sample.get_images(intensity_only=False): + if image[TYPE] == LABEL: + label_map_tensor = image[DATA] + break + elif self.probability_map_name in sample: + label_map_tensor = sample[self.probability_map_name][DATA] else: - raise ValueError('No images of type torchio.LABEL found in sample') - return label_image_dict - - def extract_patch(self): - has_label = False - label_image_data = self.get_first_label_image_dict(self.sample)[DATA] - while not has_label: - index_ini, index_fin = self.get_random_indices( - self.sample, self.patch_size) - patch_label = crop(label_image_data, index_ini, index_fin) - has_label = patch_label.sum() > 0 - cropped_sample = self.copy_and_crop( - self.sample, - index_ini, - index_fin, + message = ( + f'Image "{self.probability_map_name}"' + f' not found in subject sample: {sample}' + ) + raise KeyError(message) + if self.label_probabilities_dict is None: + return label_map_tensor > 0 + probability_map = self.get_probabilities_from_label_map( + label_map_tensor, + self.label_probabilities_dict, ) - return cropped_sample + return probability_map + + @staticmethod + def get_probabilities_from_label_map( + label_map: torch.Tensor, + label_probabilities_dict: Dict[int, float], + ) -> torch.Tensor: + """Create probability map according to label map probabilities.""" + probability_map = torch.zeros_like(label_map) + label_probs = torch.Tensor(list(label_probabilities_dict.values())) + normalized_probs = label_probs / label_probs.sum() + iterable = zip(label_probabilities_dict, normalized_probs) + for label, label_probability in iterable: + mask = label_map == label + label_size = mask.sum() + if not label_size: continue + prob_voxels = label_probability / label_size + probability_map[mask] = prob_voxels + return probability_map diff --git a/torchio/data/sampler/sampler.py b/torchio/data/sampler/sampler.py index 4572267a7..3fecb07cc 100644 --- a/torchio/data/sampler/sampler.py +++ b/torchio/data/sampler/sampler.py @@ -1,105 +1,57 @@ -import copy -from typing import Union, Sequence, Generator, Tuple +from typing import Tuple, Optional, Generator import numpy as np -import torch -from torch.utils.data import IterableDataset -from ...torchio import DATA +from ... import TypePatchSize +from ...data.subject import Subject from ...utils import to_tuple -from ..subject import Subject -class ImageSampler(IterableDataset): - r"""Extract random patches from a volume. +class PatchSampler: + r"""Base class for TorchIO samplers. Args: - sample: Sample generated by a - :py:class:`~torchio.data.dataset.ImagesDataset`, from which image - patches will be extracted. patch_size: Tuple of integers :math:`(d, h, w)` to generate patches of size :math:`d \times h \times w`. If a single number :math:`n` is provided, :math:`d = h = w = n`. """ - def __init__(self, sample: Subject, patch_size: Union[int, Sequence[int]]): - self.sample = sample - patch_size = to_tuple(patch_size, length=3) - self.patch_size = np.array(patch_size, dtype=np.uint16) - - def __iter__(self) -> Generator[Subject, None, None]: - while True: - yield self.extract_patch() - - def extract_patch(self) -> Subject: - index_ini, index_fin = self.get_random_indices( - self.sample, self.patch_size) - cropped_sample = self.copy_and_crop( - self.sample, - index_ini, - index_fin, - ) - return cropped_sample - - @staticmethod - def get_random_indices(sample: Subject, patch_size: Tuple[int, int, int]): - # Assume all images in sample have the same shape - sample.check_consistent_shape() - first_image_name = list(sample.keys())[0] - first_image_array = sample[first_image_name][DATA] - # first_image_array should have shape (1, H, W, D) - shape = np.array(first_image_array.shape[1:], dtype=np.uint16) - return get_random_indices_from_shape(shape, patch_size) - - @staticmethod - def copy_and_crop( + def __init__(self, patch_size: TypePatchSize): + patch_size_array = np.array(to_tuple(patch_size, length=3)) + if np.any(patch_size_array < 1): + message = ( + 'Patch dimensions must be positive integers,' + f' not {patch_size_array}' + ) + raise ValueError(message) + self.patch_size = patch_size_array.astype(np.uint16) + + def __call__( + self, sample: Subject, - index_ini: np.ndarray, - index_fin: np.ndarray, - ) -> dict: - cropped_sample = copy.deepcopy(sample) - iterable = sample.get_images_dict(intensity_only=False).items() - for image_name, image in iterable: - cropped_sample[image_name] = copy.deepcopy(image) - sample_image_dict = image - cropped_image_dict = cropped_sample[image_name] - cropped_image_dict[DATA] = crop( - sample_image_dict[DATA], index_ini, index_fin) - # torch doesn't like uint16 - cropped_sample['index_ini'] = index_ini.astype(int) - return cropped_sample - + num_patches: Optional[int] = None, + ) -> Generator[Subject, None, None]: + raise NotImplementedError -def crop( - image: Union[np.ndarray, torch.Tensor], - index_ini: np.ndarray, - index_fin: np.ndarray, - ) -> Union[np.ndarray, torch.Tensor]: - i_ini, j_ini, k_ini = index_ini - i_fin, j_fin, k_fin = index_fin - return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] + def get_probability_map(self, sample: Subject): + raise NotImplementedError + def extract_patch(self): + raise NotImplementedError -def get_random_indices_from_shape( - shape: Tuple[int, int, int], - patch_size: Tuple[int, int, int], - ) -> Tuple[np.ndarray, np.ndarray]: - shape_array = np.array(shape) - patch_size_array = np.array(patch_size) - max_index_ini = shape_array - patch_size_array - if (max_index_ini < 0).any(): - message = ( - f'Patch size {patch_size} must not be' - f' larger than image size {shape}' - ) - raise ValueError(message) - max_index_ini = max_index_ini.astype(np.uint16) - coordinates = [] - for max_coordinate in max_index_ini.tolist(): - if max_coordinate == 0: - coordinate = 0 - else: - coordinate = torch.randint(max_coordinate, size=(1,)).item() - coordinates.append(coordinate) - index_ini = np.array(coordinates, np.uint16) - index_fin = index_ini + patch_size_array - return index_ini, index_fin + @staticmethod + def get_crop_transform( + sample, + index_ini, + patch_size: TypePatchSize, + ): + from ...transforms.preprocessing.spatial.crop import Crop + shape = np.array(sample.spatial_shape, dtype=np.uint16) + index_ini = np.array(index_ini, dtype=np.uint16) + patch_size = np.array(patch_size, dtype=np.uint16) + index_fin = index_ini + patch_size + crop_ini = index_ini.tolist() + crop_fin = (shape - index_fin).tolist() + TypeBounds = Tuple[int, int, int, int, int, int] + start = () + cropping: TypeBounds = sum(zip(crop_ini, crop_fin), start) + return Crop(cropping) diff --git a/torchio/data/sampler/uniform.py b/torchio/data/sampler/uniform.py new file mode 100644 index 000000000..e47f3b746 --- /dev/null +++ b/torchio/data/sampler/uniform.py @@ -0,0 +1,17 @@ +import torch +from ...data.subject import Subject +from ...torchio import TypePatchSize +from .weighted import WeightedSampler + + +class UniformSampler(WeightedSampler): + """Randomly extract patches from a volume with uniform probability. + + Args: + patch_size: See :py:class:`~torchio.data.PatchSampler`. + """ + def __init__(self, patch_size: TypePatchSize): + super().__init__(patch_size) + + def get_probability_map(self, sample: Subject) -> torch.Tensor: + return torch.ones(sample.shape) diff --git a/torchio/data/sampler/weighted.py b/torchio/data/sampler/weighted.py new file mode 100644 index 000000000..46ae59373 --- /dev/null +++ b/torchio/data/sampler/weighted.py @@ -0,0 +1,267 @@ +from typing import Optional, Tuple, Generator + +import numpy as np + +import torch + +from ...torchio import TypePatchSize +from ..subject import Subject +from .sampler import PatchSampler + + + +class WeightedSampler(PatchSampler): + r"""Randomly extract patches from a volume given a probability map. + + The probability of sampling a patch centered on a specific voxel is the + value of that voxel in the probability map. The probabilities need not be + normalized. For example, voxels can have values 0, 1 and 5. Voxels with + value 0 will never be at the center of a patch. Voxels with value 5 will + have 5 times more chance of being at the center of a patch that voxels + with a value of 1. + + Args: + sample: Sample generated by a + :py:class:`~torchio.data.dataset.ImagesDataset`, from which image + patches will be extracted. + patch_size: See :py:class:`~torchio.data.PatchSampler`. + probability_map: Name of the image in the sample that will be used + as a probability map. + + Raises: + RuntimeError: If the probability map is empty. + + Example: + >>> import torchio + >>> subject = torchio.Subject( + ... t1=torchio.Image('t1_mri.nii.gz', type=torchio.INTENSITY), + ... sampling_map=torchio.Image('sampling.nii.gz', type=torchio.SAMPLING_MAP), + ... ) + >>> sample = torchio.ImagesDataset([subject])[0] + >>> patch_size = 64 + >>> sampler = torchio.data.WeightedSampler(patch_size, probability_map='sampling_map') + >>> for patch in sampler(sample): + ... print(patch['index_ini']) + + .. note:: The index of the center of a patch with even size :math:`s` is + arbitrarily set to :math:`s/2`. This is an implementation detail that + will typically not make any difference in practice. + + .. note:: Values of the probability map near the border will be set to 0 as + the center of the patch cannot be at the border (unless the patch has + size 1 or 2 along that axis). + + """ + def __init__( + self, + patch_size: TypePatchSize, + probability_map: Optional[str] = None, + ): + super().__init__(patch_size) + self.probability_map_name = probability_map + self.cdf = None + self.sort_indices = None + + def __call__( + self, + sample: Subject, + num_patches: Optional[int] = None, + ) -> Generator[Subject, None, None]: + sample.check_consistent_shape() + if np.any(self.patch_size > sample.spatial_shape): + message = ( + f'Patch size {tuple(self.patch_size)} cannot be' + f' larger than image size {tuple(sample.spatial_shape)}' + ) + raise RuntimeError(message) + probability_map = self.get_probability_map(sample) + probability_map = self.process_probability_map(probability_map) + cdf, sort_indices = self.get_cumulative_distribution_function( + probability_map) + + patches_left = num_patches if num_patches is not None else True + while patches_left: + yield self.extract_patch(sample, probability_map, cdf, sort_indices) + if num_patches is not None: + patches_left -= 1 + + def get_probability_map(self, sample: Subject) -> torch.Tensor: + if self.probability_map_name in sample: + data = sample[self.probability_map_name].data + else: + message = ( + f'Image "{self.probability_map_name}"' + f' not found in subject sample: {sample}' + ) + raise KeyError(message) + if torch.any(data < 0): + message = ( + 'Negative values found' + f' in probability map "{self.probability_map_name}"' + ) + raise ValueError(message) + return data + + def process_probability_map( + self, + probability_map: torch.Tensor, + ) -> np.ndarray: + # Using float32 can create cdf with maximum very far from 1, e.g. 0.92! + data = probability_map[0].numpy().astype(np.float64) + assert data.ndim == 3 + self.clear_probability_borders(data, self.patch_size) + total = data.sum() + if total == 0: + message = ( + 'Empty probability map found' + f' ({self.probability_map_name})' + ) + raise RuntimeError(message) + data /= total # normalize probabilities + return data + + @staticmethod + def clear_probability_borders( + probability_map: np.ndarray, + patch_size: TypePatchSize, + ) -> None: + # Set probability to 0 on voxels that wouldn't possibly be sampled given + # the current patch size + # We will arbitrarily define the center of an array with even length + # using the // Python operator + # For example, the center of an array (3, 4) will be on (1, 2) + # + # Patch center + # . . . . . . . . + # . . . . -> . . x . + # . . . . . . . . + # + # + # Prob. map After preprocessing + # + # x x x x x x x . . . . . . . + # x x x x x x x . . x x x x . + # x x x x x x x --> . . x x x x . + # x x x x x x x --> . . x x x x . + # x x x x x x x . . x x x x . + # x x x x x x x . . . . . . . + # + # The dots represent removed probabilities, x mark possible locations + crop_ini = patch_size // 2 + crop_fin = (patch_size - 1) // 2 + crop_i, crop_j, crop_k = crop_ini + probability_map[:crop_i, :, :] = 0 + probability_map[:, :crop_j, :] = 0 + probability_map[:, :, :crop_k] = 0 + + # The call tolist() is very important. Using np.uint16 as negative index + # will not work because e.g. -np.uint16(2) == 65534 + crop_i, crop_j, crop_k = crop_fin.tolist() + if crop_i: + probability_map[-crop_i:, :, :] = 0 + if crop_j: + probability_map[:, -crop_j:, :] = 0 + if crop_k: + probability_map[:, :, -crop_k:] = 0 + + @staticmethod + def get_cumulative_distribution_function( + probability_map: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """Return the CDF of a probability map. + + The cumulative distribution function (CDF) is computed as follows: + + 1. Flatten probability map + 2. Compute sorting indices + 3. Sort flattened map + 4. Compute cumulative sum + + For example, + if the probability map is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0], + the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4], + the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5], + and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0]. + """ + flat_map = probability_map.flatten() + flat_map_normalized = flat_map / flat_map.sum() + # Get the sorting indices to that we can invert the sorting later on + sort_indices = np.argsort(flat_map_normalized) + flat_map_normalized_sorted = flat_map_normalized[sort_indices] + cdf = np.cumsum(flat_map_normalized_sorted) + return cdf, sort_indices + + def extract_patch( + self, + sample: Subject, + probability_map: np.ndarray, + cdf: np.ndarray, + sort_indices: np.ndarray, + ) -> Subject: + index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices) + crop = self.get_crop_transform( + sample, + index_ini, + self.patch_size, + ) + cropped_sample = crop(sample) + cropped_sample['index_ini'] = index_ini.astype(int) + return cropped_sample + + def get_random_index_ini( + self, + probability_map: np.ndarray, + cdf: np.ndarray, + sort_indices: np.ndarray, + ) -> np.ndarray: + center = self.sample_probability_map(probability_map, cdf, sort_indices) + assert np.all(center >= 0) + # See self.clear_probability_borders + index_ini = center - self.patch_size // 2 + assert np.all(index_ini >= 0) + return index_ini + + def sample_probability_map( + self, + probability_map: np.ndarray, + cdf: np.ndarray, + sort_indices: np.ndarray, + ) -> np.ndarray: + """Inverse transform sampling. + + Example: + >>> probability_map = np.array( + ... ((0,0,1,1,5,2,1,1,0), + ... (2,2,2,2,2,2,2,2,2))) + >>> probability_map + array([[0, 0, 1, 1, 5, 2, 1, 1, 0], + [2, 2, 2, 2, 2, 2, 2, 2, 2]]) + >>> histogram = np.zeros_like(probability_map) + >>> for _ in range(100000): + ... histogram[sample_probability_map(probability_map)] += 1 + ... + >>> histogram + array([[ 0, 0, 3479, 3478, 17121, 7023, 3355, 3378, 0], + [ 6808, 6804, 6942, 6809, 6946, 6988, 7002, 6826, 7041]]) + + """ + # Get first value larger than random number + random_number = torch.rand(1).item() + # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92 + if random_number > cdf.max(): + cdf_index = -1 + else: # proceed as usual + cdf_index = np.argmax(random_number < cdf) + + random_location_index = sort_indices[cdf_index] + center = np.unravel_index( + random_location_index, + probability_map.shape + ) + + i, j, k = center + probability = probability_map[i, j, k] + assert probability > 0 + + center = np.array(center).astype(int) + return center diff --git a/torchio/torchio.py b/torchio/torchio.py index a6c2512fd..20b63e548 100644 --- a/torchio/torchio.py +++ b/torchio/torchio.py @@ -30,5 +30,6 @@ TypeTripletFloat = Tuple[float, float, float] TypeTuple = Union[int, TypeTripletInt] TypeRangeInt = Union[int, Tuple[int, int]] +TypePatchSize = Union[int, Tuple[int, int, int]] TypeRangeFloat = Union[float, Tuple[float, float]] TypeCallable = Callable[[torch.Tensor], torch.Tensor] diff --git a/torchio/transforms/augmentation/intensity/random_swap.py b/torchio/transforms/augmentation/intensity/random_swap.py index db87cfd01..c58790805 100644 --- a/torchio/transforms/augmentation/intensity/random_swap.py +++ b/torchio/transforms/augmentation/intensity/random_swap.py @@ -1,10 +1,9 @@ -from typing import Optional +from typing import Optional, Tuple, Union import torch import numpy as np from ....data.subject import Subject from ....utils import to_tuple from ....torchio import DATA, TypeTuple, TypeData -from ....data.sampler.sampler import get_random_indices_from_shape, crop from .. import RandomTransform @@ -74,3 +73,39 @@ def insert(tensor: TypeData, patch: TypeData, index_ini: np.ndarray) -> None: i_ini, j_ini, k_ini = index_ini i_fin, j_fin, k_fin = index_fin tensor[i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch + + +def crop( + image: Union[np.ndarray, torch.Tensor], + index_ini: np.ndarray, + index_fin: np.ndarray, + ) -> Union[np.ndarray, torch.Tensor]: + i_ini, j_ini, k_ini = index_ini + i_fin, j_fin, k_fin = index_fin + return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] + + +def get_random_indices_from_shape( + shape: Tuple[int, int, int], + patch_size: Tuple[int, int, int], + ) -> Tuple[np.ndarray, np.ndarray]: + shape_array = np.array(shape) + patch_size_array = np.array(patch_size) + max_index_ini = shape_array - patch_size_array + if (max_index_ini < 0).any(): + message = ( + f'Patch size {patch_size} must not be' + f' larger than image size {shape}' + ) + raise ValueError(message) + max_index_ini = max_index_ini.astype(np.uint16) + coordinates = [] + for max_coordinate in max_index_ini.tolist(): + if max_coordinate == 0: + coordinate = 0 + else: + coordinate = torch.randint(max_coordinate, size=(1,)).item() + coordinates.append(coordinate) + index_ini = np.array(coordinates, np.uint16) + index_fin = index_ini + patch_size_array + return index_ini, index_fin diff --git a/torchio/transforms/augmentation/spatial/random_affine.py b/torchio/transforms/augmentation/spatial/random_affine.py index cca96095a..e783a3d75 100644 --- a/torchio/transforms/augmentation/spatial/random_affine.py +++ b/torchio/transforms/augmentation/spatial/random_affine.py @@ -5,7 +5,7 @@ import SimpleITK as sitk from ....data.subject import Subject from ....torchio import ( - LABEL, + INTENSITY, DATA, AFFINE, TYPE, @@ -160,7 +160,7 @@ def apply_transform(self, sample: Subject) -> dict: ) scaling_params, rotation_params, translation_params = params for image in sample.get_images(intensity_only=False): - if image[TYPE] == LABEL: + if image[TYPE] != INTENSITY: interpolation = Interpolation.NEAREST else: interpolation = self.interpolation diff --git a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py index 6649d1cea..5531e5e3e 100644 --- a/torchio/transforms/augmentation/spatial/random_elastic_deformation.py +++ b/torchio/transforms/augmentation/spatial/random_elastic_deformation.py @@ -6,7 +6,7 @@ import SimpleITK as sitk from ....data.subject import Subject from ....utils import to_tuple -from ....torchio import LABEL, DATA, AFFINE, TYPE +from ....torchio import INTENSITY, DATA, AFFINE, TYPE from .. import Interpolation, get_sitk_interpolator from .. import RandomTransform @@ -219,7 +219,7 @@ def apply_transform(self, sample: Subject) -> dict: self.num_locked_borders, ) for image in sample.get_images(intensity_only=False): - if image[TYPE] == LABEL: + if image[TYPE] != INTENSITY: interpolation = Interpolation.NEAREST else: interpolation = self.interpolation diff --git a/torchio/transforms/preprocessing/spatial/bounds_transform.py b/torchio/transforms/preprocessing/spatial/bounds_transform.py index c515b8197..911c100aa 100644 --- a/torchio/transforms/preprocessing/spatial/bounds_transform.py +++ b/torchio/transforms/preprocessing/spatial/bounds_transform.py @@ -19,7 +19,8 @@ class BoundsTransform(Transform): """Base class for transforms that change image bounds. Args: - bounds_parameters: + bounds_parameters: The meaning of this argument varies according to the + child class. p: Probability that this transform will be applied. """ @@ -47,7 +48,7 @@ def parse_bounds(bounds_parameters: TypeBounds) -> Tuple[int, ...]: if not isinstance(number, int) or number < 0: message = ( 'Bounds values must be integers greater or equal to zero,' - f' not "{bounds_parameters}"' + f' not "{bounds_parameters}" of type {type(number)}' ) raise ValueError(message) diff --git a/torchio/transforms/preprocessing/spatial/resample.py b/torchio/transforms/preprocessing/spatial/resample.py index 94f6ab399..958b802a1 100644 --- a/torchio/transforms/preprocessing/spatial/resample.py +++ b/torchio/transforms/preprocessing/spatial/resample.py @@ -9,7 +9,7 @@ from ....data.subject import Subject from ....data.image import Image -from ....torchio import LABEL, DATA, AFFINE, TYPE, INTENSITY +from ....torchio import DATA, AFFINE, TYPE, INTENSITY from ... import Interpolation from ... import Transform @@ -173,7 +173,7 @@ def apply_transform(self, sample: Subject) -> dict: continue # Choose interpolator - if image_dict[TYPE] == LABEL: + if image_dict[TYPE] != INTENSITY: interpolation_order = 0 # nearest neighbor else: interpolation_order = self.interpolation_order