diff --git a/prose/blocks/detection.py b/prose/blocks/detection.py index 86685e7b..9b47ac56 100644 --- a/prose/blocks/detection.py +++ b/prose/blocks/detection.py @@ -49,9 +49,8 @@ def __init__( self.min_area = min_area self.minor_length = minor_length - def clean(self, sources, *args): + def clean(self, sources): peaks = np.array([s.peak for s in sources]) - coords = np.array([s.coords for s in sources]) _sources = sources.copy() if len(sources) > 0: @@ -61,10 +60,22 @@ def clean(self, sources, *args): if self.n is not None: _sources = _sources[0 : self.n] if self.min_separation: - idxs = clean_stars_positions( - coords, tolerance=self.min_separation, output_id=True - )[1] - _sources = _sources[idxs] + final_sources = sources.copy() + + for s in final_sources: + s.keep = True + + for i, s in enumerate(final_sources): + if final_sources[i].keep: + distances = np.linalg.norm( + s.coords - final_sources.coords, axis=1 + ) + distances[i] == np.nan + idxs = np.flatnonzero(distances < self.min_separation) + for j in idxs[idxs > i]: + final_sources[int(j)].keep = False + + _sources = sources[np.array([s.keep for s in final_sources])] for i, s in enumerate(_sources): s.i = i @@ -187,7 +198,9 @@ def run(self, image): idxs = np.flatnonzero([r.euler_number == 1 for r in regions]) regions = [regions[i] for i in idxs] - sources = np.array([PointSource.from_region(region) for region in regions]) + sources = Sources( + np.array([PointSource.from_region(region) for region in regions]) + ) image.sources = Sources(self.clean(sources), source_type="PointSource") @property diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 36ee2568..ef3555bd 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -38,10 +38,18 @@ def test_psf_blocks(block): block().run(image_psf) -def test_detection_min_separation(): +@pytest.mark.parametrize("d", [10, 50, 80, 100]) +def test_detection_min_separation(d): from prose.blocks.detection import PointSourceDetection - PointSourceDetection(min_separation=10.0)(image) + PointSourceDetection(min_separation=d).run(image) + + distances = np.linalg.norm( + image.sources.coords - image.sources.coords[:, None], axis=-1 + ) + distances = np.where(np.eye(distances.shape[0]).astype(bool), np.nan, distances) + distances = np.nanmin(distances, 0) + np.testing.assert_allclose(distances > d, True) def test_Trim():