Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tickets/DM-27208 #28

Merged
merged 8 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
188 changes: 101 additions & 87 deletions python/lsst/meas/extensions/scarlet/scarletDeblendTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from functools import partial

import numpy as np
import scarlet
from scarlet.psf import PSF, gaussian
from scarlet.psf import ImagePSF, GaussianPSF
from scarlet import Blend, Frame, Observation
from scarlet_extensions.initialization.source import initAllSources
from scarlet.initialization import initAllSources

import lsst.log
import lsst.pex.config as pexConfig
from lsst.pex.exceptions import InvalidParameterError
import lsst.pipe.base as pipeBase
from lsst.geom import Point2I, Box2I, Point2D
import lsst.afw.math as afwMath
import lsst.afw.geom as afwGeom
import lsst.afw.geom.ellipses as afwEll
import lsst.afw.image.utils
Expand Down Expand Up @@ -84,19 +81,6 @@ def _getPsfFwhm(psf):
return psf.computeShape().getDeterminantRadius() * 2.35


def _estimateRMS(exposure, statsMask):
"""Estimate the standard dev. of an image

Calculate the RMS of the `exposure`.
"""
mi = exposure.getMaskedImage()
statsCtrl = afwMath.StatisticsControl()
statsCtrl.setAndMask(mi.getMask().getPlaneBitMask(statsMask))
stats = afwMath.makeStatistics(mi.variance, mi.mask, afwMath.STDEV | afwMath.MEAN, statsCtrl)
rms = np.sqrt(stats.getValue(afwMath.MEAN)**2 + stats.getValue(afwMath.STDEV)**2)
return rms


def _computePsfImage(self, position=None):
"""Get a multiband PSF image
The PSF Kernel Image is computed for each band
Expand Down Expand Up @@ -197,15 +181,17 @@ def deblend(mExposure, footprint, config):
weights = 1/mExposure.variance[:, bbox].array
else:
weights = np.ones_like(images)
badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
mask = mExposure.mask[:, bbox].array & badPixels
weights[mask > 0] = 0

# Mask out the pixels outside the footprint
mask = getFootprintMask(footprint, mExposure)
weights *= ~mask

psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)

psfShape = (config.modelPsfSize, config.modelPsfSize)
model_psf = PSF(partial(gaussian, sigma=config.modelPsfSigma), shape=(None,)+psfShape)
psfs = ImagePSF(psfs)
model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))

frame = Frame(images.shape, psfs=model_psf, channels=mExposure.filters)
observation = Observation(images, psfs=psfs, weights=weights, channels=mExposure.filters)
Expand All @@ -218,6 +204,9 @@ def deblend(mExposure, footprint, config):
maxComponents = 1
elif config.sourceModel == "double":
maxComponents = 2
elif config.sourceModel == "compact":
raise NotImplementedError("CompactSource initialization has not yet been ported"
"to the stack version of scarlet")
elif config.sourceModel == "point":
maxComponents = 0
elif config.sourceModel == "fit":
Expand Down Expand Up @@ -247,8 +236,14 @@ def deblend(mExposure, footprint, config):
)

# Attach the peak to all of the initialized sources
for k, src in enumerate(sources):
src.detectedPeak = footprint.peaks[k]
srcIndex = 0
for k, center in enumerate(centers):
if k not in skipped:
# This is just to make sure that there isn't a coding bug
assert np.all(sources[srcIndex].center == center)
# Store the record for the peak with the appropriate source
sources[srcIndex].detectedPeak = footprint.peaks[k]
srcIndex += 1

# Create the blend and attempt to optimize it
blend = Blend(sources, observation)
Expand Down Expand Up @@ -313,14 +308,16 @@ class ScarletDeblendConfig(pexConfig.Config):
dtype=bool, default=True,
doc="Whether or not to save the SEDs and templates")
processSingles = pexConfig.Field(
dtype=bool, default=False,
dtype=bool, default=True,
doc="Whether or not to process isolated sources in the deblender")
sourceModel = pexConfig.Field(
dtype=str, default="single",
doc=("How to determine which model to use for sources, from\n"
"- 'single': use a single component for all sources\n"
"- 'double': use a bulge disk model for all sources\n"
"- 'point: use a point-source model for all sources\n"
"- 'compact': use a single component model, initialzed with a point source morphology, "
" for all sources\n"
"- 'point': use a point-source model for all sources\n"
"- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)")
)
downgrade = pexConfig.Field(
Expand Down Expand Up @@ -369,11 +366,9 @@ class ScarletDeblendConfig(pexConfig.Config):
dtype=str, default="NOT_DEBLENDED", optional=True,
doc="Mask name for footprints not deblended, or None")
catchFailures = pexConfig.Field(
dtype=bool, default=False,
dtype=bool, default=True,
doc=("If True, catch exceptions thrown by the deblender, log them, "
"and set a flag on the parent, instead of letting them propagate up"))
propagateAllPeaks = pexConfig.Field(dtype=bool, default=False,
doc=('Guarantee that all peaks produce a child source.'))


class ScarletDeblendTask(pipeBase.Task):
Expand Down Expand Up @@ -464,24 +459,33 @@ def _addSchemaKeys(self, schema):
doc='Name of error if the blend failed')
self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
doc="Deblender skipped this source")
self.modelCenter = afwTable.Point2DKey.addFields(schema, name="deblend_peak_center",
doc="Center used to apply constraints in scarlet",
unit="pixel")
self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
doc="Center used to apply constraints in scarlet",
unit="pixel")
self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
doc="ID of the peak in the parent footprint. "
"This is not unique, but the combination of 'parent'"
"and 'peakId' should be for all child sources. "
"Top level blends with no parents have 'peakId=0'")
self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
doc="The instFlux at the peak position of deblended mode")
self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=20,
doc="The type of model used, for example "
"MultiExtendedSource, SingleExtendedSource, PointSource")
self.edgeFluxFlagKey = schema.addField("deblend_edgeFluxFlag", type="Flag",
doc="Source has flux on the edge of the image")
self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
doc="Flux measurement from scarlet")
self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
doc="Number of initial peaks in the blend. "
"This includes peaks that may have been culled "
"during deblending or failed to deblend")
self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
doc="Same as deblend_n_peaks, but the number of peaks "
"in the parent footprint")
self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
doc="Flux measurement from scarlet")
self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
doc="Final logL, used to identify regressions in scarlet.")

# self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in
# (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey))
# )
Expand Down Expand Up @@ -592,10 +596,16 @@ def deblend(self, mExposure, sources):
self._skipParent(src, mask)
self.log.trace('Parent %i: skipping masked footprint', int(src.getId()))
continue
if len(peaks) > self.config.maxNumberOfPeaks:
if self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't particularly matter much but you could store this value before the loop rather than checking every time, or simply set maxNumberOfPeaks to np.inf if it's zero.

src.set(self.tooManyPeaksKey, True)
msg = 'Parent {0}: Too many peaks, using the first {1} peaks'
self.log.trace(msg.format(int(src.getId()), self.config.maxNumberOfPeaks))
self._skipParent(src, mExposure.mask)
msg = 'Parent {0}: Too many peaks, skipping blend'
self.log.trace(msg.format(int(src.getId())))
# Unlike meas_deblender, in scarlet we skip the entire blend
# if the number of peaks exceeds max peaks, since neglecting
# to model any peaks often results in catastrophic failure
# of scarlet to generate models for the brighter sources.
continue

nparents += 1
self.log.trace('Parent %i: deblending %i peaks', int(src.getId()), len(peaks))
Expand Down Expand Up @@ -624,7 +634,7 @@ def deblend(self, mExposure, sources):
src.set(self.iterKey, e.iterations)
elif not isinstance(e, IncompleteDataError):
blendError = "UnknownError"

self._skipParent(src, mExposure.mask)
if self.config.catchFailures:
# Make it easy to find UnknownErrors in the log file
self.log.warn("UnknownError")
Expand All @@ -635,13 +645,8 @@ def deblend(self, mExposure, sources):

self.log.warn("Unable to deblend source %d: %s" % (src.getId(), blendError))
src.set(self.deblendFailedKey, True)
src.set(self.runtimeKey, 0)
src.set(self.deblendErrorKey, blendError)
bbox = foot.getBBox()
src.set(self.modelCenter, Point2D(bbox.getMinX(), bbox.getMinY()))
# We want to store the total number of initial peaks,
# even if some of them fail
src.set(self.nPeaksKey, len(foot.peaks))
self._skipParent(src, mExposure.mask)
continue

# Add the merged source as a parent in the catalog for each band
Expand All @@ -652,18 +657,7 @@ def deblend(self, mExposure, sources):
templateParents[f].set(self.nPeaksKey, len(foot.peaks))
templateParents[f].set(self.runtimeKey, runtime)
templateParents[f].set(self.iterKey, len(blend.loss))
# TODO: When DM-26603 is merged observation has a "log_norm"
# property that performs the following calculation,
# so this code block can be removed
observation = blend.observations[0]
_weights = observation.weights
_images = observation.images
log_sigma = np.zeros(_weights.shape, dtype=_weights.dtype)
cuts = _weights > 0
log_sigma[cuts] = np.log(1/_weights[cuts])
log_norm = np.prod(_images.shape)/2 * np.log(2*np.pi)+np.sum(log_sigma)/2
# end temporary code block
logL = blend.loss[-1]-log_norm
logL = blend.loss[-1]-blend.observations[0].log_norm
templateParents[f].set(self.scarletLogLKey, logL)

# Add each source to the catalogs in each band
Expand All @@ -673,21 +667,8 @@ def deblend(self, mExposure, sources):
# Skip any sources with no flux or that scarlet skipped because
# it could not initialize
if k in skipped:
if not self.config.propagateAllPeaks:
# We don't care
continue
# We need to preserve the peak: make sure we have enough
# info to create a minimal child src
msg = "Peak at {0} failed deblending. Using minimal default info for child."
self.log.trace(msg.format(src.getFootprint().peaks[k]))
# copy the full footprint and strip out extra peaks
foot = afwDet.Footprint(src.getFootprint())
peakList = foot.getPeaks()
peakList.clear()
peakList.append(src.peaks[k])
zeroMimg = afwImage.MaskedImageF(foot.getBBox())
heavy = afwDet.makeHeavyFootprint(foot, zeroMimg)
models = afwDet.MultibandFootprint(mExposure.filters, [heavy]*len(mExposure.filters))
# No need to propagate anything
continue
else:
src.set(self.deblendSkippedKey, False)
models = modelToHeavy(source, filters, xy0=bbox.getMin(),
Expand Down Expand Up @@ -773,13 +754,30 @@ def _skipParent(self, source, masks):
"""
fp = source.getFootprint()
source.set(self.deblendSkippedKey, True)
source.set(self.nChildKey, len(fp.getPeaks())) # It would have this many if we deblended them all
if self.config.notDeblendedMask:
for mask in masks:
mask.addMaskPlane(self.config.notDeblendedMask)
fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))

def _addChild(self, parentId, sources, heavy, scarlet_source, blend_converged, xy0, flux):
# The deblender didn't run on this source, so it has zero runtime
source.set(self.runtimeKey, 0)
# Set the center of the parent
bbox = fp.getBBox()
centerX = int(bbox.getMinX()+bbox.getWidth()/2)
centerY = int(bbox.getMinY()+bbox.getHeight()/2)
source.set(self.peakCenter, Point2I(centerX, centerY))
# There are no deblended children, so nChild = 0
source.set(self.nChildKey, 0)
# But we also want to know how many peaks that we would have
# deblended if the parent wasn't skipped.
source.set(self.nPeaksKey, len(fp.peaks))
# The blend was skipped, so it didn't take any iterations
source.set(self.iterKey, 0)
# Top level parents are not a detected peak, so they have no peakId
source.set(self.peakIdKey, 0)
# Top level parents also have no parentNPeaks
source.set(self.parentNPeaksKey, 0)

def _addChild(self, parentId, sources, heavy, scarletSource, blend_converged, xy0, flux):
"""Add a child to a catalog

This creates a new child in the source catalog,
Expand All @@ -792,29 +790,45 @@ def _addChild(self, parentId, sources, heavy, scarlet_source, blend_converged, x
src.assign(heavy.getPeaks()[0], self.peakSchemaMapper)
src.setParent(parentId)
src.setFootprint(heavy)
src.set(self.psfKey, False)
# Set the psf key based on whether or not the source was
# deblended using the PointSource model.
# This key is not that useful anymore since we now keep track of
# `modelType`, but we continue to propagate it in case code downstream
# is expecting it.
src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
src.set(self.runtimeKey, 0)
src.set(self.blendConvergenceFailedFlagKey, not blend_converged)

# Set the position of the peak from the parent footprint
# This will make it easier to match the same source across
# deblenders and across observations, where the peak
# position is unlikely to change unless enough time passes
# for a source to move on the sky.
peak = scarletSource.detectedPeak
src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
src.set(self.peakIdKey, peak["id"])

# The children have a single peak
src.set(self.nPeaksKey, 1)

# Store the flux at the center of the model and the total
# scarlet flux measurement.
morph = afwDet.multiband.heavyFootprintToImage(heavy).image.array

# Set the flux at the center of the model (for SNR)
try:
cy, cx = scarlet_source.center
cy, cx = scarletSource.center
cy = np.max([np.min([int(np.round(cy)), morph.shape[0]-1]), 0])
cx = np.max([np.min([int(np.round(cx)), morph.shape[1]-1]), 0])
src.set(self.modelCenterFlux, morph[cy, cx])
except AttributeError:
msg = "Did not recognize coordinates for source type of `{0}`, "
msg += "could not write coordinates or center flux. "
msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
logger.warning(msg.format(type(scarlet_source)))
return src
xmin, ymin = xy0
src.set(self.modelCenter, Point2D(cx+xmin, cy+ymin))
logger.warning(msg.format(type(scarletSource)))

# Store the flux at the center of the model and the total
# scarlet flux measurement.
morph = afwDet.multiband.heavyFootprintToImage(heavy).image.array
cy = np.max([np.min([int(np.round(cy)), morph.shape[0]-1]), 0])
cx = np.max([np.min([int(np.round(cx)), morph.shape[1]-1]), 0])
src.set(self.modelCenterFlux, morph[cy, cx])
src.set(self.modelTypeKey, scarlet_source.__class__.__name__)
src.set(self.edgeFluxFlagKey, scarlet_source.isEdge)
src.set(self.modelTypeKey, scarletSource.__class__.__name__)
src.set(self.edgeFluxFlagKey, scarletSource.isEdge)
# Include the source flux in the model space in the catalog.
# This uses the narrower model PSF, which ensures that all sources
# not located on an edge have all of their flux included in the
Expand Down
5 changes: 3 additions & 2 deletions python/lsst/meas/extensions/scarlet/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
from scarlet.bbox import Box
from scarlet.frame import Frame

from lsst.geom import Point2I
import lsst.log
Expand Down Expand Up @@ -64,15 +65,15 @@ def modelToHeavy(source, filters, xy0=Point2I(), observation=None, dtype=np.floa
# to take the intersection of two boxes.

# Get the PSF size and radii to grow the box
py, px = observation.frame.psf.shape[1:]
py, px = observation.frame.psf.get_model().shape[1:]
dh = py // 2
dw = px // 2
shape = (source.bbox.shape[0], source.bbox.shape[1] + py, source.bbox.shape[2] + px)
origin = (source.bbox.origin[0], source.bbox.origin[1] - dh, source.bbox.origin[2] - dw)
# Create the larger box to fit the model + PSf
bbox = Box(shape, origin=origin)
# Only use the portion of the convolved model that fits in the image
overlap = bbox & source.frame.bbox
overlap = Frame(bbox & source.frame.bbox, source.frame.channels, psfs=source.frame.psf)
# Load the full multiband model in the larger box
model = source.model_to_frame(overlap)
# Convolve the model with the PSF in each band
Expand Down