Skip to content

Commit

Permalink
Final touches before the speed-up
Browse files Browse the repository at this point in the history
  • Loading branch information
enourbakhsh committed Dec 11, 2023
1 parent 14efe64 commit cecf0d3
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 13 deletions.
36 changes: 25 additions & 11 deletions python/lsst/meas/extensions/shapeHSM/_hsm_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,18 @@ def measure(self, record, exposure):
psfSigma = exposure.getPsf().computeShape(center).getTraceRadius()

# Turn bounding box corners into GalSim bounds.
# ipdb> psfImage.getBBox()
# Box2I(corner=Point2I(0, 0), dimensions=Extent2I(51, 51))
# ipdb> bbox
# Box2I(corner=Point2I(5678, 9876), dimensions=Extent2I(1290, 1274))
xmin, xmax = bbox.getMinX(), bbox.getMaxX()
ymin, ymax = bbox.getMinY(), bbox.getMaxY()
bounds = galsim.bounds.BoundsI(xmin, xmax, ymin, ymax)

# Each GalSim image below will match whatever dtype the input array is.
# NOTE: PSF is already restricted to a small image, so no bounds for
# the PSF is expected.
image = galsim.Image(exposure.image[bbox].array, bounds=bounds, copy=False)
psf = galsim.Image(psfImage.array, copy=False)
# FIXME: no bounds in the line above? apparently no bounds in the C++ version

# Get the `lsst.meas.base` mask for bad pixels.
subMask = exposure.mask[bbox]
# FIXME: is the above PARENT? we want subMask(*afwMask, bbox, afw::image::PARENT);
# where afwMask=exposure.mask
badpix = subMask.array.copy() # Copy it since badpix gets modified.
bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes)
badpix &= bitValue
Expand All @@ -211,10 +206,7 @@ def measure(self, record, exposure):
# (here int32).
badpix = galsim.Image(badpix, bounds=bounds, copy=False)

# FIXME: dummyMask not used later on?! (it was in C++!)
# dummyMask = galsim.ImageI(bounds=bounds)
# dummyMask.image <- 1 # use something to make all pixel values 1

# Get the statistics control object for sky variance estimation.
sctrl = afwMath.StatisticsControl()
sctrl.setAndMask(bitValue)

Expand Down Expand Up @@ -285,6 +277,11 @@ def setDefaults(self):
super().setDefaults()
self.shearType = "BJ"

def validate(self):
if self.shearType != "BJ":
raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.")
super().validate()


@measBase.register("ext_shapeHSM_HsmShapeBj")
class HsmShapeBjPlugin(HsmShapePlugin):
Expand All @@ -302,6 +299,11 @@ def setDefaults(self):
super().setDefaults()
self.shearType = "LINEAR"

def validate(self):
if self.shearType != "LINEAR":
raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.")
super().validate()


@measBase.register("ext_shapeHSM_HsmShapeLinear")
class HsmShapeLinearPlugin(HsmShapePlugin):
Expand All @@ -319,6 +321,11 @@ def setDefaults(self):
super().setDefaults()
self.shearType = "KSB"

def validate(self):
if self.shearType != "KSB":
raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.")
super().validate()


@measBase.register("ext_shapeHSM_HsmShapeKsb")
class HsmShapeKsbPlugin(HsmShapePlugin):
Expand All @@ -336,6 +343,13 @@ def setDefaults(self):
super().setDefaults()
self.shearType = "REGAUSS"

def validate(self):
if self.shearType != "REGAUSS":
raise pexConfig.FieldValidationError(
self.shearType, self, "shearType should be set to 'REGAUSS'."
)
super().validate()


@measBase.register("ext_shapeHSM_HsmShapeRegauss")
class HsmShapeRegaussPlugin(HsmShapePlugin):
Expand Down
37 changes: 35 additions & 2 deletions tests/test_hsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,31 @@ def runMeasurement(self, algorithmName, imageid, x, y, v):
source.setFootprint(afwDetection.Footprint(afwGeom.SpanSet(exposure.getBBox(afwImage.PARENT))))
plugin.measure(source, exposure)

return source
# Get the trace radius of the PSF and GalSim images to use in the
# EstimateShear call.
psfSigma = exposure.getPsf().computeShape(center).getTraceRadius()
bbox = source.getFootprint().getBBox()
bounds = galsim.bounds.BoundsI(bbox.getMinX(), bbox.getMaxX(), bbox.getMinY(), bbox.getMaxY())
image = galsim.Image(exposure.image[bbox].array, bounds=bounds, copy=False)
psf = galsim.Image(psfImg.array, copy=False)

# Retrieve the measurement type that Galsim outputs after estimation.
# NOTE: not passing weight, badpix, and sky_var, as the objective here
# is solely to deduce the meas_type for this setup.
postEstimationMeasType = galsim.hsm.EstimateShear(
gal_image=image,
PSF_image=psf,
shear_est=control.shearType,
recompute_flux="FIT",
guess_sig_gal=2.5 * psfSigma,
guess_sig_PSF=psfSigma,
precision=1.0e-6,
guess_centroid=galsim.PositionD(center.getX(), center.getY()),
strict=True,
hsmparams=None,
).meas_type

return source, alg.measTypeSymbol, postEstimationMeasType

def testHsmShape(self):
"""Test that we can instantiate and play with a measureShape"""
Expand All @@ -417,7 +441,16 @@ def testHsmShape(self):
enumerate(file_indices)):
algorithmName = "ext_shapeHSM_HsmShape" + algName[0:1].upper() + algName[1:].lower()

source = self.runMeasurement(algorithmName, imageid, x_centroid[i], y_centroid[i], sky_var[i])
source, preEstimationMeasType, postEstimationMeasType = self.runMeasurement(
algorithmName, imageid, x_centroid[i], y_centroid[i], sky_var[i]
)

# Check consistency with GalSim output
self.assertEqual(
preEstimationMeasType,
postEstimationMeasType,
"The plugin setup is incompatible with GalSim output.",
)

##########################################
# see how we did
Expand Down

0 comments on commit cecf0d3

Please sign in to comment.