Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-17426: Improve full-visit sky subtraction #74

Merged
merged 5 commits into from
Apr 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
148 changes: 146 additions & 2 deletions python/lsst/pipe/drivers/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import numpy
import itertools
from scipy.ndimage import gaussian_filter

import lsst.afw.math as afwMath
import lsst.afw.image as afwImage
import lsst.afw.geom as afwGeom
import lsst.afw.cameraGeom as afwCameraGeom
import lsst.meas.algorithms as measAlg
import lsst.afw.table as afwTable

from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField
from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField, ConfigurableField
from lsst.pipe.base import Task


Expand Down Expand Up @@ -471,6 +474,8 @@ class FocalPlaneBackgroundConfig(Config):
"NONE": "No background estimation is to be attempted",
},
)
doSmooth = Field(dtype=bool, default=False, doc="Do smoothing?")
smoothScale = Field(dtype=float, default=2.0, doc="Smoothing scale, as a multiple of the bin size")
binning = Field(dtype=int, default=64, doc="Binning to use for CCD background model (pixels)")


Expand Down Expand Up @@ -724,5 +729,144 @@ def getStatsImage(self):
values /= self._numbers
thresh = self.config.minFrac*self.config.xSize*self.config.ySize
isBad = self._numbers.getArray() < thresh
interpolateBadPixels(values.getArray(), isBad, self.config.interpolation)
if self.config.doSmooth:
array = values.getArray()
array[:] = smoothArray(array, isBad, self.config.smoothScale)
isBad = numpy.isnan(values.array)
if numpy.any(isBad):
interpolateBadPixels(values.getArray(), isBad, self.config.interpolation)
return values


class MaskObjectsConfig(Config):
"""Configuration for MaskObjectsTask"""
nIter = Field(dtype=int, default=3, doc="Number of iterations")
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask,
doc="Background subtraction")
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Source detection")
detectSigma = Field(dtype=float, default=5.0, doc="Detection threshold (standard deviations)")
doInterpolate = Field(dtype=bool, default=True, doc="Interpolate when removing objects?")
interpolate = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc="Interpolation")

def setDefaults(self):
self.detection.reEstimateBackground = False
self.detection.doTempLocalBackground = False
self.detection.doTempWideBackground = False
self.detection.thresholdValue = 2.5
self.subtractBackground.binSize = 1024
self.subtractBackground.useApprox = False
self.interpolate.binSize = 256
self.interpolate.useApprox = False

def validate(self):
if (self.detection.reEstimateBackground or
self.detection.doTempLocalBackground or
self.detection.doTempWideBackground):
raise RuntimeError("Incorrect settings for object masking: reEstimateBackground, "
"doTempLocalBackground and doTempWideBackground must be False")


class MaskObjectsTask(Task):
"""Iterative masking of objects on an Exposure

This task makes more exhaustive object mask by iteratively doing detection
and background-subtraction. The purpose of this task is to get true
background removing faint tails of large objects. This is useful to get a
clean sky estimate from relatively small number of visits.

We deliberately use the specified ``detectSigma`` instead of the PSF,
in order to better pick up the faint wings of objects.
"""
ConfigClass = MaskObjectsConfig

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())
self.makeSubtask("interpolate")
self.makeSubtask("subtractBackground")

def run(self, exposure, maskPlanes=None):
"""Mask objects on Exposure

Objects are found and removed.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure on which to mask objects.
maskPlanes : iterable of `str`, optional
List of mask planes to remove.
"""
self.findObjects(exposure)
self.removeObjects(exposure, maskPlanes)

def findObjects(self, exposure):
"""Iteratively find objects on an exposure

Objects are masked with the ``DETECTED`` mask plane.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure on which to mask objects.
"""
for _ in range(self.config.nIter):
bg = self.subtractBackground.run(exposure).background
self.detection.detectFootprints(exposure, sigma=self.config.detectSigma, clearMask=True)
exposure.maskedImage += bg.getImage()

def removeObjects(self, exposure, maskPlanes=None):
"""Remove objects from exposure

We interpolate over using a background model if ``doInterpolate`` is
set; otherwise we simply replace everything with the median.

Parameters
----------
exposure : `lsst.afw.image.Exposure`
Exposure on which to mask objects.
maskPlanes : iterable of `str`, optional
List of mask planes to remove. ``DETECTED`` will be added as well.
"""
image = exposure.image
mask = exposure.mask
maskVal = mask.getPlaneBitMask("DETECTED")
if maskPlanes is not None:
maskVal |= mask.getPlaneBitMask(maskPlanes)
isBad = mask.array & maskVal > 0

if self.config.doInterpolate:
smooth = self.interpolate.fitBackground(exposure.maskedImage)
replace = smooth.getImageF().array[isBad]
mask.array &= ~mask.getPlaneBitMask(["DETECTED"])
else:
replace = numpy.median(image.array[~isBad])
image.array[isBad] = replace


def smoothArray(array, bad, sigma):
"""Gaussian-smooth an array while ignoring bad pixels

It's not sufficient to set the bad pixels to zero, as then they're treated
as if they are zero, rather than being ignored altogether. We need to apply
a correction to that image that removes the effect of the bad pixels.

Parameters
----------
array : `numpy.ndarray` of floating-point
Array to smooth.
bad : `numpy.ndarray` of `bool`
Flag array indicating bad pixels.
sigma : `float`
Gaussian sigma.

Returns
-------
convolved : `numpy.ndarray`
Smoothed image.
"""
convolved = gaussian_filter(numpy.where(bad, 0.0, array), sigma, mode="constant", cval=0.0)
numerator = gaussian_filter(numpy.ones_like(array), sigma, mode="constant", cval=0.0)
denominator = gaussian_filter(numpy.where(bad, 0.0, 1.0), sigma, mode="constant", cval=0.0)
return convolved*numerator/denominator
33 changes: 6 additions & 27 deletions python/lsst/pipe/drivers/constructCalibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.ctrl.pool.pool import Pool, NODE
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import (SkyMeasurementTask, FocalPlaneBackground,
FocalPlaneBackgroundConfig, MaskObjectsTask)
from lsst.pipe.drivers.visualizeVisit import makeCameraImage

from .checksum import checksum
Expand Down Expand Up @@ -1171,10 +1172,9 @@ def processSingle(self, sensorRef):

class SkyConfig(CalibConfig):
"""Configuration for sky frame construction"""
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
detectSigma = Field(dtype=float, default=2.0, doc="Detection PSF gaussian sigma")
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask,
doc="Regular-scale background configuration, for object detection")
maskObjects = ConfigurableField(target=MaskObjectsTask,
doc="Configuration for masking objects aggressively")
largeScaleBackground = ConfigField(dtype=FocalPlaneBackgroundConfig,
doc="Large-scale background configuration")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
Expand All @@ -1201,8 +1201,7 @@ class SkyTask(CalibTask):

def __init__(self, *args, **kwargs):
CalibTask.__init__(self, *args, **kwargs)
self.makeSubtask("detection")
self.makeSubtask("subtractBackground")
self.makeSubtask("maskObjects")
self.makeSubtask("sky")

def scatterProcess(self, pool, ccdIdLists):
Expand Down Expand Up @@ -1286,27 +1285,7 @@ def processSingleBackground(self, dataRef):
return dataRef.get("postISRCCD")
exposure = CalibTask.processSingle(self, dataRef)

# Detect sources. Requires us to remove the background; we'll restore it later.
bgTemp = self.subtractBackground.run(exposure).background
footprints = self.detection.detectFootprints(exposure, sigma=self.config.detectSigma)
image = exposure.getMaskedImage()
if footprints.background is not None:
image += footprints.background.getImage()

# Mask high pixels
variance = image.getVariance()
noise = np.sqrt(np.median(variance.getArray()))
isHigh = image.getImage().getArray() > self.config.maskThresh*noise
image.getMask().getArray()[isHigh] |= image.getMask().getPlaneBitMask("DETECTED")

# Restore the background: it's what we want!
image += bgTemp.getImage()

# Set detected/bad pixels to background to ensure they don't corrupt the background
maskVal = image.getMask().getPlaneBitMask(self.config.mask)
isBad = image.getMask().getArray() & maskVal > 0
bgLevel = np.median(image.getImage().getArray()[~isBad])
image.getImage().getArray()[isBad] = bgLevel
self.maskObjects.run(exposure, self.config.mask)
dataRef.put(exposure, "postISRCCD")
return exposure

Expand Down
88 changes: 51 additions & 37 deletions python/lsst/pipe/drivers/skyCorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import lsst.afw.math as afwMath
import lsst.afw.image as afwImage
import lsst.afw.table as afwTable
import lsst.meas.algorithms as measAlg

from lsst.pipe.base import ArgumentParser, Struct
from lsst.pex.config import Config, Field, ConfigurableField, ConfigField
from lsst.ctrl.pool.pool import Pool
from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import (SkyMeasurementTask, FocalPlaneBackground,
FocalPlaneBackgroundConfig, MaskObjectsTask)
import lsst.pipe.drivers.visualizeVisit as visualizeVisit

DEBUG = False # Debugging outputs?
Expand Down Expand Up @@ -38,21 +37,22 @@ def makeCameraImage(camera, exposures, filename=None, binning=8):
class SkyCorrectionConfig(Config):
"""Configuration for SkyCorrectionTask"""
bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model")
bgModel2 = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="2nd Background model")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
doDetection = Field(dtype=bool, default=True, doc="Detect sources (to find good sky)?")
detectSigma = Field(dtype=float, default=5.0, doc="Detection PSF gaussian sigma")
maskObjects = ConfigurableField(target=MaskObjectsTask, doc="Mask Objects")
doMaskObjects = Field(dtype=bool, default=True, doc="Mask objects to find good sky?")
doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?")
doBgModel2 = Field(dtype=bool, default=True, doc="Do cleanup background model subtraction?")
doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?")
binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images")

def setDefaults(self):
Config.setDefaults(self)
self.detection.reEstimateBackground = False
self.detection.thresholdPolarity = "both"
self.detection.doTempLocalBackground = False
self.detection.thresholdType = "pixel_stdev"
self.detection.thresholdValue = 3.0
self.bgModel2.doSmooth = True
self.bgModel2.minFrac = 0.5
self.bgModel2.xSize = 256
self.bgModel2.ySize = 256
self.bgModel2.smoothScale = 1.0


class SkyCorrectionTask(BatchPoolTask):
Expand All @@ -62,9 +62,8 @@ class SkyCorrectionTask(BatchPoolTask):

def __init__(self, *args, **kwargs):
BatchPoolTask.__init__(self, *args, **kwargs)
self.makeSubtask("maskObjects")
self.makeSubtask("sky")
# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())

@classmethod
def _makeArgumentParser(cls, *args, **kwargs):
Expand Down Expand Up @@ -100,7 +99,10 @@ def runDataRef(self, expRef):
algorithms. We optionally apply:

1. A large-scale background model.
This step removes very-large-scale sky such as moonlight.
2. A sky frame.
3. A medium-scale background model.
This step removes residual sky (This is smooth on the focal plane).

Only the master node executes this method. The data is held on
the slave nodes, which do all the hard work.
Expand Down Expand Up @@ -131,21 +133,7 @@ def runDataRef(self, expRef):
makeCameraImage(camera, exposures, "mask" + extension)

if self.config.doBgModel:
bgModel = FocalPlaneBackground.fromCamera(self.config.bgModel, camera)
data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList]
bgModelList = pool.mapToPrevious(self.accumulateModel, data)
for ii, bg in enumerate(bgModelList):
self.log.info("Background %d: %d pixels", ii, bg._numbers.getArray().sum())
bgModel.merge(bg)

if DEBUG:
bgModel.getStatsImage().writeFits("bgModel" + extension)
bgImages = pool.mapToPrevious(self.realiseModel, dataIdList, bgModel)
makeCameraImage(camera, bgImages, "bgModelCamera" + extension)

exposures = pool.mapToPrevious(self.subtractModel, dataIdList, bgModel)
if DEBUG:
makeCameraImage(camera, exposures, "modelsub" + extension)
exposures = self.focalPlaneBackground(camera, pool, dataIdList, self.config.bgModel)

if self.config.doSky:
measScales = pool.mapToPrevious(self.measureSkyFrame, dataIdList)
Expand All @@ -157,12 +145,44 @@ def runDataRef(self, expRef):
calibs = pool.mapToPrevious(self.collectSky, dataIdList)
makeCameraImage(camera, calibs, "sky" + extension)

if self.config.doBgModel2:
exposures = self.focalPlaneBackground(camera, pool, dataIdList, self.config.bgModel2)

# Persist camera-level image of calexp
image = makeCameraImage(camera, exposures)
expRef.put(image, "calexp_camera")

pool.mapToPrevious(self.write, dataIdList)

def focalPlaneBackground(self, camera, pool, dataIdList, config):
"""Perform full focal-plane background subtraction

This method runs on the master node.

Parameters
----------
camera : `lsst.afw.cameraGeom.Camera`
Camera description.
pool : `lsst.ctrl.pool.Pool`
Process pool.
dataIdList : iterable of `dict`
List of data identifiers for the CCDs.
config : `lsst.pipe.drivers.background.FocalPlaneBackgroundConfig`
Configuration to use for background subtraction.

Returns
-------
exposures : `list` of `lsst.afw.image.Image`
List of binned images, for creating focal plane image.
"""
bgModel = FocalPlaneBackground.fromCamera(config, camera)
data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList]
bgModelList = pool.mapToPrevious(self.accumulateModel, data)
for ii, bg in enumerate(bgModelList):
self.log.info("Background %d: %d pixels", ii, bg._numbers.array.sum())
bgModel.merge(bg)
return pool.mapToPrevious(self.subtractModel, dataIdList, bgModel)

def loadImage(self, cache, dataId):
"""Load original image and restore the sky

Expand All @@ -185,15 +205,6 @@ def loadImage(self, cache, dataId):
bgOld = cache.butler.get("calexpBackground", dataId, immediate=True)
image = cache.exposure.getMaskedImage()

if self.config.doDetection:
# We deliberately use the specified 'detectSigma' instead of the PSF, in order to better pick up
# the faint wings of objects.
results = self.detection.detectFootprints(cache.exposure, doSmooth=True,
sigma=self.config.detectSigma, clearMask=True)
if hasattr(results, "background") and results.background:
# Restore any background that was removed during detection
image += results.background.getImage()

# We're removing the old background, so change the sense of all its components
for bgData in bgOld:
statsImage = bgData[0].getStatsImage()
Expand All @@ -204,6 +215,9 @@ def loadImage(self, cache, dataId):
for bgData in bgOld:
cache.bgList.append(bgData)

if self.config.doMaskObjects:
self.maskObjects.findObjects(cache.exposure)

return self.collect(cache)

def measureSkyFrame(self, cache, dataId):
Expand Down