Skip to content

Commit

Permalink
avoid infinite loops in the sampling of points in SpecifiedLocation
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Apr 9, 2024
1 parent 36225d4 commit f034e7d
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions gunpowder/nodes/specified_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,29 @@ class SpecifiedLocation(BatchFilter):
Default is None, which places the point in the center.
Chooses uniformly from [loc - jitter, loc + jitter] in each
direction.
attempt_factor (``int``):
If choosing randomly then given `n` points, sample
`attempt_factor * n` points at most before giving up and
throwing an error.
"""

def __init__(self, locations, choose_randomly=False, extra_data=None, jitter=None):
def __init__(
self,
locations,
choose_randomly=False,
extra_data=None,
jitter=None,
attempt_factor: int = 5,
):
self.coordinates = locations
self.choose_randomly = choose_randomly
self.jitter = jitter
self.loc_i = -1
self.upstream_spec = None
self.specified_shift = None
self.attempt_factor = attempt_factor

if extra_data is not None:
assert len(extra_data) == len(locations), (
Expand All @@ -79,13 +93,28 @@ def prepare(self, request):
request_center = total_roi.shape / 2 + total_roi.offset

self.specified_shift = self._get_next_shift(request_center, lcm_voxel_size)
loop_counter = 0
while not self.__check_shift(request):
logger.warning(
"Location %s (shift %s) skipped"
% (self.coordinates[self.loc_i], self.specified_shift)
)
self.specified_shift = self._get_next_shift(request_center, lcm_voxel_size)

loop_counter += 1
if loop_counter >= len(self.coordinates) * (
1 + (self.attempt_factor - 1) * int(self.choose_randomly)
):
if self.choose_randomly:
raise Exception(
f"Took {5*len(self.coordinates)} samples of {len(self.coordinates)} points "
"and did not find a suitible location"
)
else:
raise Exception(
"Looped through every possible location and None are valid"
)

# Set shift for all requests
for specs_type in [request.array_specs, request.graph_specs]:
for key, spec in specs_type.items():
Expand All @@ -106,9 +135,9 @@ def process(self, batch, request):
for array_key, spec in request.array_specs.items():
batch.arrays[array_key].spec.roi = spec.roi
if self.extra_data is not None:
batch.arrays[array_key].attrs[
"specified_location_extra_data"
] = self.extra_data[self.loc_i]
batch.arrays[array_key].attrs["specified_location_extra_data"] = (
self.extra_data[self.loc_i]
)

for graph_key, spec in request.graph_specs.items():
batch.points[graph_key].spec.roi = spec.roi
Expand Down

0 comments on commit f034e7d

Please sign in to comment.