Skip to content

Commit

Permalink
Merge 5231c6e into fca4b74
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jun 1, 2020
2 parents fca4b74 + 5231c6e commit d7a50ea
Show file tree
Hide file tree
Showing 23 changed files with 561 additions and 202 deletions.
2 changes: 1 addition & 1 deletion docs/source/data/patch_based.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ volumes for training and testing.


.. toctree::
:maxdepth: 2
:maxdepth: 3

patch_training.rst
patch_inference.rst
40 changes: 23 additions & 17 deletions docs/source/data/patch_training.rst
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
Training
========

Random samplers
---------------
Patch samplers
--------------

Samplers are used to randomly extract patches from volumes.
They are called with a sample generated by an
:py:class:`~torchio.ImagesDataset` and return a Python generator that yields
cropped versions of the sample.

TorchIO includes grid, uniform and label patch samplers. There is also an
aggregator used for dense predictions.
For more information about patch-based training, see
`this NiftyNet tutorial <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_.


.. currentmodule:: torchio.data


:class:`ImageSampler`
:class:`PatchSampler`
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ImageSampler
:members:
.. autoclass:: PatchSampler
:show-inheritance:


:class:`WeightedSampler`
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: WeightedSampler
:show-inheritance:


:class:`UniformSampler`
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: UniformSampler
:show-inheritance:


:class:`LabelSampler`
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LabelSampler
:members:
:show-inheritance:


Queue
-----

In the following animation, :attr:`shuffle_subjects` is ``False``
and :attr:`shuffle_patches` is ``True``.

.. raw:: html

<embed>
<iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
</embed>


.. currentmodule:: torchio.data

Expand Down
5 changes: 2 additions & 3 deletions examples/example_heteromodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torchio
from torchio import Image, Subject, ImagesDataset, Queue
from torchio.data import ImageSampler
from torchio.data import UniformSampler

def main():
# Define training and patches sampling parameters
Expand Down Expand Up @@ -45,8 +45,7 @@ def main():
subjects_dataset,
queue_length,
samples_per_volume,
patch_size,
ImageSampler,
UniformSampler(patch_size),
)

# This collate_fn is needed in the case of missing modalities
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python

from ..utils import TorchioTestCase
from torchio.data import GridSampler
from ...utils import TorchioTestCase


class TestGridSampler(TorchioTestCase):
"""Tests for `GridSampler`."""
Expand Down
Empty file added tests/data/sampler/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/data/sampler/test_label_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torchio
from torchio.data import LabelSampler
from ...utils import TorchioTestCase


class TestLabelSampler(TorchioTestCase):
"""Tests for `LabelSampler` class."""

def test_label_sampler(self):
sampler = LabelSampler(5, 'label')
for patch in sampler(self.sample, num_patches=10):
patch_center = patch['label'][torchio.DATA][0, 2, 2, 2]
self.assertEqual(patch_center, 1)

def test_label_probabilities(self):
labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, -1)
subject = torchio.Subject(
label=torchio.Image(tensor=labels, type=torchio.LABEL),
)
sample = torchio.ImagesDataset([subject])[0]
probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
probabilities = sampler.get_probability_map(sample)
fixture = torch.Tensor((0, 0, 2/12, 2/12, 3/12, 2/12, 0))
assert torch.all(probabilities.squeeze().eq(fixture))
25 changes: 25 additions & 0 deletions tests/data/sampler/test_weighted_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import torchio
from torchio.data import WeightedSampler
from ...utils import TorchioTestCase


class TestWeightedSampler(TorchioTestCase):
"""Tests for `WeightedSampler` class."""

def test_weighted_sampler(self):
sample = self.get_sample((7, 7, 7))
sampler = WeightedSampler(5, 'prob')
patch = next(iter(sampler(sample)))
self.assertEqual(tuple(patch['index_ini']), (1, 1, 1))

def get_sample(self, image_shape):
t1 = torch.rand(*image_shape)
prob = torch.zeros_like(t1)
prob[3, 3, 3] = 1
subject = torchio.Subject(
t1=torchio.Image(tensor=t1),
prob=torchio.Image(tensor=prob),
)
sample = torchio.ImagesDataset([subject])[0]
return sample
13 changes: 0 additions & 13 deletions tests/data/test_label_sampler.py

This file was deleted.

7 changes: 4 additions & 3 deletions tests/data/test_queue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch.utils.data import DataLoader
from torchio.data import ImageSampler
from torchio.data import UniformSampler
from torchio import ImagesDataset, Queue, DATA
from torchio.utils import create_dummy_dataset
from ..utils import TorchioTestCase
Expand All @@ -19,12 +19,13 @@ def setUp(self):

def test_queue(self):
subjects_dataset = ImagesDataset(self.subjects_list)
patch_size = 10
sampler = UniformSampler(patch_size)
queue_dataset = Queue(
subjects_dataset,
max_length=6,
samples_per_volume=2,
patch_size=10,
sampler_class=ImageSampler,
sampler=sampler,
num_workers=0,
verbose=True,
)
Expand Down
2 changes: 1 addition & 1 deletion torchio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from .image import Image
from .subject import Subject
from .dataset import ImagesDataset
from .sampler import ImageSampler, LabelSampler
from .inference import GridSampler, GridAggregator
from .sampler import PatchSampler, LabelSampler, WeightedSampler, UniformSampler
43 changes: 23 additions & 20 deletions torchio/data/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import random
import warnings
from typing import List, Iterator
from itertools import islice
from typing import List, Iterator

from tqdm import trange
from torch.utils.data import Dataset, DataLoader
from .. import TypeTuple

from .sampler import PatchSampler
from .dataset import ImagesDataset
from .sampler import ImageSampler


class Queue(Dataset):
Expand All @@ -17,16 +18,11 @@ class Queue(Dataset):
:class:`~torchio.data.dataset.ImagesDataset`.
max_length: Maximum number of patches that can be stored in the queue.
Using a large number means that the queue needs to be filled less
often, but more RAM is needed to store the patches.
often, but more CPU memory is needed to store the patches.
samples_per_volume: Number of patches to extract from each volume.
A small number of patches ensures a large variability in the queue,
but training will be slower.
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
of size :math:`d \times h \times w`.
If a single number :math:`n` is provided,
:math:`d = h = w = n`.
sampler_class: An instance of :class:`~torchio.data.datasetampler` used
to define the patches sampling strategy.
sampler: A sampler used to extract patches from the volumes.
num_workers: Number of subprocesses to use for data loading
(as in :class:`torch.utils.data.DataLoader`).
``0`` means that the data will be loaded in the main process.
Expand All @@ -37,6 +33,16 @@ class Queue(Dataset):
queue.
verbose: If ``True``, some debugging messages are printed.
This sketch can be used to experiment and understand how the queue works.
In this case, :attr:`shuffle_subjects` is ``False``
and :attr:`shuffle_patches` is ``True``.
.. raw:: html
<embed>
<iframe style="width: 640px; height: 360px; overflow: hidden;" scrolling="no" frameborder="0" src="https://editor.p5js.org/embed/DZwjZzkkV"></iframe>
</embed>
.. note:: :attr:`num_workers` refers to the number of workers used to
load and transform the volumes. Multiprocessing is not needed to pop
patches from the queue.
Expand All @@ -50,7 +56,7 @@ class Queue(Dataset):
... max_length=300,
... samples_per_volume=10,
... patch_size=96,
... sampler_class=torchio.sampler.ImageSampler,
... sampler=,
... num_workers=4,
... shuffle_subjects=True,
... shuffle_patches=True,
Expand All @@ -69,8 +75,7 @@ def __init__(
subjects_dataset: ImagesDataset,
max_length: int,
samples_per_volume: int,
patch_size: TypeTuple,
sampler_class: ImageSampler,
sampler: PatchSampler,
num_workers: int = 0,
shuffle_subjects: bool = True,
shuffle_patches: bool = True,
Expand All @@ -81,8 +86,7 @@ def __init__(
self.shuffle_subjects = shuffle_subjects
self.shuffle_patches = shuffle_patches
self.samples_per_volume = samples_per_volume
self.sampler_class = sampler_class
self.patch_size = patch_size
self.sampler = sampler
self.num_workers = num_workers
self.verbose = verbose
self.subjects_iterable = self.get_subjects_iterable()
Expand Down Expand Up @@ -130,8 +134,7 @@ def iterations_per_epoch(self) -> int:
return self.num_subjects * self.samples_per_volume

def fill(self) -> None:
assert self.sampler_class is not None
assert self.patch_size is not None
assert self.sampler is not None
if self.max_length % self.samples_per_volume != 0:
message = (
f'Queue length ({self.max_length})'
Expand All @@ -153,9 +156,9 @@ def fill(self) -> None:
iterable = range(num_subjects_for_queue)
for _ in iterable:
subject_sample = self.get_next_subject_sample()
sampler = self.sampler_class(subject_sample, self.patch_size)
samples = list(islice(sampler, self.samples_per_volume))
self.patches_list.extend(samples)
iterable = self.sampler(subject_sample)
patches = list(islice(iterable, self.samples_per_volume))
self.patches_list.extend(patches)
if self.shuffle_patches:
random.shuffle(self.patches_list)

Expand Down
4 changes: 3 additions & 1 deletion torchio/data/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .sampler import ImageSampler
from .label import LabelSampler
from .sampler import PatchSampler
from .uniform import UniformSampler
from .weighted import WeightedSampler

0 comments on commit d7a50ea

Please sign in to comment.