Skip to content

Commit

Permalink
Merge pull request #60 from JonasHell/main
Browse files Browse the repository at this point in the history
Add sampler that checks if there are a minimum number of instances in the volume
  • Loading branch information
constantinpape committed Jun 3, 2022
2 parents 83c8d0f + 7d0d03b commit 4411d9e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
27 changes: 27 additions & 0 deletions torch_em/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,30 @@ def __call__(self, x, y=None):
return True
else:
return np.random.rand() > self.p_reject


class MinInstanceSampler:
def __init__(self, min_num_instances=2, p_reject=1.0):
self.min_num_instances = min_num_instances
self.p_reject = p_reject

def __call__(self, x, y):
uniques = np.unique(y)
if len(uniques) >= self.min_num_instances:
return True
else:
return np.random.rand() > self.p_reject


class MinTwoInstanceSampler:
# for the case of min_num_instances=2 this is roughly 10x faster
# than using MinInstanceSampler since np.unique is slow
def __init__(self, p_reject=1.0):
self.p_reject = p_reject

def __call__(self, x, y):
sample_value = y.flat[0]
if (y != sample_value).any():
return True
else:
return np.random.rand() > self.p_reject
6 changes: 4 additions & 2 deletions torch_em/shallow2deep/prepare_shallow2deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def _load_rf_segmentation_dataset(
raw_paths, raw_key, label_paths, label_key, patch_shape_min, patch_shape_max, **kwargs
):
rois = kwargs.pop("rois", None)
sampler = torch_em.data.MinForegroundSampler(min_fraction=0.01)
sampler = kwargs.pop("sampler", None)
sampler = sampler if sampler else torch_em.data.MinForegroundSampler(min_fraction=0.01)
if isinstance(raw_paths, str):
if rois is not None:
assert len(rois) == 3 and all(isinstance(roi, slice) for roi in rois)
Expand Down Expand Up @@ -299,6 +300,7 @@ def prepare_shallow2deep(
is_seg_dataset=None,
balance_labels=True,
filter_config=None,
sampler=None,
**rf_kwargs,
):
assert len(patch_shape_min) == len(patch_shape_max)
Expand All @@ -312,7 +314,7 @@ def prepare_shallow2deep(
ds = _load_rf_segmentation_dataset(raw_paths, raw_key, label_paths, label_key,
patch_shape_min, patch_shape_max,
raw_transform=raw_transform, label_transform=label_transform,
rois=rois, n_samples=n_forests)
rois=rois, n_samples=n_forests, sampler=sampler)
else:
ds = _load_rf_image_collection_dataset(raw_paths, raw_key, label_paths, label_key,
patch_shape_min, patch_shape_max, roi=rois,
Expand Down

0 comments on commit 4411d9e

Please sign in to comment.