Skip to content

Commit

Permalink
Remove merge sources from deblender
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Apr 26, 2022
1 parent 6b85f9f commit a2e5dfe
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 149 deletions.
31 changes: 13 additions & 18 deletions python/lsst/meas/extensions/scarlet/scarletDeblendTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from lsst.utils.logging import PeriodicLogger
from lsst.utils.timer import timeMethod

from .source import bboxToScarletBox, modelToHeavy, liteModelToHeavy, mergeDeblendedSources
from .source import bboxToScarletBox, modelToHeavy, liteModelToHeavy

# Scarlet and proxmin have a different definition of log levels than the stack,
# so even "warnings" occur far more often than we would like.
Expand Down Expand Up @@ -315,7 +315,6 @@ def deblend(mExposure, footprint, config):

# Create the blend and attempt to optimize it
blend = Blend(sources, observation)
blend.mask = mask
try:
blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
except ArithmeticError:
Expand Down Expand Up @@ -430,10 +429,9 @@ def deblend_lite(mExposure, footprint, config, wavelets=None):
if len(sources[k].components) > 0 and np.any(sources[k].center != center):
raise ValueError("Misaligned center, expected {center} but got {sources[k].center}")
# Store the record for the peak with the appropriate source
sources[k].detectedPeaks = [footprint.peaks[k]]
sources[k].detectedPeak = footprint.peaks[k]

blend = lite.LiteBlend(sources, observation)
blend.mask = mask

# Initialize each source with its best fit spectrum
# This significantly cuts down on the number of iterations
Expand Down Expand Up @@ -587,7 +585,7 @@ class ScarletDeblendConfig(pexConfig.Config):
maskLimits = pexConfig.DictField(
keytype=str,
itemtype=float,
default={"SAT": 0},
default={},
doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
"Sources violating this limit will not be deblended. "
"If the fraction is `0` then the limit is a single pixel."),
Expand Down Expand Up @@ -628,12 +626,6 @@ class ScarletDeblendConfig(pexConfig.Config):
"a high density of sources from running out of memory. "
"If `maxSpectrumCutoff == -1` then there is no cutoff.")
)
# Blend quality fields
mergeShredded = pexConfig.Field(
dtype=bool, default=True,
doc="Whether or not to merge sources together that are likely have been shredded or "
"deblended incorrectly."
)
# Failure modes
fallback = pexConfig.Field(
dtype=bool, default=True,
Expand Down Expand Up @@ -957,10 +949,8 @@ def deblend(self, mExposure, catalog):
tf = time.monotonic()
runtime = (tf-t0)*1000
converged = _checkBlendConvergence(blend, self.config.relativeError)
# Merge sources that are likely to be part of the same source,
# or are too close to be deblended properly
# Store the number of components in the blend
if self.config.version == "lite":
blend.sources = mergeDeblendedSources(blend, ~blend.mask, factor=0.5)
nComponents = len(blend.components)
else:
nComponents = 0
Expand Down Expand Up @@ -1129,7 +1119,7 @@ def _isLargeFootprint(self, footprint):
return True
return False

def _isMasked(self, footprint, mMask):
def _isMasked(self, footprint, mExposure):
"""Returns whether the footprint violates the mask limits
Parameters
Expand All @@ -1147,10 +1137,10 @@ def _isMasked(self, footprint, mMask):
`self.config.maskLimits`.
"""
bbox = footprint.getBBox()
mask = np.bitwise_or.reduce(mMask[:, bbox].array, axis=0)
mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
size = float(footprint.getArea())
for maskName, limit in self.config.maskLimits.items():
maskVal = mMask.getPlaneBitMask(maskName)
maskVal = mExposure.mask.getPlaneBitMask(maskName)
_mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
# spanset of masked pixels
maskedSpan = footprint.spans.intersect(_mask, maskVal)
Expand Down Expand Up @@ -1232,7 +1222,7 @@ def _isSkipped(self, parent, mExposure):
# The footprint is above the maximum footprint size limit
skipKey = self.tooBigKey
skipMessage = f"Parent {parent.getId()}: skipping large footprint"
elif self._isMasked(footprint, mExposure.mask):
elif self._isMasked(footprint, mExposure):
# The footprint exceeds the maximum number of masked pixels
skipKey = self.maskedKey
skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
Expand Down Expand Up @@ -1287,6 +1277,11 @@ def _updateParentRecord(self, parent, nPeaks, nChild, nComponents,
Number of children deblended from the parent.
This may differ from `nPeaks` if some of the peaks
were culled and have no deblended model.
nComponents : `int`
Total number of components in the parent.
This is usually different than the number of children,
since it is common for a single source to have multiple
components.
runtime : `float`
Total runtime for deblending.
iterations : `int`
Expand Down
136 changes: 5 additions & 131 deletions python/lsst/meas/extensions/scarlet/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
import logging

import numpy as np
from scarlet.bbox import Box, overlapped_slices
from scarlet.detect_pybind11 import get_footprints
from scarlet.detect import bounds_to_bbox
from scarlet.lite import LiteSource
from scarlet.bbox import Box

from lsst.geom import Point2I, Box2I, Extent2I
from lsst.afw.geom import SpanSet
Expand Down Expand Up @@ -87,134 +84,12 @@ def bboxToScarletBox(nBands, bbox, xy0=Point2I()):
return Box((nBands, bbox.getHeight(), bbox.getWidth()), origin)


def scarletFootprintsToArray(footprints, shape):
"""Convert scarlet footprints to an integer array.
Given a set of footprints, create an image array that
contains the index of each footprint as the value
for all of its pixels in the mask. Because the
footprints do not overlap (by definition), we can store the
footprint indices as integers as opposed to bit values,
which would limit the number of sources per blend.
Parameters
----------
footprints : `list` of `scarlet.Footprint`
The scarlet footprints to insert into
the array.
shape : `tuple` of `int`
The shape of the image that contains all
of the footprints. Typically this is the
shape of the blend.
"""
mask = np.zeros(shape, dtype=int)
for idx, fp in enumerate(footprints):
bbox = bounds_to_bbox(fp.bounds)
mask[bbox.slices] += fp.footprint * (idx+1)
return mask


def mergeDeblendedSources(blend, footprintArr, factor=0.5):
"""Merge sources that are contained in the same footprint
In order to prevent large galaxies from being shredded or
very tight blends (that are unlikely to be deblended correctly)
from making it into the output catalog,
sources with peaks contained in the same footprint above the
noise level of the observations are grouped together to
be listed in the catalog as a single object and flagged as
a compound source.
This works because the model exists in a partially
deconvolved image space where there is rarely any flux overlap
above the noise level between objects not physically interacting.
Parameters
----------
blend: `~scarlet.LiteBlend`
The blend containing the observations and models
for all of the sources.
footprintArr: `numpy.ndarray`
The array that contains the pixels contained in the
stack footprint for the entire blend.
factor: `float`
The factor to multiply the noise by in order to set the
detection threshold.
Returns
-------
new_sources: `list`
A list of peaks for each source in the output catalog.
Each element of `soures` is a list of indices that
gives the index for each source in `blend.sources`
that is to be merged into a single source model.
"""
# Get the deconvolved model
model = blend.get_model() * footprintArr

# Get the pixels above the noise
noise = np.max(blend.observation.noise_rms) * factor
template = model > noise
template = np.sum(template, axis=0) > len(blend.observation.images) - 1

# Get the merged footprints
footprints = get_footprints(np.sum(model, axis=0)*template, 1, 4, 0, False)
peakIndices = [[] for i in range(len(footprints))]
footprints = scarletFootprintsToArray(footprints, model.shape[1:])

# Combine peaks that are contained in the same merged footprint
for k, src in enumerate(blend.sources):
if src.is_null:
continue
idx = footprints[src.center]

if idx == 0:
peakIndices.append([k])
else:
peakIndices[idx-1].append(k)
sourceIndices = [peak for peak in peakIndices if len(peak) > 0]

# Order the sources from brightest to faintest
flux = [np.sum([np.sum(blend.sources[peak].get_model()) for peak in peaks]) for peaks in sourceIndices]
indices = np.argsort(flux)

# Create the merged sources by combining the appropriate components
newSources = []
for idx in indices:
peakIndicess = sourceIndices[idx]
components = []
boxes = []
detectedPeaks = []
for peak in peakIndicess:
src = blend.sources[peak]
components.extend(src.components)
if hasattr(src, "flux"):
boxes.append(src.flux_box)
detectedPeaks.extend(src.detectedPeaks)
src = LiteSource(components, src.dtype)
src.detectedPeaks = detectedPeaks
if len(boxes) > 0:
flux_box = boxes[0]
for bbox in boxes[1:]:
flux_box |= bbox
flux_img = np.zeros(flux_box.shape, dtype=model.dtype)
for peak in peakIndicess:
_src = blend.sources[peak]
slices = overlapped_slices(flux_box, _src.flux_box)
flux_img[slices[0]] += _src.flux
src.flux = flux_img
src.flux_box = flux_box

newSources.append(src)
return newSources


def modelToHeavy(source, mExposure, blend, xy0=Point2I(), dtype=np.float32):
"""Convert a scarlet model to a `MultibandFootprint`.
Parameters
----------
source : `scarlet.Component`
source : `scarlet.Source`
The source to convert to a `HeavyFootprint`.
mExposure : `lsst.image.MultibandExposure`
The multiband exposure containing the image,
Expand Down Expand Up @@ -278,7 +153,7 @@ def liteModelToHeavy(source, mExposure, blend, xy0=Point2I(), dtype=np.float32,
"""Convert a scarlet model to a `MultibandFootprint`.
Parameters
----------
source : `scarlet.Component`
source : `scarlet.Source`
The source to convert to a `HeavyFootprint`.
mExposure : `lsst.image.MultibandExposure`
The multiband exposure containing the image,
Expand Down Expand Up @@ -339,9 +214,8 @@ def liteModelToHeavy(source, mExposure, blend, xy0=Point2I(), dtype=np.float32,
spans = SpanSet.fromMask(valid)

# Add the location of the peaks to the peak catalog
peakCat = PeakCatalog(source.detectedPeaks[0].table)
for detectedPeak in source.detectedPeaks:
peakCat.append(detectedPeak)
peakCat = PeakCatalog(source.detectedPeak.table)
peakCat.append(source.detectedPeak)

# Create the MultibandHeavyFootprint
foot = Footprint(spans)
Expand Down

0 comments on commit a2e5dfe

Please sign in to comment.