Skip to content

Commit

Permalink
Merge 1eb5f2f into 08af12b
Browse files Browse the repository at this point in the history
  • Loading branch information
dvolgyes committed Sep 16, 2020
2 parents 08af12b + 1eb5f2f commit 8ef0d08
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
10 changes: 0 additions & 10 deletions tests/data/sampler/test_uniform_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@ def test_uniform_probabilities(self):
fixtures = torch.ones_like(probabilities)
assert torch.all(probabilities.eq(fixtures))

def test_processed_uniform_probabilities(self):
sampler = UniformSampler(5)
probabilities = sampler.get_probability_map(self.sample)
probabilities = sampler.process_probability_map(probabilities)
fixtures = np.zeros_like(probabilities)
# Other positions cannot be patch centers
fixtures[2:-2, 2:-2, 2:-2] = probabilities[2, 2, 2]
self.assertAlmostEqual(probabilities.sum(), 1)
assert np.equal(probabilities, fixtures).all()

def test_incosistent_shape(self):
# https://github.com/fepegar/torchio/issues/234#issuecomment-675029767
sample = torchio.Subject(
Expand Down
21 changes: 19 additions & 2 deletions torchio/data/sampler/uniform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
from ...data.subject import Subject
from ...torchio import TypePatchSize
from .weighted import WeightedSampler
from .sampler import RandomSampler
from typing import Optional, Tuple, Generator
import numpy as np


class UniformSampler(WeightedSampler):
class UniformSampler(RandomSampler):
"""Randomly extract patches from a volume with uniform probability.
Args:
Expand All @@ -15,3 +17,18 @@ def __init__(self, patch_size: TypePatchSize):

def get_probability_map(self, sample: Subject) -> torch.Tensor:
return torch.ones(1, *sample.spatial_shape)

def __call__(self, sample: Subject) -> Generator[Subject, None, None]:

sample.check_consistent_spatial_shape()

if np.any(self.patch_size > sample.spatial_shape):
message = (
f'Patch size {tuple(self.patch_size)} cannot be'
f' larger than image size {tuple(sample.spatial_shape)}'
)
raise RuntimeError(message)

valid_range = sample.spatial_shape - self.patch_size
corners = np.asarray([torch.randint(x+1,(1,)).item() for x in valid_range])
yield self.extract_patch(sample, corners)

0 comments on commit 8ef0d08

Please sign in to comment.