Skip to content

Commit

Permalink
Make measureAndNormalize a BrightStarStamp method
Browse files Browse the repository at this point in the history
  • Loading branch information
MorganSchmitz committed Jan 26, 2021
1 parent 620bf3f commit 1e39c11
Showing 1 changed file with 38 additions and 74 deletions.
112 changes: 38 additions & 74 deletions python/lsst/pipe/tasks/processBrightStars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,17 @@

import numpy as np
import astropy.units as u
from operator import ior
from functools import reduce

from lsst import geom
from lsst.afw import math as afwMath
from lsst.afw import image as afwImage
from lsst.afw import geom as afwGeom
from lsst.afw import cameraGeom as cg
from lsst.afw.geom import transformFactory as tFactory
import lsst.pex.config as pexConfig
from lsst.pipe import base as pipeBase
from lsst.pipe.base import connectionTypes as cT
from lsst.meas.algorithms.loadIndexedReferenceObjects import LoadIndexedReferenceObjectsTask
from lsst.meas.algorithms import ReferenceObjectLoader
from lsst.meas.algorithms import brightStarStamps as bSS


Expand All @@ -50,6 +48,14 @@ class ProcessBrightStarsConnections(pipeBase.PipelineTaskConnections, dimensions
storageClass="ExposureF",
dimensions=("visit", "detector")
)
refCat = cT.PrerequisiteInput(
doc="Reference catalog that contains bright star positions",
name="gaia_dr2_20200414",
storageClass="SimpleCatalog",
dimensions=("skypix",),
multiple=True,
deferLoad=True
)
brightStarStamps = cT.Output(
doc="Set of preprocessed postage stamps, each centered on a single bright star.",
name="brightStarStamps",
Expand Down Expand Up @@ -198,6 +204,8 @@ def extractStamps(self, inputExposure, refObjLoader=None):
refObjLoader = self.refObjLoader
starIms = []
pixCenters = []
GMags = []
ids = []
wcs = inputExposure.getWcs()
# select stars within input exposure from refcat
withinCalexp = refObjLoader.loadPixelBox(inputExposure.getBBox(), wcs, filterName="phot_g_mean")
Expand All @@ -207,10 +215,10 @@ def extractStamps(self, inputExposure, refObjLoader=None):
GFluxes = np.array(refCat['phot_g_mean_flux'])
bright = GFluxes > fluxLimit
# convert to AB magnitudes
GMags = [((gFlux*u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]]
ids = refCat.columns.extract("id", where=bright)["id"]
allGMags = [((gFlux*u.nJy).to(u.ABmag)).to_value() for gFlux in GFluxes[bright]]
allIds = refCat.columns.extract("id", where=bright)["id"]
selectedColumns = refCat.columns.extract('coord_ra', 'coord_dec', where=bright)
for ra, dec in zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"]):
for j, (ra, dec) in enumerate(zip(selectedColumns["coord_ra"], selectedColumns["coord_dec"])):
sp = geom.SpherePoint(ra, dec, geom.radians)
cpix = wcs.skyToPixel(sp)
# TODO: DM-25894 keep objects on or slightly beyond CCD edge
Expand All @@ -220,6 +228,8 @@ def extractStamps(self, inputExposure, refObjLoader=None):
and cpix[1] < inputExposure.getDimensions()[1] - self.config.stampSize[1]/2):
starIms.append(inputExposure.getCutout(sp, geom.Extent2I(self.config.stampSize)))
pixCenters.append(cpix)
GMags.append(allGMags[j])
ids.append(allIds[j])
return pipeBase.Struct(starIms=starIms,
pixCenters=pixCenters,
GMags=GMags,
Expand Down Expand Up @@ -298,57 +308,6 @@ def warpStamps(self, stamps, pixCenters):
warpedStars.append(destImage.clone())
return warpedStars

def measureAndNormalize(self, warpedStamps):
"""Compute "annularFlux", the integrated flux within an annulus
around each object's center, and normalize them.
Since the center of bright stars are saturated and/or heavily affected
by ghosts, we measure their flux in an annulus with a large enough
inner radius to avoid the most severe ghosts and contain enough
non-saturated pixels.
Parameters
----------
warpedStamps : `collections.abc.Sequence`
[`afwImage.exposure.exposure.ExposureF`]
Image cutouts centered on a single object and warped to the same
arbirtary grid.
Returns
-------
annularFluxes : `list` [`float`]
"""
innerRadius, outerRadius = self.config.annularFluxRadii
# Create SpanSet of annulus
outerCircle = afwGeom.SpanSet.fromShape(outerRadius, afwGeom.Stencil.CIRCLE, offset=self.modelCenter)
innerCircle = afwGeom.SpanSet.fromShape(innerRadius, afwGeom.Stencil.CIRCLE, offset=self.modelCenter)
annulus = outerCircle.intersectNot(innerCircle)
# annularFlux statistic set-up, excluding mask planes
statsControl = afwMath.StatisticsControl()
statsControl.setNumSigmaClip(self.config.numSigmaClip)
statsControl.setNumIter(self.config.numIter)
annularFluxes = []
for image in warpedStamps:
# create image with the same pixel values within annulus, NO_DATA
# elsewhere
maskPlaneDict = image.getMask().getMaskPlaneDict()
annulusImage = afwImage.MaskedImageF(image.getDimensions(), planeDict=maskPlaneDict)
annulusMask = annulusImage.mask
annulusMask.array[:] = maskPlaneDict['NO_DATA']
annulus.copyMaskedImage(image, annulusImage)
# set mask planes to be ignored
badMasks = self.config.badMaskPlanes
andMask = reduce(ior, (annulusMask.getPlaneBitMask(bm) for bm in badMasks))
statsControl.setAndMask(andMask)
# compute annularFlux
statsFlags = afwMath.stringToStatisticsProperty(self.config.annularFluxStatistic)
annulusStat = afwMath.makeStatistics(annulusImage, statsFlags, statsControl)
annularFlux = annulusStat.getValue()
annularFluxes.append(annularFlux)
# normalize stamps
image.image.array /= annularFlux
return annularFluxes

@pipeBase.timeMethod
def run(self, inputExposure, refObjLoader=None, dataId=None):
"""Identify bright stars within an exposure using a reference catalog,
Expand All @@ -362,7 +321,7 @@ def run(self, inputExposure, refObjLoader=None, dataId=None):
The image from which bright star stamps should be extracted.
refObjLoader : `LoadIndexedReferenceObjectsTask`, optional
Loader to find objects within a reference catalog.
dataId : `dict`
dataId : `dict` or `lsst.daf.butler.DataCoordinate`
The dataId of the exposure (and detector) bright stars should be
extracted from.
Expand All @@ -380,16 +339,26 @@ def run(self, inputExposure, refObjLoader=None, dataId=None):
self.log.info("Applying warp to %i star stamps from exposure %s",
len(extractedStamps.starIms), dataId)
warpedStars = self.warpStamps(extractedStamps.starIms, extractedStamps.pixCenters)
brightStarList = [bSS.BrightStarStamp(stamp_im=warp,
gaiaGMag=extractedStamps.GMags[j],
gaiaId=extractedStamps.gaiaIds[j])
for j, warp in enumerate(warpedStars)]
# Compute annularFlux and normalize
self.log.info("Computing annular flux and normalizing %i bright stars from exposure %s",
len(warpedStars), dataId)
fluxes = self.measureAndNormalize(warpedStars)
brightStarList = [bSS.BrightStarStamp(starStamp=warp,
gaiaGMag=extractedStamps.GMags[j],
gaiaId=extractedStamps.gaiaIds[j],
annularFlux=fluxes[j])
for j, warp in enumerate(warpedStars)]
brightStarStamps = bSS.BrightStarStamps(brightStarList, *self.config.annularFluxRadii)
# annularFlux statistic set-up, excluding mask planes
statsControl = afwMath.StatisticsControl()
statsControl.setNumSigmaClip(self.config.numSigmaClip)
statsControl.setNumIter(self.config.numIter)
innerRadius, outerRadius = self.config.annularFluxRadii
statsFlag = afwMath.stringToStatisticsProperty(self.config.annularFluxStatistic)
brightStarStamps = bSS.BrightStarStamps.initAndNormalize(brightStarList,
innerRadius=innerRadius,
outerRadius=outerRadius,
imCenter=self.modelCenter,
statsControl=statsControl,
statsFlag=statsFlag,
badMaskPlanes=self.config.badMaskPlanes)
return pipeBase.Struct(brightStarStamps=brightStarStamps)

def runDataRef(self, dataRef):
Expand All @@ -410,14 +379,9 @@ def runDataRef(self, dataRef):
def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
inputs['dataId'] = str(butlerQC.quantum.dataId)
# TODO (DM-27262): remove workaround and load refcat in gen3
self.log.info("Gaia refcat is not yet available in gen3; as a temporary fix, "
"reading it in from a gen2 butler")
from lsst.meas.algorithms.loadIndexedReferenceObjects import LoadIndexedReferenceObjectsTask
from lsst.daf.persistence import Butler
refcatConfig = LoadIndexedReferenceObjectsTask.ConfigClass()
refcatConfig.ref_dataset_name = 'gaia_dr2_20200414'
gen2butler = Butler('/datasets/hsc/repo/rerun/RC/w_2020_03/DM-23121_obj/')
refObjLoader = LoadIndexedReferenceObjectsTask(gen2butler, config=refcatConfig)
refObjLoader = ReferenceObjectLoader(dataIds=[ref.datasetRef.dataId
for ref in inputRefs.refCat],
refCats=inputs.pop("refCat"),
config=self.config.refObjLoader)
output = self.run(**inputs, refObjLoader=refObjLoader)
butlerQC.put(output, outputRefs)

0 comments on commit 1e39c11

Please sign in to comment.