Skip to content

Commit

Permalink
Implement scarlet lite
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Feb 9, 2022
1 parent a1fafd7 commit 4ef56e3
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 18 deletions.
257 changes: 241 additions & 16 deletions python/lsst/meas/extensions/scarlet/scarletDeblendTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from functools import partial
import logging
import numpy as np
import scarlet
from scarlet.psf import ImagePSF, GaussianPSF
from scarlet import Blend, Frame, Observation
from scarlet.renderer import ConvolutionRenderer
from scarlet.detect import get_detect_wavelets
from scarlet.initialization import init_all_sources
from scarlet import lite

import lsst.pex.config as pexConfig
from lsst.pex.exceptions import InvalidParameterError
Expand All @@ -37,7 +40,7 @@
import lsst.afw.table as afwTable
from lsst.utils.timer import timeMethod

from .source import modelToHeavy
from .source import bboxToScarletBox, modelToHeavy, liteModelToHeavy

# Scarlet and proxmin have a different definition of log levels than the stack,
# so even "warnings" occur far more often than we would like.
Expand Down Expand Up @@ -328,6 +331,126 @@ def deblend(mExposure, footprint, config):
return blend, skipped, spectrumInit


def deblend_lite(mExposure, footprint, config, wavelets=None):
"""Deblend a parent footprint
Parameters
----------
mExposure : `lsst.image.MultibandExposure`
- The multiband exposure containing the image,
mask, and variance data
footprint : `lsst.detection.Footprint`
- The footprint of the parent to deblend
config : `ScarletDeblendConfig`
- Configuration of the deblending task
"""
# Extract coordinates from each MultiColorPeak
bbox = footprint.getBBox()

# Create the data array from the masked images
images = mExposure.image[:, bbox].array
variance = mExposure.variance[:, bbox].array

# Use the inverse variance as the weights
if config.useWeights:
weights = 1/mExposure.variance[:, bbox].array
else:
weights = np.ones_like(images)
badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
mask = mExposure.mask[:, bbox].array & badPixels
weights[mask > 0] = 0

# Mask out the pixels outside the footprint
mask = getFootprintMask(footprint, mExposure)
weights *= ~mask

psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
modelPsf = lite.integrated_circular_gaussian(sigma=config.modelPsfSigma)

observation = lite.LiteObservation(
images=images,
variance=variance,
weights=weights,
psfs=psfs,
model_psf=modelPsf[None, :, :],
convolution_mode=config.convolutionType,
)

# Convert the centers to pixel coordinates
xmin = bbox.getMinX()
ymin = bbox.getMinY()
centers = [
np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
for peak in footprint.peaks
if not isPseudoSource(peak, config.pseudoColumns)
]

# Initialize the sources
if config.morphImage == "chi2":
sources = lite.init_all_sources_main(
observation,
centers,
min_snr=config.minSNR,
thresh=config.morphThresh,
)
elif config.morphImage == "wavelet":
_bbox = bboxToScarletBox(len(mExposure.filters), bbox, bbox.getMin())
_wavelets = wavelets[(slice(None), *_bbox[1:].slices)]
sources = lite.init_all_sources_wavelets(
observation,
centers,
use_psf=False,
wavelets=_wavelets,
min_snr=config.minSNR,
)
else:
raise ValueError("morphImage must be either 'chi2' or 'wavelet'.")

# Set the optimizer
if config.optimizer == "adaprox":
parameterization = partial(
lite.init_adaprox_component,
bg_thresh=config.backgroundThresh,
max_prox_iter=config.maxProxIter,
)
elif config.optimizer == "fista":
parameterization = partial(
lite.init_fista_component,
bg_thresh=config.backgroundThresh,
)
else:
raise ValueError("Unrecognized optimizer. Must be either 'adaprox' or 'fista'.")
sources = lite.parameterize_sources(sources, observation, parameterization)

# Attach the peak to all of the initialized sources
for k, center in enumerate(centers):
# This is just to make sure that there isn't a coding bug
if len(sources[k].components) > 0 and np.any(sources[k].center != center):
raise ValueError("Misaligned center, expected {center} but got {sources[k].center}")
# Store the record for the peak with the appropriate source
sources[k].detectedPeak = footprint.peaks[k]

blend = lite.LiteBlend(sources, observation)

# Initialize each source with its best fit spectrum
# This significantly cuts town on the number of iterations
# that the optimizer needs and usually results in a better
# fit, but using least squares on a very large blend causes memory issues.
# This is typically the most expensive operation in deblending, memorywise.
spectrumInit = False
if config.setSpectra:
if config.maxSpectrumCutoff <= 0 or len(centers) * bbox.getArea() < config.maxSpectrumCutoff:
spectrumInit = True
blend.fit_spectra()

# Set the sources that could not be initialized and were skipped
skipped = [src for src in sources if src.is_null]

blend.fit(max_iter=config.maxIter, e_rel=config.relativeError, min_iter=config.minIter)

return blend, skipped, spectrumInit


class ScarletDeblendConfig(pexConfig.Config):
"""MultibandDeblendConfig
Expand All @@ -339,16 +462,72 @@ class ScarletDeblendConfig(pexConfig.Config):
- Other: Parameters that don't fit into the above categories
"""
# Stopping Criteria
minIter = pexConfig.Field(dtype=int, default=1,
doc="Minimum number of iterations before the optimizer is allowed to stop.")
maxIter = pexConfig.Field(dtype=int, default=300,
doc=("Maximum number of iterations to deblend a single parent"))
relativeError = pexConfig.Field(dtype=float, default=1e-4,
doc=("Change in the loss function between"
"iterations to exit fitter"))
relativeError = pexConfig.Field(dtype=float, default=1e-2,
doc=("Change in the loss function between iterations to exit fitter. "
"Typically this is `1e-2` if measurements will be made on the "
"flux re-distributed models and `1e-4` when making measurements "
"on the models themselves."))

# Constraints
morphThresh = pexConfig.Field(dtype=float, default=1,
doc="Fraction of background RMS a pixel must have"
"to be included in the initial morphology")
# Lite Parameters
# All of these parameters (except version) are only valid if version='lite'
version = pexConfig.ChoiceField(
dtype=str,
default="lite",
allowed={
"scarlet": "main scarlet version (likely to be deprecated soon)",
"lite": "Optimized version of scarlet for survey data from a single instrument",
},
doc="The version of scarlet to use.",
)
optimizer = pexConfig.ChoiceField(
dtype=str,
default="adaprox",
allowed={
"adaprox": "Proximal ADAM optimization",
"fista": "Accelerated proximal gradient method",
},
doc="The optimizer to use for fitting parameters and is only used when version='lite'",
)
morphImage = pexConfig.ChoiceField(
dtype=str,
default="chi2",
allowed={
"chi2": "Initialize sources on a chi^2 image made from all available bands",
"wavelet": "Initialize sources using a wavelet decomposition of the chi^2 image",
},
doc="The type of image to use for initializing the morphology. "
"Must be either 'chi2' or 'wavelet'. "
)
backgroundThresh = pexConfig.Field(
dtype=float,
default=0.25,
doc="Fraction of background to use for a sparsity threshold. "
"This prevents sources from growing unrealistically outside "
"the parent footprint while still modeling flux correctly "
"for bright sources."
)
maxProxIter = pexConfig.Field(
dtype=int,
default=1,
doc="Maximum number of proximal operator iterations inside of each "
"iteration of the optimizer. "
"This config field is only used if version='lite' and optimizer='adaprox'."
)
waveletScales = pexConfig.Field(
dtype=int,
default=5,
doc="Number of wavelet scales to use for wavelet initialization. "
"This field is only used when `version`='lite' and `morphImage`='wavelet'."
)

# Other scarlet paremeters
useWeights = pexConfig.Field(
dtype=bool, default=True,
Expand Down Expand Up @@ -384,7 +563,8 @@ class ScarletDeblendConfig(pexConfig.Config):
"- 'compact': use a single component model, initialzed with a point source morphology, "
" for all sources\n"
"- 'point': use a point-source model for all sources\n"
"- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)")
"- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)"),
deprecated="This field will be deprecated when the default for `version` is changed to `lite`.",
)
setSpectra = pexConfig.Field(
dtype=bool, default=True,
Expand Down Expand Up @@ -695,6 +875,14 @@ def deblend(self, mExposure, catalog):
self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure))
nextLogTime = time.time() + self.config.loggingInterval

# Create a set of wavelet coefficients if using wavelet initialization
if self.config.version == "lite" and self.config.morphImage == "wavelet":
images = mExposure.image.array
variance = mExposure.variance.array
wavelets = get_detect_wavelets(images, variance, scales=self.config.waveletScales)
else:
wavelets = None

# Add the NOT_DEBLENDED mask to the mask plane in each band
if self.config.notDeblendedMask:
for mask in mExposure.mask:
Expand All @@ -708,6 +896,11 @@ def deblend(self, mExposure, catalog):
"fluxes": [],
"centerFluxes": [],
}
weightedColumns = {
"heavies": [],
"fluxes": [],
"centerFluxes": [],
}
for parentIndex in range(nParents):
parent = catalog[parentIndex]
foot = parent.getFootprint()
Expand Down Expand Up @@ -771,17 +964,16 @@ def deblend(self, mExposure, catalog):
try:
t0 = time.time()
# Build the parameter lists with the same ordering
blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
if self.config.version == "scarlet":
blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
elif self.config.version == "lite":
blend, skipped, spectrumInit = deblend_lite(mExposure, foot, self.config, wavelets)
tf = time.time()
runtime = (tf-t0)*1000
converged = _checkBlendConvergence(blend, self.config.relativeError)

scarletSources = [src for src in blend.sources]
nChild = len(scarletSources)
# Re-insert place holders for skipped sources
# to propagate them in the catalog so
# that the peaks stay consistent
for k in skipped:
scarletSources.insert(k, None)
# Catch all errors and filter out the ones that we know about
except Exception as e:
blendError = type(e).__name__
Expand All @@ -807,7 +999,10 @@ def deblend(self, mExposure, catalog):
continue

# Update the parent record with the deblending results
logL = blend.loss[-1]-blend.observations[0].log_norm
if self.config.version == "scarlet":
logL = -blend.loss[-1]+blend.observations[0].log_norm
elif self.config.version == "lite":
logL = blend.loss[-1]
self._updateParentRecord(
parent=parent,
nPeaks=len(peaks),
Expand All @@ -823,11 +1018,24 @@ def deblend(self, mExposure, catalog):
for k, scarletSource in enumerate(scarletSources):
# Skip any sources with no flux or that scarlet skipped because
# it could not initialize
if k in skipped:
if k in skipped or (self.config.version == "lite" and scarletSource.is_null):
# No need to propagate anything
continue
parent.set(self.deblendSkippedKey, False)
mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
if self.config.version == "lite":
mHeavy = liteModelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
weightedHeavy = liteModelToHeavy(
scarletSource, mExposure, blend, xy0=bbox.getMin(), useFlux=True)
weightedColumns["heavies"].append(weightedHeavy)
flux = scarletSource.get_model(use_flux=True).sum(axis=(1, 2))
weightedColumns["fluxes"].append({
filters[fidx]: _flux
for fidx, _flux in enumerate(flux)
})
centerFlux = self._getCenterFlux(weightedHeavy, scarletSource, xy0=bbox.getMin())
weightedColumns["centerFluxes"].append(centerFlux)
else:
mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
multibandColumns["heavies"].append(mHeavy)
flux = scarlet.measure.flux(scarletSource)
multibandColumns["fluxes"].append({
Expand Down Expand Up @@ -862,14 +1070,15 @@ def deblend(self, mExposure, catalog):
msg = f"Added {len(catalog)-nParents} new sources, but have "
msg += ", ".join([
f"{len(value)} {key}"
for key, value in multibandColumns
for key, value in multibandColumns.items()
])
raise RuntimeError(msg)
# Make a copy of the catlog in each band and update the footprints
catalogs = {}
for f in filters:
_catalog = afwTable.SourceCatalog(catalog.table.clone())
_catalog.extend(catalog, deep=True)

# Update the footprints and columns that are different
# for each filter
for sourceIndex, source in enumerate(_catalog[nParents:]):
Expand All @@ -878,6 +1087,22 @@ def deblend(self, mExposure, catalog):
source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
catalogs[f] = _catalog

weightedCatalogs = {}
if self.config.version == "lite":
# Also create a catalog by reweighting the flux
weightedCatalogs = {}
for f in filters:
_catalog = afwTable.SourceCatalog(catalog.table.clone())
_catalog.extend(catalog, deep=True)

# Update the footprints and columns that are different
# for each filter
for sourceIndex, source in enumerate(_catalog[nParents:]):
source.setFootprint(weightedColumns["heavies"][sourceIndex][f])
source.set(self.scarletFluxKey, weightedColumns["fluxes"][sourceIndex][f])
source.set(self.modelCenterFlux, weightedColumns["centerFluxes"][sourceIndex][f])
weightedCatalogs[f] = _catalog

# Update the mExposure mask with the footprint of skipped parents
if self.config.notDeblendedMask:
for mask in mExposure.mask:
Expand All @@ -888,7 +1113,7 @@ def deblend(self, mExposure, catalog):
self.log.info("Deblender results: of %d parent sources, %d were deblended, "
"creating %d children, for a total of %d sources",
nParents, nDeblendedParents, nChildren, len(catalog))
return catalogs
return catalogs, weightedCatalogs

def _isLargeFootprint(self, footprint):
"""Returns whether a Footprint is large
Expand Down

0 comments on commit 4ef56e3

Please sign in to comment.