-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add weighted sampler Add some docs for weighted sampler Rename probability map argument Add weighted sampler to docs Add features to samplers Add abstract method get_probability_map() Move tests Add features, tests and docs for samplers Use crop transform to extract patches Add comment to bounds transform Add type hint for samplers Fix TypeError for Python <3.8
- Loading branch information
Showing
21 changed files
with
502 additions
and
199 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
3 changes: 2 additions & 1 deletion
3
tests/data/test_grid_sampler.py → tests/data/inference/test_grid_sampler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from torchio import DATA | ||
from torchio.data import LabelSampler | ||
from ...utils import TorchioTestCase | ||
|
||
|
||
class TestLabelSampler(TorchioTestCase): | ||
"""Tests for `LabelSampler` class.""" | ||
|
||
def test_label_sampler(self): | ||
sampler = LabelSampler(5, 'label') | ||
for patch in sampler(self.sample, num_patches=10): | ||
self.assertEqual(patch['label'][DATA][0, 2, 2, 2], 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
import torchio | ||
from torchio.data import WeightedSampler | ||
from ...utils import TorchioTestCase | ||
|
||
|
||
class TestWeightedSampler(TorchioTestCase): | ||
"""Tests for `WeightedSampler` class.""" | ||
|
||
def test_label_sampler(self): | ||
sample = self.get_sample((7, 7, 7)) | ||
sampler = WeightedSampler(5, 'prob') | ||
patch = next(iter(sampler(sample))) | ||
self.assertEqual(tuple(patch['index_ini']), (1, 1, 1)) | ||
|
||
def get_sample(self, image_shape): | ||
t1 = torch.rand(*image_shape) | ||
prob = torch.zeros_like(t1) | ||
prob[3, 3, 3] = 1 | ||
subject = torchio.Subject( | ||
t1=torchio.Image(tensor=t1), | ||
prob=torchio.Image(tensor=prob), | ||
) | ||
sample = torchio.ImagesDataset([subject])[0] | ||
return sample |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from .sampler import ImageSampler | ||
from .label import LabelSampler | ||
from .sampler import PatchSampler | ||
from .uniform import UniformSampler | ||
from .weighted import WeightedSampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,49 @@ | ||
from .sampler import ImageSampler, crop | ||
from ... import DATA, LABEL, TYPE | ||
from ..subject import Subject | ||
import torch | ||
from ...data.subject import Subject | ||
from ...torchio import TypePatchSize | ||
from .weighted import WeightedSampler | ||
|
||
|
||
class LabelSampler(ImageSampler): | ||
r"""Extract random patches containing labeled voxels. | ||
class LabelSampler(WeightedSampler): | ||
r"""Extract random patches with labeled voxels at their center. | ||
This iterable dataset yields patches that contain at least one voxel | ||
without background. | ||
It extracts the label data from the first image in the sample with type | ||
:py:attr:`torchio.LABEL`. | ||
This sampler yields patches whose center value is greater than 0 | ||
in the :py:attr:`label_name`. | ||
Args: | ||
sample: Sample generated by a | ||
:py:class:`~torchio.data.dataset.ImagesDataset`, from which image | ||
patches will be extracted. | ||
patch_size: Tuple of integers :math:`(d, h, w)` to generate patches | ||
of size :math:`d \times h \times w`. | ||
If a single number :math:`n` is provided, | ||
:math:`d = h = w = n`. | ||
.. warning:: For now, this implementation is not efficient because it uses | ||
brute force to look for foreground voxels. It the number of | ||
non-background voxels is very small, this sampler will be slow. | ||
patch_size: See :py:class:`~torchio.data.PatchSampler`. | ||
label_name: Name of the label image in the sample that will be used to | ||
generate the sampling probability map. | ||
Example: | ||
>>> import torchio | ||
>>> subject = torchio.datasets.Colin27() | ||
>>> subject | ||
Colin27(Keys: ('t1', 'head', 'brain'); images: 3) | ||
>>> sample = torchio.ImagesDataset([subject])[0] | ||
>>> sampler = torchio.data.LabelSampler(64, 'brain') | ||
>>> generator = sampler(sample) | ||
>>> for patch in generator: | ||
... print(patch.shape) | ||
If you want a specific number of patches from a volume, e.g. 10: | ||
>>> generator = sampler(sample, num_patches=10) | ||
>>> for patch in iterator: | ||
... print(patch.shape) | ||
""" | ||
@staticmethod | ||
def get_first_label_image_dict(sample: Subject): | ||
for image_dict in sample.get_images(intensity_only=False): | ||
if image_dict[TYPE] == LABEL: | ||
label_image_dict = image_dict | ||
break | ||
def __init__(self, patch_size: TypePatchSize, label_name: str): | ||
super().__init__(patch_size, probability_map=label_name) | ||
|
||
def get_probability_map(self, sample: Subject) -> torch.Tensor: | ||
"""Return binarized image for sampling.""" | ||
if self.probability_map_name in sample: | ||
data = sample[self.probability_map_name].data > 0.5 | ||
else: | ||
raise ValueError('No images of type torchio.LABEL found in sample') | ||
return label_image_dict | ||
|
||
def extract_patch(self): | ||
has_label = False | ||
label_image_data = self.get_first_label_image_dict(self.sample)[DATA] | ||
while not has_label: | ||
index_ini, index_fin = self.get_random_indices( | ||
self.sample, self.patch_size) | ||
patch_label = crop(label_image_data, index_ini, index_fin) | ||
has_label = patch_label.sum() > 0 | ||
cropped_sample = self.copy_and_crop( | ||
self.sample, | ||
index_ini, | ||
index_fin, | ||
) | ||
return cropped_sample | ||
message = ( | ||
f'Image "{self.probability_map_name}"' | ||
f' not found in subject sample: {sample}' | ||
) | ||
raise KeyError(message) | ||
return data |
Oops, something went wrong.