Skip to content

Commit

Permalink
Refactor samplers
Browse files Browse the repository at this point in the history
Add weighted sampler

Add some docs for weighted sampler

Rename probability map argument

Add weighted sampler to docs

Add features to samplers

Add abstract method get_probability_map()

Move tests

Add features, tests and docs for samplers

Use crop transform to extract patches

Add comment to bounds transform

Add type hint for samplers

Fix TypeError for Python <3.8
  • Loading branch information
fepegar committed May 31, 2020
1 parent 4133e10 commit d5a917d
Show file tree
Hide file tree
Showing 21 changed files with 502 additions and 199 deletions.
40 changes: 23 additions & 17 deletions docs/source/data/patch_training.rst
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
Training
========

Random samplers
---------------
Patch samplers
--------------

Samplers are used to randomly extract patches from volumes.
They are called with a sample generated by an
:py:class:`~torchio.ImagesDataset` and return a Python generator that yields
cropped versions of the sample.

TorchIO includes grid, uniform and label patch samplers. There is also an
aggregator used for dense predictions.
For more information about patch-based training, see
`this NiftyNet tutorial <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_.


.. currentmodule:: torchio.data


:class:`ImageSampler`
:class:`PatchSampler`
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ImageSampler
:members:
.. autoclass:: PatchSampler
:show-inheritance:


:class:`WeightedSampler`
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: WeightedSampler
:show-inheritance:


:class:`UniformSampler`
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: UniformSampler
:show-inheritance:


:class:`LabelSampler`
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LabelSampler
:members:
:show-inheritance:


Queue
-----

In the following animation, :attr:`shuffle_subjects` is ``False``
and :attr:`shuffle_patches` is ``True``.

.. raw:: html

<embed>
<iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
</embed>


.. currentmodule:: torchio.data

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python

from ..utils import TorchioTestCase
from torchio.data import GridSampler
from ...utils import TorchioTestCase


class TestGridSampler(TorchioTestCase):
"""Tests for `GridSampler`."""
Expand Down
Empty file added tests/data/sampler/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/data/sampler/test_label_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from torchio import DATA
from torchio.data import LabelSampler
from ...utils import TorchioTestCase


class TestLabelSampler(TorchioTestCase):
"""Tests for `LabelSampler` class."""

def test_label_sampler(self):
sampler = LabelSampler(5, 'label')
for patch in sampler(self.sample, num_patches=10):
self.assertEqual(patch['label'][DATA][0, 2, 2, 2], 1)
25 changes: 25 additions & 0 deletions tests/data/sampler/test_weighted_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torchio
from torchio.data import WeightedSampler
from ...utils import TorchioTestCase


class TestWeightedSampler(TorchioTestCase):
"""Tests for `WeightedSampler` class."""

def test_label_sampler(self):
sample = self.get_sample((7, 7, 7))
sampler = WeightedSampler(5, 'prob')
patch = next(iter(sampler(sample)))
self.assertEqual(tuple(patch['index_ini']), (1, 1, 1))

def get_sample(self, image_shape):
t1 = torch.rand(*image_shape)
prob = torch.zeros_like(t1)
prob[3, 3, 3] = 1
subject = torchio.Subject(
t1=torchio.Image(tensor=t1),
prob=torchio.Image(tensor=prob),
)
sample = torchio.ImagesDataset([subject])[0]
return sample
13 changes: 0 additions & 13 deletions tests/data/test_label_sampler.py

This file was deleted.

7 changes: 4 additions & 3 deletions tests/data/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch.utils.data import DataLoader
from torchio.data import ImageSampler
from torchio.data import UniformSampler
from torchio import ImagesDataset, Queue, DATA
from torchio.utils import create_dummy_dataset
from ..utils import TorchioTestCase
Expand All @@ -19,12 +19,13 @@ def setUp(self):

def test_queue(self):
subjects_dataset = ImagesDataset(self.subjects_list)
patch_size = 10
sampler = UniformSampler(patch_size)
queue_dataset = Queue(
subjects_dataset,
max_length=6,
samples_per_volume=2,
patch_size=10,
sampler_class=ImageSampler,
sampler=sampler,
num_workers=0,
verbose=True,
)
Expand Down
2 changes: 1 addition & 1 deletion torchio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .image import Image
from .subject import Subject
from .dataset import ImagesDataset
from .sampler import ImageSampler, LabelSampler
from .inference import GridSampler, GridAggregator
from .sampler import PatchSampler, LabelSampler, WeightedSampler, UniformSampler
43 changes: 23 additions & 20 deletions torchio/data/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
import warnings
from typing import List, Iterator
from itertools import islice
from typing import List, Iterator

from tqdm import trange
from torch.utils.data import Dataset, DataLoader
from .. import TypeTuple

from .sampler import PatchSampler
from .dataset import ImagesDataset
from .sampler import ImageSampler


class Queue(Dataset):
Expand All @@ -17,16 +18,11 @@ class Queue(Dataset):
:class:`~torchio.data.dataset.ImagesDataset`.
max_length: Maximum number of patches that can be stored in the queue.
Using a large number means that the queue needs to be filled less
often, but more RAM is needed to store the patches.
often, but more CPU memory is needed to store the patches.
samples_per_volume: Number of patches to extract from each volume.
A small number of patches ensures a large variability in the queue,
but training will be slower.
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided,
:math:`d = h = w = n`.
sampler_class: An instance of :class:`~torchio.data.datasetampler` used
to define the patches sampling strategy.
sampler: A sampler used to extract patches from the volumes.
num_workers: Number of subprocesses to use for data loading
(as in :class:`torch.utils.data.DataLoader`).
``0`` means that the data will be loaded in the main process.
Expand All @@ -37,6 +33,16 @@ class Queue(Dataset):
queue.
verbose: If ``True``, some debugging messages are printed.
This sketch can be used to experiment and understand how the queue works.
In this case, :attr:`shuffle_subjects` is ``False``
and :attr:`shuffle_patches` is ``True``.
.. raw:: html
<embed>
<iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
</embed>
.. note:: :attr:`num_workers` refers to the number of workers used to
load and transform the volumes. Multiprocessing is not needed to pop
patches from the queue.
Expand All @@ -50,7 +56,7 @@ class Queue(Dataset):
... max_length=300,
... samples_per_volume=10,
... patch_size=96,
... sampler_class=torchio.sampler.ImageSampler,
... sampler=,
... num_workers=4,
... shuffle_subjects=True,
... shuffle_patches=True,
Expand All @@ -69,8 +75,7 @@ def __init__(
subjects_dataset: ImagesDataset,
max_length: int,
samples_per_volume: int,
patch_size: TypeTuple,
sampler_class: ImageSampler,
sampler: PatchSampler,
num_workers: int = 0,
shuffle_subjects: bool = True,
shuffle_patches: bool = True,
Expand All @@ -81,8 +86,7 @@ def __init__(
self.shuffle_subjects = shuffle_subjects
self.shuffle_patches = shuffle_patches
self.samples_per_volume = samples_per_volume
self.sampler_class = sampler_class
self.patch_size = patch_size
self.sampler = sampler
self.num_workers = num_workers
self.verbose = verbose
self.subjects_iterable = self.get_subjects_iterable()
Expand Down Expand Up @@ -130,8 +134,7 @@ def iterations_per_epoch(self) -> int:
return self.num_subjects * self.samples_per_volume

def fill(self) -> None:
assert self.sampler_class is not None
assert self.patch_size is not None
assert self.sampler is not None
if self.max_length % self.samples_per_volume != 0:
message = (
f'Queue length ({self.max_length})'
Expand All @@ -153,9 +156,9 @@ def fill(self) -> None:
iterable = range(num_subjects_for_queue)
for _ in iterable:
subject_sample = self.get_next_subject_sample()
sampler = self.sampler_class(subject_sample, self.patch_size)
samples = list(islice(sampler, self.samples_per_volume))
self.patches_list.extend(samples)
iterable = self.sampler(subject_sample)
patches = list(islice(iterable, self.samples_per_volume))
self.patches_list.extend(patches)
if self.shuffle_patches:
random.shuffle(self.patches_list)

Expand Down
4 changes: 3 additions & 1 deletion torchio/data/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .sampler import ImageSampler
from .label import LabelSampler
from .sampler import PatchSampler
from .uniform import UniformSampler
from .weighted import WeightedSampler
86 changes: 42 additions & 44 deletions torchio/data/sampler/label.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
from .sampler import ImageSampler, crop
from ... import DATA, LABEL, TYPE
from ..subject import Subject
import torch
from ...data.subject import Subject
from ...torchio import TypePatchSize
from .weighted import WeightedSampler


class LabelSampler(ImageSampler):
r"""Extract random patches containing labeled voxels.
class LabelSampler(WeightedSampler):
r"""Extract random patches with labeled voxels at their center.
This iterable dataset yields patches that contain at least one voxel
without background.
It extracts the label data from the first image in the sample with type
:py:attr:`torchio.LABEL`.
This sampler yields patches whose center value is greater than 0
in the :py:attr:`label_name`.
Args:
sample: Sample generated by a
:py:class:`~torchio.data.dataset.ImagesDataset`, from which image
patches will be extracted.
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided,
:math:`d = h = w = n`.
.. warning:: For now, this implementation is not efficient because it uses
brute force to look for foreground voxels. It the number of
non-background voxels is very small, this sampler will be slow.
patch_size: See :py:class:`~torchio.data.PatchSampler`.
label_name: Name of the label image in the sample that will be used to
generate the sampling probability map.
Example:
>>> import torchio
>>> subject = torchio.datasets.Colin27()
>>> subject
Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
>>> sample = torchio.ImagesDataset([subject])[0]
>>> sampler = torchio.data.LabelSampler(64, 'brain')
>>> generator = sampler(sample)
>>> for patch in generator:
... print(patch.shape)
If you want a specific number of patches from a volume, e.g. 10:
>>> generator = sampler(sample, num_patches=10)
>>> for patch in iterator:
... print(patch.shape)
"""
@staticmethod
def get_first_label_image_dict(sample: Subject):
for image_dict in sample.get_images(intensity_only=False):
if image_dict[TYPE] == LABEL:
label_image_dict = image_dict
break
def __init__(self, patch_size: TypePatchSize, label_name: str):
super().__init__(patch_size, probability_map=label_name)

def get_probability_map(self, sample: Subject) -> torch.Tensor:
"""Return binarized image for sampling."""
if self.probability_map_name in sample:
data = sample[self.probability_map_name].data > 0.5
else:
raise ValueError('No images of type torchio.LABEL found in sample')
return label_image_dict

def extract_patch(self):
has_label = False
label_image_data = self.get_first_label_image_dict(self.sample)[DATA]
while not has_label:
index_ini, index_fin = self.get_random_indices(
self.sample, self.patch_size)
patch_label = crop(label_image_data, index_ini, index_fin)
has_label = patch_label.sum() > 0
cropped_sample = self.copy_and_crop(
self.sample,
index_ini,
index_fin,
)
return cropped_sample
message = (
f'Image "{self.probability_map_name}"'
f' not found in subject sample: {sample}'
)
raise KeyError(message)
return data
Loading

0 comments on commit d5a917d

Please sign in to comment.