diff --git a/tests/data/sampler/test_uniform_sampler.py b/tests/data/sampler/test_uniform_sampler.py index d375d6ea0..54982133b 100644 --- a/tests/data/sampler/test_uniform_sampler.py +++ b/tests/data/sampler/test_uniform_sampler.py @@ -14,16 +14,6 @@ def test_uniform_probabilities(self): fixtures = torch.ones_like(probabilities) assert torch.all(probabilities.eq(fixtures)) - def test_processed_uniform_probabilities(self): - sampler = UniformSampler(5) - probabilities = sampler.get_probability_map(self.sample) - probabilities = sampler.process_probability_map(probabilities) - fixtures = np.zeros_like(probabilities) - # Other positions cannot be patch centers - fixtures[2:-2, 2:-2, 2:-2] = probabilities[2, 2, 2] - self.assertAlmostEqual(probabilities.sum(), 1) - assert np.equal(probabilities, fixtures).all() - def test_incosistent_shape(self): # https://github.com/fepegar/torchio/issues/234#issuecomment-675029767 sample = torchio.Subject( diff --git a/torchio/data/sampler/uniform.py b/torchio/data/sampler/uniform.py index b115c466a..50a8257ff 100644 --- a/torchio/data/sampler/uniform.py +++ b/torchio/data/sampler/uniform.py @@ -1,10 +1,12 @@ import torch from ...data.subject import Subject from ...torchio import TypePatchSize -from .weighted import WeightedSampler +from .sampler import RandomSampler +from typing import Optional, Tuple, Generator +import numpy as np -class UniformSampler(WeightedSampler): +class UniformSampler(RandomSampler): """Randomly extract patches from a volume with uniform probability. Args: @@ -15,3 +17,18 @@ def __init__(self, patch_size: TypePatchSize): def get_probability_map(self, sample: Subject) -> torch.Tensor: return torch.ones(1, *sample.spatial_shape) + + def __call__(self, sample: Subject) -> Generator[Subject, None, None]: + + sample.check_consistent_spatial_shape() + + if np.any(self.patch_size > sample.spatial_shape): + message = ( + f'Patch size {tuple(self.patch_size)} cannot be' + f' larger than image size {tuple(sample.spatial_shape)}' + ) + raise RuntimeError(message) + + valid_range = sample.spatial_shape - self.patch_size + corners = np.asarray([torch.randint(x+1,(1,)).item() for x in valid_range]) + yield self.extract_patch(sample, corners)