Skip to content

Commit

Permalink
Merge pull request #306 from lsst/tickets/DM-43404
Browse files Browse the repository at this point in the history
DM-43404: Create diffim QA metrics
  • Loading branch information
isullivan committed Apr 5, 2024
2 parents 36da64c + fa2b451 commit 4770a20
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 11 deletions.
211 changes: 202 additions & 9 deletions python/lsst/ip/diffim/detectAndMeasure.py
Expand Up @@ -24,12 +24,15 @@
import lsst.afw.detection as afwDetection
import lsst.afw.table as afwTable
import lsst.daf.base as dafBase
import lsst.geom
from lsst.ip.diffim.utils import getPsfFwhm, angleMean
from lsst.meas.algorithms import SkyObjectsTask, SourceDetectionTask, SetPrimaryFlagsTask
from lsst.meas.base import ForcedMeasurementTask, ApplyApCorrTask, DetectorVisitIdGeneratorConfig
import lsst.meas.deblender
import lsst.meas.extensions.trailedSources # noqa: F401
import lsst.meas.extensions.shapeHSM
import lsst.pex.config as pexConfig
from lsst.pex.exceptions import InvalidParameterError
import lsst.pipe.base as pipeBase
import lsst.utils
from lsst.utils.timer import timeMethod
Expand Down Expand Up @@ -80,6 +83,17 @@ class DetectAndMeasureConnections(pipeBase.PipelineTaskConnections,
storageClass="ExposureF",
name="{fakesType}{coaddName}Diff_differenceExp",
)
spatiallySampledMetrics = pipeBase.connectionTypes.Output(
doc="Summary metrics computed at randomized locations.",
dimensions=("instrument", "visit", "detector"),
storageClass="ArrowAstropy",
name="{fakesType}{coaddName}Diff_spatiallySampledMetrics",
)

def __init__(self, *, config=None):
super().__init__(config=config)
if not config.doWriteMetrics:
self.outputs.remove("spatiallySampledMetrics")


class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
Expand Down Expand Up @@ -156,8 +170,27 @@ class DetectAndMeasureConfig(pipeBase.PipelineTaskConfig,
"base_PixelFlags_flag_interpolatedCenterAll",
"base_PixelFlags_flag_badCenterAll",
"base_PixelFlags_flag_edgeCenterAll",
"base_PixelFlags_flag_saturatedCenterAll",
),
)
metricsMaskPlanes = lsst.pex.config.ListField(
dtype=str,
doc="List of mask planes to include in metrics",
default=('BAD', 'CLIPPED', 'CR', 'DETECTED', 'DETECTED_NEGATIVE', 'EDGE',
'INEXACT_PSF', 'INJECTED', 'INJECTED_TEMPLATE', 'INTRP', 'NOT_DEBLENDED',
'NO_DATA', 'REJECTED', 'SAT', 'SAT_TEMPLATE', 'SENSOR_EDGE', 'STREAK', 'SUSPECT',
'UNMASKEDNAN',
),
)
metricSources = pexConfig.ConfigurableField(
target=SkyObjectsTask,
doc="Generate QA metric sources",
)
doWriteMetrics = lsst.pex.config.Field(
dtype=bool,
default=True,
doc="Compute and write summary metrics."
)
idGenerator = DetectorVisitIdGeneratorConfig.make_field()

def setDefaults(self):
Expand Down Expand Up @@ -192,6 +225,7 @@ def setDefaults(self):
self.measurement.plugins["base_PixelFlags"].masksFpCenter = [
"STREAK", "INJECTED", "INJECTED_TEMPLATE"]
self.skySources.avoidMask = ["DETECTED", "DETECTED_NEGATIVE", "BAD", "NO_DATA", "EDGE"]
self.metricSources.avoidMask = ["NO_DATA", "EDGE"]


class DetectAndMeasureTask(lsst.pipe.base.PipelineTask):
Expand Down Expand Up @@ -247,6 +281,62 @@ def __init__(self, **kwargs):
for flag in self.config.badSourceFlags:
if flag not in self.schema:
raise pipeBase.InvalidQuantumError("Field %s not in schema" % flag)

if self.config.doWriteMetrics:
self.makeSubtask("metricSources")
self.metricSchema = afwTable.SourceTable.makeMinimalSchema()
self.metricSchema.addField(
"x", "F",
"X location of the metric evaluation.",
units="pixel")
self.metricSchema.addField(
"y", "F",
"Y location of the metric evaluation.",
units="pixel")
self.metricSources.skySourceKey = self.metricSchema.addField("sky_source", type="Flag",
doc="Metric evaluation objects.")
self.metricSchema.addField(
"source_density", "F",
"Density of diaSources at location.",
units="count/degree^2")
self.metricSchema.addField(
"dipole_density", "F",
"Density of dipoles at location.",
units="count/degree^2")
self.metricSchema.addField(
"dipole_direction", "F",
"Mean dipole orientation.",
units="radian")
self.metricSchema.addField(
"dipole_separation", "F",
"Mean dipole separation.",
units="pixel")
self.metricSchema.addField(
"template_value", "F",
"Median of template at location.",
units="nJy")
self.metricSchema.addField(
"science_value", "F",
"Median of science at location.",
units="nJy")
self.metricSchema.addField(
"diffim_value", "F",
"Median of diffim at location.",
units="nJy")
self.metricSchema.addField(
"science_psfSize", "F",
"Width of the science image PSF at location.",
units="pixel")
self.metricSchema.addField(
"template_psfSize", "F",
"Width of the template image PSF at location.",
units="pixel")
for maskPlane in self.config.metricsMaskPlanes:
self.metricSchema.addField(
"%s_mask_fraction"%maskPlane.lower(), "F",
"Fraction of pixels with %s mask"%maskPlane
)

# initialize InitOutputs
self.outputSchema = afwTable.SourceCatalog(self.schema)
self.outputSchema.getTable().setMetadata(self.algMetadata)
Expand Down Expand Up @@ -386,11 +476,14 @@ def processResults(self, science, matchedTemplate, difference, sources, idFactor
if self.config.doForcedMeasurement:
self.measureForcedSources(diaSources, science, difference.getWcs())

spatiallySampledMetrics = self.calculateMetrics(difference, diaSources, science, matchedTemplate,
idFactory)

measurementResults = pipeBase.Struct(
subtractedMeasuredExposure=difference,
diaSources=diaSources,
spatiallySampledMetrics=spatiallySampledMetrics,
)
self.calculateMetrics(difference)

return measurementResults

Expand Down Expand Up @@ -443,7 +536,7 @@ def deblend(footprints):
return sources, makeFootprints(positives), makeFootprints(negatives)

def _removeBadSources(self, diaSources):
"""Remove bad diaSources from the catalog.
"""Remove unphysical diaSources from the catalog.
Parameters
----------
Expand All @@ -456,19 +549,20 @@ def _removeBadSources(self, diaSources):
The updated catalog of detected sources, with any source that has a
flag in ``config.badSourceFlags`` set removed.
"""
nBadTotal = 0
selector = np.ones(len(diaSources), dtype=bool)
for flag in self.config.badSourceFlags:
flags = diaSources[flag]
nBad = np.count_nonzero(flags)
if nBad > 0:
self.log.info("Found and removed %d unphysical sources with flag %s.", nBad, flag)
self.log.debug("Found %d unphysical sources with flag %s.", nBad, flag)
selector &= ~flags
nBadTotal += nBad
nBadTotal = np.count_nonzero(~selector)
self.metadata.add("nRemovedBadFlaggedSources", nBadTotal)
self.log.info("Removed %d unphysical sources.", nBadTotal)
return diaSources[selector].copy(deep=True)

def addSkySources(self, diaSources, mask, seed):
def addSkySources(self, diaSources, mask, seed,
subtask=None):
"""Add sources in empty regions of the difference image
for measuring the background.
Expand All @@ -481,8 +575,10 @@ def addSkySources(self, diaSources, mask, seed):
seed : `int`
Seed value to initialize the random number generator.
"""
skySourceFootprints = self.skySources.run(mask=mask, seed=seed, catalog=diaSources)
self.metadata.add("nSkySources", len(skySourceFootprints))
if subtask is None:
subtask = self.skySources
skySourceFootprints = subtask.run(mask=mask, seed=seed, catalog=diaSources)
self.metadata.add(f"n_{subtask.getName()}", len(skySourceFootprints))

def measureDiaSources(self, diaSources, science, difference, matchedTemplate):
"""Use (matched) template and science image to constrain dipole fitting.
Expand Down Expand Up @@ -547,13 +643,28 @@ def measureForcedSources(self, diaSources, science, wcs):
for diaSource, forcedSource in zip(diaSources, forcedSources):
diaSource.assign(forcedSource, mapper)

def calculateMetrics(self, difference):
def calculateMetrics(self, difference, diaSources, science, matchedTemplate, idFactory):
"""Add image QA metrics to the Task metadata.
Parameters
----------
difference : `lsst.afw.image.Exposure`
The target image to calculate metrics for.
diaSources : `lsst.afw.table.SourceCatalog`
The catalog of detected sources.
science : `lsst.afw.image.Exposure`
The science image.
matchedTemplate : `lsst.afw.image.Exposure`
The reference image, warped and psf-matched to the science image.
idFactory : `lsst.afw.table.IdFactory`, optional
Generator object used to assign ids to detected sources in the
difference image.
Returns
-------
spatiallySampledMetrics : `lsst.afw.table.SourceCatalog`, or `None`
A catalog of randomized locations containing locally evaluated
metric results
"""
mask = difference.mask
badPix = (mask.array & mask.getPlaneBitMask(self.config.detection.excludeMaskPlanes)) > 0
Expand All @@ -567,6 +678,83 @@ def calculateMetrics(self, difference):
detNegPix &= badPix
self.metadata.add("nBadPixelsDetectedPositive", np.sum(detPosPix))
self.metadata.add("nBadPixelsDetectedNegative", np.sum(detNegPix))
metricsMaskPlanes = []
for maskPlane in self.config.metricsMaskPlanes:
try:
self.metadata.add("%s_mask_fraction"%maskPlane.lower(), evaluateMaskFraction(mask, maskPlane))
metricsMaskPlanes.append(maskPlane)
except InvalidParameterError:
self.metadata.add("%s_mask_fraction"%maskPlane.lower(), -1)
self.log.info("Unable to calculate metrics for mask plane %s: not in image"%maskPlane)

if self.config.doWriteMetrics:
spatiallySampledMetrics = afwTable.SourceCatalog(self.metricSchema)
spatiallySampledMetrics.getTable().setIdFactory(idFactory)
self.addSkySources(spatiallySampledMetrics, science.mask, difference.info.id,
subtask=self.metricSources)
for src in spatiallySampledMetrics:
self._evaluateLocalMetric(src, diaSources, science, matchedTemplate, difference,
metricsMaskPlanes=metricsMaskPlanes)

return spatiallySampledMetrics.asAstropy()

def _evaluateLocalMetric(self, src, diaSources, science, matchedTemplate, difference,
metricsMaskPlanes):
"""Calculate image quality metrics at spatially sampled locations.
Parameters
----------
src : `lsst.afw.table.SourceRecord`
The source record to be updated with metric calculations.
diaSources : `lsst.afw.table.SourceCatalog`
The catalog of detected sources.
science : `lsst.afw.image.Exposure`
The science image.
matchedTemplate : `lsst.afw.image.Exposure`
The reference image, warped and psf-matched to the science image.
difference : `lsst.afw.image.Exposure`
Result of subtracting template from the science image.
metricsMaskPlanes : `list` of `str`
Mask planes to calculate metrics from.
"""
bbox = src.getFootprint().getBBox()
pix = bbox.getCenter()
src.set('science_psfSize', getPsfFwhm(science.psf, position=pix))
src.set('template_psfSize', getPsfFwhm(matchedTemplate.psf, position=pix))

metricRegionSize = 100
bbox.grow(metricRegionSize)
bbox = bbox.clippedTo(science.getBBox())
nPix = bbox.getArea()
pixScale = science.wcs.getPixelScale()
area = nPix*pixScale.asDegrees()**2
peak = src.getFootprint().getPeaks()[0]
src.set('x', peak['i_x'])
src.set('y', peak['i_y'])
src.setCoord(science.wcs.pixelToSky(peak['i_x'], peak['i_y']))
selectSources = diaSources[bbox.contains(diaSources.getX(), diaSources.getY())]
if self.config.doSkySources:
selectSources = selectSources[~selectSources["sky_source"]]
sourceDensity = len(selectSources)/area
dipoleSources = selectSources[selectSources["ip_diffim_DipoleFit_flag_classification"]]
dipoleDensity = len(dipoleSources)/area
if dipoleSources:
meanDipoleOrientation = angleMean(dipoleSources["ip_diffim_DipoleFit_orientation"])
src.set('dipole_direction', meanDipoleOrientation)
meanDipoleSeparation = np.mean(dipoleSources["ip_diffim_DipoleFit_separation"])
src.set('dipole_separation', meanDipoleSeparation)
templateVal = np.median(matchedTemplate[bbox].image.array)
scienceVal = np.median(science[bbox].image.array)
diffimVal = np.median(difference[bbox].image.array)
src.set('source_density', sourceDensity)
src.set('dipole_density', dipoleDensity)
src.set('template_value', templateVal)
src.set('science_value', scienceVal)
src.set('diffim_value', diffimVal)
for maskPlane in metricsMaskPlanes:
src.set("%s_mask_fraction"%maskPlane.lower(),
evaluateMaskFraction(difference.mask[bbox], maskPlane)
)


class DetectAndMeasureScoreConnections(DetectAndMeasureConnections):
Expand Down Expand Up @@ -653,3 +841,8 @@ def run(self, science, matchedTemplate, difference, scoreExposure,

return self.processResults(science, matchedTemplate, difference, sources, idFactory,
positiveFootprints=positives, negativeFootprints=negatives)


def evaluateMaskFraction(mask, maskPlane):
nMaskSet = np.count_nonzero((mask.array & mask.getPlaneBitMask(maskPlane)))
return nMaskSet/mask.array.size
17 changes: 17 additions & 0 deletions python/lsst/ip/diffim/utils.py
Expand Up @@ -1323,3 +1323,20 @@ def computePSFNoiseEquivalentArea(psf):
psfImg = psf.computeImage(psf.getAveragePosition())
nea = 1./np.sum(psfImg.array**2)
return nea


def angleMean(angles):
"""Calculate the mean of an array of angles.
Parameters
----------
angles : `ndarray`
An array of angles, in degrees
Returns
-------
`lsst.geom.Angle`
The mean angle
"""
complexArray = [complex(np.cos(np.deg2rad(angle)), np.sin(np.deg2rad(angle))) for angle in angles]
return (geom.Angle(np.angle(np.mean(complexArray))))
5 changes: 3 additions & 2 deletions tests/test_detectAndMeasure.py
Expand Up @@ -93,7 +93,7 @@ def _check_values(self, values, minValue=None, maxValue=None):
if maxValue is not None:
self.assertTrue(np.all(values <= maxValue))

def _setup_detection(self, doSkySources=False, nSkySources=5, **kwargs):
def _setup_detection(self, doSkySources=False, nSkySources=5, doWriteMetrics=False, **kwargs):
"""Setup and configure the detection and measurement PipelineTask.
Parameters
Expand All @@ -114,6 +114,7 @@ def _setup_detection(self, doSkySources=False, nSkySources=5, **kwargs):
config.doSkySources = doSkySources
if doSkySources:
config.skySources.nSources = nSkySources
config.doWriteMetrics = doWriteMetrics
config.update(**kwargs)

# Make a realistic id generator so that output catalog ids are useful.
Expand Down Expand Up @@ -739,7 +740,7 @@ def test_sky_sources(self):
# Run detection and check the results
output = detectionTask.run(science, matchedTemplate, difference, score,
idFactory=self.idGenerator.make_table_id_factory())
nSkySourcesGenerated = detectionTask.metadata["nSkySources"]
nSkySourcesGenerated = detectionTask.metadata["n_skySources"]
skySources = output.diaSources[output.diaSources["sky_source"]]
self.assertEqual(len(skySources), nSkySourcesGenerated)
for skySource in skySources:
Expand Down

0 comments on commit 4770a20

Please sign in to comment.