Skip to content

Commit

Permalink
Use first label in sample if None is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Jun 2, 2020
1 parent 5231c6e commit ce7f273
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/data/sampler/test_label_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class TestLabelSampler(TorchioTestCase):
"""Tests for `LabelSampler` class."""

def test_label_sampler(self):
sampler = LabelSampler(5, 'label')
sampler = LabelSampler(5)
for patch in sampler(self.sample, num_patches=10):
patch_center = patch['label'][torchio.DATA][0, 2, 2, 2]
self.assertEqual(patch_center, 1)
Expand Down
18 changes: 13 additions & 5 deletions torchio/data/sampler/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from ...data.subject import Subject
from ...torchio import TypePatchSize, DATA
from ...torchio import TypePatchSize, DATA, TYPE, LABEL
from .weighted import WeightedSampler


Expand All @@ -16,7 +16,9 @@ class LabelSampler(WeightedSampler):
Args:
patch_size: See :py:class:`~torchio.data.PatchSampler`.
label_name: Name of the label image in the sample that will be used to
generate the sampling probability map.
generate the sampling probability map. If ``None``, the first image
of type :py:attr:`torchio.LABEL` found in the subject sample will be
used.
label_probabilities: Dictionary containing the probability that each
class will be sampled. Probabilities do not need to be normalized.
For example, a value of ``{0: 0, 1: 2, 2: 1, 3: 1}`` will create a
Expand Down Expand Up @@ -46,20 +48,26 @@ class will be sampled. Probabilities do not need to be normalized.
def __init__(
self,
patch_size: TypePatchSize,
label_name: str,
label_name: Optional[str] = None,
label_probabilities: Optional[Dict[int, float]] = None,
):
super().__init__(patch_size, probability_map=label_name)
self.label_probabilities_dict = label_probabilities

def get_probability_map(self, sample: Subject) -> torch.Tensor:
if self.probability_map_name not in sample:
if self.probability_map_name is None:
for image in sample.get_images(intensity_only=False):
if image[TYPE] == LABEL:
label_map_tensor = image[DATA]
break
elif self.probability_map_name in sample:
label_map_tensor = sample[self.probability_map_name][DATA]
else:
message = (
f'Image "{self.probability_map_name}"'
f' not found in subject sample: {sample}'
)
raise KeyError(message)
label_map_tensor = sample[self.probability_map_name][DATA]
if self.label_probabilities_dict is None:
return label_map_tensor > 0
probability_map = self.get_probabilities_from_label_map(
Expand Down

0 comments on commit ce7f273

Please sign in to comment.