From db00fc9f9f40abaf96db72b68da7bf715d3a5f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Thu, 10 Sep 2020 14:31:42 +0200 Subject: [PATCH 1/2] Faster implementation for UniformSampler --- torchio/data/sampler/uniform.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/torchio/data/sampler/uniform.py b/torchio/data/sampler/uniform.py index b115c466a..619d78e45 100644 --- a/torchio/data/sampler/uniform.py +++ b/torchio/data/sampler/uniform.py @@ -1,10 +1,12 @@ import torch from ...data.subject import Subject from ...torchio import TypePatchSize -from .weighted import WeightedSampler +from .sampler import RandomSampler +from typing import Optional, Tuple, Generator +import numpy as np -class UniformSampler(WeightedSampler): +class UniformSampler(RandomSampler): """Randomly extract patches from a volume with uniform probability. Args: @@ -15,3 +17,18 @@ def __init__(self, patch_size: TypePatchSize): def get_probability_map(self, sample: Subject) -> torch.Tensor: return torch.ones(1, *sample.spatial_shape) + + def __call__(self, sample: Subject) -> Generator[Subject, None, None]: + + sample.check_consistent_spatial_shape() + + if np.any(self.patch_size > sample.spatial_shape): + message = ( + f'Patch size {tuple(self.patch_size)} cannot be' + f' larger than image size {tuple(sample.spatial_shape)}' + ) + raise RuntimeError(message) + + valid_range = sample.spatial_shape - self.patch_size + corners = np.random.randint(valid_range + 1) + yield self.extract_patch(sample, corners) From 1eb5f2f029d0068f0c24007aceec190a54a3081c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20V=C3=B6lgyes?= Date: Sat, 12 Sep 2020 12:05:17 +0200 Subject: [PATCH 2/2] Replacing numpy.randint with torch.randint in the UniformSampler --- tests/data/sampler/test_uniform_sampler.py | 10 ---------- torchio/data/sampler/uniform.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/data/sampler/test_uniform_sampler.py b/tests/data/sampler/test_uniform_sampler.py index d375d6ea0..54982133b 100644 --- a/tests/data/sampler/test_uniform_sampler.py +++ b/tests/data/sampler/test_uniform_sampler.py @@ -14,16 +14,6 @@ def test_uniform_probabilities(self): fixtures = torch.ones_like(probabilities) assert torch.all(probabilities.eq(fixtures)) - def test_processed_uniform_probabilities(self): - sampler = UniformSampler(5) - probabilities = sampler.get_probability_map(self.sample) - probabilities = sampler.process_probability_map(probabilities) - fixtures = np.zeros_like(probabilities) - # Other positions cannot be patch centers - fixtures[2:-2, 2:-2, 2:-2] = probabilities[2, 2, 2] - self.assertAlmostEqual(probabilities.sum(), 1) - assert np.equal(probabilities, fixtures).all() - def test_incosistent_shape(self): # https://github.com/fepegar/torchio/issues/234#issuecomment-675029767 sample = torchio.Subject( diff --git a/torchio/data/sampler/uniform.py b/torchio/data/sampler/uniform.py index 619d78e45..50a8257ff 100644 --- a/torchio/data/sampler/uniform.py +++ b/torchio/data/sampler/uniform.py @@ -30,5 +30,5 @@ def __call__(self, sample: Subject) -> Generator[Subject, None, None]: raise RuntimeError(message) valid_range = sample.spatial_shape - self.patch_size - corners = np.random.randint(valid_range + 1) + corners = np.asarray([torch.randint(x+1,(1,)).item() for x in valid_range]) yield self.extract_patch(sample, corners)