Skip to content

Commit

Permalink
[block] fix min_separation in detection (once for all?)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed Mar 15, 2023
1 parent f0debf2 commit 83f55e8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
27 changes: 20 additions & 7 deletions prose/blocks/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def __init__(
self.min_area = min_area
self.minor_length = minor_length

def clean(self, sources, *args):
def clean(self, sources):
peaks = np.array([s.peak for s in sources])
coords = np.array([s.coords for s in sources])
_sources = sources.copy()

if len(sources) > 0:
Expand All @@ -61,10 +60,22 @@ def clean(self, sources, *args):
if self.n is not None:
_sources = _sources[0 : self.n]
if self.min_separation:
idxs = clean_stars_positions(
coords, tolerance=self.min_separation, output_id=True
)[1]
_sources = _sources[idxs]
final_sources = sources.copy()

for s in final_sources:
s.keep = True

for i, s in enumerate(final_sources):
if final_sources[i].keep:
distances = np.linalg.norm(
s.coords - final_sources.coords, axis=1
)
distances[i] == np.nan
idxs = np.flatnonzero(distances < self.min_separation)
for j in idxs[idxs > i]:
final_sources[int(j)].keep = False

_sources = sources[np.array([s.keep for s in final_sources])]

for i, s in enumerate(_sources):
s.i = i
Expand Down Expand Up @@ -187,7 +198,9 @@ def run(self, image):
idxs = np.flatnonzero([r.euler_number == 1 for r in regions])
regions = [regions[i] for i in idxs]

sources = np.array([PointSource.from_region(region) for region in regions])
sources = Sources(
np.array([PointSource.from_region(region) for region in regions])
)
image.sources = Sources(self.clean(sources), source_type="PointSource")

@property
Expand Down
12 changes: 10 additions & 2 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ def test_psf_blocks(block):
block().run(image_psf)


def test_detection_min_separation():
@pytest.mark.parametrize("d", [10, 50, 80, 100])
def test_detection_min_separation(d):
from prose.blocks.detection import PointSourceDetection

PointSourceDetection(min_separation=10.0)(image)
PointSourceDetection(min_separation=d).run(image)

distances = np.linalg.norm(
image.sources.coords - image.sources.coords[:, None], axis=-1
)
distances = np.where(np.eye(distances.shape[0]).astype(bool), np.nan, distances)
distances = np.nanmin(distances, 0)
np.testing.assert_allclose(distances > d, True)


def test_Trim():
Expand Down

0 comments on commit 83f55e8

Please sign in to comment.