Skip to content

Commit

Permalink
Hotfix for case where there are no centroids detected in inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Jan 21, 2020
1 parent 26ac829 commit 96d7308
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
15 changes: 12 additions & 3 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size):
"""Runs the inference components of pipeline for a chunk."""

if "centroid" in self.policies:
# Detect centroids and pull out region proposals.
centroid_predictor = self.policies["centroid"]
region_proposal_extractor = self.policies["region"]

Expand All @@ -114,21 +115,26 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size):
region_proposal.RegionProposalSet.from_uncropped(images=img_chunk)
]

# for region_ind in range(len(region_proposal_sets)):
# region_proposal_sets[region_ind].sample_inds += chunk_ind * chunk_size

if "paf" not in self.policies:
# If we don't have PAFs, we must be doing topdown or single-instance.
if "topdown" in self.policies:
topdown_peak_finder = self.policies["topdown"]
else:
topdown_peak_finder = self.policies["confmap"]

if len(region_proposal_sets) == 0:
# No region proposals were found, so just return empty result.
return defaultdict(list)

# Only one scale region proposals in topdown inference.
rps = region_proposal_sets[0]

# Find peaks in each sample of the region proposal set.
sample_peak_pts, sample_peak_vals = topdown_peak_finder.predict_rps(rps)
sample_peak_pts = sample_peak_pts.to_tensor().numpy()
sample_peak_vals = sample_peak_vals.to_tensor().numpy()

# Gather instances across all region proposals for this chunk.
predicted_instances_chunk = topdown.make_sample_grouped_predicted_instances(
sample_peak_pts,
sample_peak_vals,
Expand All @@ -140,16 +146,19 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size):
confmap_peak_finder = self.policies["confmap"]
paf_grouper = self.policies["paf"]

# Find multi-instance peaks in each region proposal set.
region_peak_sets = []
for rps in region_proposal_sets:
region_peaks = confmap_peak_finder.predict_rps(rps)
region_peak_sets.append(region_peaks)

# Group peaks into instances using PAFs.
region_instance_sets = []
for rps, region_peaks in zip(region_proposal_sets, region_peak_sets):
region_instances = paf_grouper.predict_rps(rps, region_peaks)
region_instance_sets.append(region_instances)

# Gather instances across all region proposals for this chunk.
predicted_instances_chunk = defaultdict(list)
for region_instance_set in region_instance_sets:
for sample, region_instances in region_instance_set.items():
Expand Down
4 changes: 4 additions & 0 deletions sleap/nn/region_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def extract_region_proposal_sets(
sample_inds = size_grouped_sample_inds[box_size]
bboxes = size_grouped_bboxes[box_size]

if len(bboxes) == 0:
# Skip region proposal set if we found no bboxes.
continue

# Extract image patches for all regions in the set.
patches = extract_patches(
imgs, tf.cast(bboxes, tf.float32), tf.cast(sample_inds, tf.int32)
Expand Down

0 comments on commit 96d7308

Please sign in to comment.