Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-43515: Evaluate PSFs at the center of cells #48

Merged
merged 8 commits into from
Apr 24, 2024
122 changes: 71 additions & 51 deletions python/lsst/drp/tasks/assemble_cell_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import lsst.afw.image as afwImage
import lsst.afw.math as afwMath
import lsst.geom as geom
import numpy as np
from lsst.cell_coadds import (
CellIdentifiers,
Expand All @@ -41,12 +42,11 @@
SingleCellCoadd,
UniformGrid,
)
from lsst.meas.algorithms import AccumulatorMeanStack, CoaddPsf, CoaddPsfConfig
from lsst.meas.algorithms import AccumulatorMeanStack
from lsst.pex.config import ConfigField, ConfigurableField, Field, ListField, RangeField
from lsst.pipe.base import NoWorkFound, PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
from lsst.pipe.base.connectionTypes import Input, Output
from lsst.pipe.tasks.coaddBase import makeSkyInfo
from lsst.pipe.tasks.coaddInputRecorder import CoaddInputRecorderTask
from lsst.pipe.tasks.interpImage import InterpImageTask
from lsst.pipe.tasks.scaleZeroPoint import ScaleZeroPointTask
from lsst.skymap import BaseSkyMap
Expand Down Expand Up @@ -113,14 +113,15 @@ class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCe
inclusiveMin=True,
inclusiveMax=False,
)
# The following config options are specific to the CoaddPsf.
coadd_psf = ConfigField(
doc="Configuration for CoaddPsf",
dtype=CoaddPsfConfig,
psf_warper = ConfigField(
doc="Configuration for the warper that warps the PSFs. It must have the same configuration used to "
"warp the images.",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to enforce this via a contract in the pipeline, since I don't think there's any other way to enforce consistency across tasks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that's the best approach for now, but I think it's also a nudge in the direction of moving PSF warping down to MakeDirectWarpTask (by having that warp yet another image plane that would for now have to be yet another separate dataset type). I don't think we should consider doing that on this ticket, but between configuration consistency, free parallelism over visits, and ruling out differences in how the WCS is interpolated during warping, I think that's the direction we want to go eventually.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that to be the natural progression, given that we've moved PSF evaluation from measurement to coaddition and the next natural step is to move it to make_direct_warp.py. We spoke about this, but my thought here has been to use MultipleCellCoadd (we might need to subclass it, more on that later) with no buffer between cells to be the data structure to hold warps. It has the capability to hold multiple noise image planes which are used by cell-based coadds as well. The PSF images can be evaluated at the cell centers at the time of warping.

The reason why we might need another class - or make a generic parent class to subclass from - is that we now have to allow for multiple detectors within each cell (we don't want to discard them just yet). This needs some minor changes to how we serialize metadata info.

But to reiterate again,

differences in how the WCS is interpolated during warping

is not an issue because we evaluate the PSF image on the calexp and warp them here.

Copy link
Member

@TallJimbo TallJimbo Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But to reiterate again,

differences in how the WCS is interpolated during warping

is not an issue because we evaluate the PSF image on the calexp and warp them here.

We don't evaluate the PSF image on the full calexp; we evaluate it into a much smaller image in calexp coordinates. That would be sufficient if warping used the real WCS for every pixel, or if it evaluated the WCS at fixed points in PARENT coordinates (accounting for xy0) and interpolated between them. But I suspect that it might actually evaluate the WCS at fixed points in LOCAL coordinates (i.e. not accounting for xy0) and interpolate those. Just suspicion, though; if you know that it evaluates the PSF at fixed points in PARENT coordinates then this concern is indeed unfounded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this related to your comment about xy0 below? The bounding box of the PSF image is evaluated in PARENT coordinates by default but I wonder if I'm missing your point. Happy to discuss this more tomorrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This question doesn't need to hold up this branch anyway, and I think the answer is one that can be found in the depths of the warping code: we're making different assumptions about what it's doing, but the truth is out there.

dtype=afwMath.Warper.ConfigClass,
)
input_recorder = ConfigurableField(
doc="Subtask that helps fill CoaddInputs catalogs added to the final Exposure",
target=CoaddInputRecorderTask,
psf_dimensions = Field[int](
default=21,
doc="Dimensions of the PSF image stamp size to be assigned to cells (must be odd).",
check=lambda x: (x > 0) and (x % 2 == 1),
)


Expand Down Expand Up @@ -162,12 +163,13 @@ class AssembleCellCoaddTask(PipelineTask):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.makeSubtask("input_recorder")
if self.config.do_interpolate_coadd:
self.makeSubtask("interpolate_coadd")
if self.config.do_scale_zero_point:
self.makeSubtask("scale_zero_point")

self.psf_warper = afwMath.Warper.fromConfig(self.config.psf_warper)

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Docstring inherited.
inputData = butlerQC.get(inputRefs)
Expand Down Expand Up @@ -244,15 +246,13 @@ def _construct_grid(skyInfo):
grid = UniformGrid.from_bbox_cell_size(grid_bbox, skyInfo.patchInfo.getCellInnerDimensions())
return grid

def _construct_grid_container(self, skyInfo, statsCtrl):
def _construct_grid_container(self, skyInfo):
"""Construct a grid of AccumulatorMeanStack instances.

Parameters
----------
skyInfo : `~lsst.pipe.base.Struct`
A Struct object
statsCtrl : `~lsst.afw.math.StatisticsControl`
A control (config-like) object for StatisticsStack.

Returns
-------
Expand All @@ -267,7 +267,7 @@ def _construct_grid_container(self, skyInfo, statsCtrl):
stacker = AccumulatorMeanStack(
# The shape is for the numpy arrays, hence transposed.
shape=(cellInfo.outer_bbox.height, cellInfo.outer_bbox.width),
bit_mask_value=afwImage.Mask.getPlaneBitMask(self.config.bad_mask_planes),
bit_mask_value=0,
calc_error_from_input_variance=self.config.calc_error_from_input_variance,
compute_n_image=False,
)
Expand All @@ -284,28 +284,34 @@ def _construct_stats_control(self):
def run(self, inputWarps, skyInfo, **kwargs):
statsCtrl = self._construct_stats_control()

gc = self._construct_grid_container(skyInfo, statsCtrl)
coadd_inputs_gc = GridContainer(gc.shape)
gc = self._construct_grid_container(skyInfo)
psf_gc = GridContainer[AccumulatorMeanStack](gc.shape)
psf_bbox_gc = GridContainer[geom.Box2I](gc.shape)

# Make a container to hold the cell centers in sky coordinates now,
# so we don't have to recompute them for each warp
# (they share a common WCS). These are needed to find the various
# warp + detector combinations that contributed to each cell, and later
# get the corresponding PSFs as well.
cell_centers_sky = GridContainer(gc.shape)
cell_centers_sky = GridContainer[geom.SpherePoint](gc.shape)
# Make a container to hold the observation identifiers for each cell.
observation_identifiers_gc = GridContainer(gc.shape)
observation_identifiers_gc = GridContainer[list](gc.shape)
# Populate them.
for cellInfo in skyInfo.patchInfo:
coadd_inputs = self.input_recorder.makeCoaddInputs()
# Reserve the absolute maximum of how many ccds, visits
# we could potentially have.
coadd_inputs.ccds.reserve(len(inputWarps))
coadd_inputs.visits.reserve(len(inputWarps))
coadd_inputs_gc[cellInfo.index] = coadd_inputs
# Make a list to hold the observation identifiers for each cell.
observation_identifiers_gc[cellInfo.index] = []
cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cellInfo.inner_bbox.getCenter())
psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox(
geom.Point2D(cellInfo.inner_bbox.getCenter()),
geom.Extent2I(self.config.psf_dimensions, self.config.psf_dimensions),
)
psf_gc[cellInfo.index] = AccumulatorMeanStack(
# The shape is for the numpy arrays, hence transposed.
shape=(self.config.psf_dimensions, self.config.psf_dimensions),
bit_mask_value=0,
calc_error_from_input_variance=self.config.calc_error_from_input_variance,
compute_n_image=False,
)

# Read in one warp at a time, and accumulate it in all the cells that
# it completely overlaps.
Expand All @@ -323,7 +329,6 @@ def run(self, inputWarps, skyInfo, **kwargs):
edge = afwImage.Mask.getPlaneBitMask("EDGE")
for cellInfo in skyInfo.patchInfo:
bbox = cellInfo.outer_bbox
stacker = gc[cellInfo.index]
mi = warp[bbox].getMaskedImage()

if (mi.getMask().array & edge).any():
Expand All @@ -340,30 +345,49 @@ def run(self, inputWarps, skyInfo, **kwargs):
)
continue

stacker.add_masked_image(mi, weight=weight)

coadd_inputs = coadd_inputs_gc[cellInfo.index]
self.input_recorder.addVisitToCoadd(coadd_inputs, warp[bbox], weight)
if True:
ccd_table = (
warp.getInfo()
.getCoaddInputs()
.ccds.subsetContaining(cell_centers_sky[cellInfo.index])
)
assert len(ccd_table) > 0, "No CCD from a warp found within a cell."
assert len(ccd_table) == 1, "More than one CCD from a warp found within a cell."
ccd_row = ccd_table[0]
else:
for ccd_row in warp.getInfo().getCoaddInputs().ccds:
if ccd_row.contains(cell_centers_sky[cellInfo.index]):
break
ccd_table = (
warp.getInfo().getCoaddInputs().ccds.subsetContaining(cell_centers_sky[cellInfo.index])
)
assert len(ccd_table) > 0, "No CCD from a warp found within a cell."
assert len(ccd_table) == 1, "More than one CCD from a warp found within a cell."
ccd_row = ccd_table[0]

observation_identifier = ObservationIdentifiers.from_data_id(
warpRef.dataId,
backup_detector=ccd_row["ccd"],
)
observation_identifiers_gc[cellInfo.index].append(observation_identifier)

stacker = gc[cellInfo.index]
stacker.add_masked_image(mi, weight=weight)

calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index])
undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point)

assert undistorted_psf_im.getBBox() == geom.Box2I.makeCenteredBox(
calexp_point,
undistorted_psf_im.getDimensions(),
), "PSF image does not share the coordinates of the 'calexp'"

# Convert the PSF image from Image to MaskedImage.
undistorted_psf_maskedImage = afwImage.MaskedImageD(image=undistorted_psf_im)
# TODO: In DM-43585, use the variance plane value from noise.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters; the stacker had better not be using this variance for anything other than the output variance (we're passing it an explicit weight below), and we're throwing the output variance away.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to double check, but we need that variance because its inverse is used as the weight.

undistorted_psf_maskedImage.variance += 1.0 # Set variance to 1

warped_psf_maskedImage = self.psf_warper.warpImage(
destWcs=skyInfo.wcs,
srcImage=undistorted_psf_maskedImage,
srcWcs=ccd_row.getWcs(),
destBBox=psf_bbox_gc[cellInfo.index],
)

# There may be NaNs in the PSF image. Set them to 0.0
warped_psf_maskedImage.variance.array[np.isnan(warped_psf_maskedImage.image.array)] = 1.0
warped_psf_maskedImage.image.array[np.isnan(warped_psf_maskedImage.image.array)] = 0.0
arunkannawadi marked this conversation as resolved.
Show resolved Hide resolved

psf_stacker = psf_gc[cellInfo.index]
psf_stacker.add_masked_image(warped_psf_maskedImage, weight=weight)
arunkannawadi marked this conversation as resolved.
Show resolved Hide resolved

del warp

cells: list[SingleCellCoadd] = []
Expand All @@ -374,7 +398,9 @@ def run(self, inputWarps, skyInfo, **kwargs):

stacker = gc[cellInfo.index]
cell_masked_image = afwImage.MaskedImageF(cellInfo.outer_bbox)
stacker.fill_stacked_masked_image(cell_masked_image)
psf_masked_image = afwImage.MaskedImageF(psf_bbox_gc[cellInfo.index])
gc[cellInfo.index].fill_stacked_masked_image(cell_masked_image)
psf_gc[cellInfo.index].fill_stacked_masked_image(psf_masked_image)

# Post-process the coadd before converting to new data structures.
if self.config.do_interpolate_coadd:
Expand All @@ -384,12 +410,6 @@ def run(self, inputWarps, skyInfo, **kwargs):
with np.errstate(invalid="ignore"):
varArray[:] = np.where(varArray > 0, varArray, np.inf)

# Finalize the PSF on the cell coadds.
coadd_inputs = coadd_inputs_gc[cellInfo.index]
coadd_inputs.ccds.sort()
coadd_inputs.visits.sort()
cell_coadd_psf = CoaddPsf(coadd_inputs.ccds, skyInfo.wcs, self.config.coadd_psf.makeControl())

image_planes = OwnedImagePlanes.from_masked_image(cell_masked_image)
identifiers = CellIdentifiers(
cell=cellInfo.index,
Expand All @@ -401,9 +421,9 @@ def run(self, inputWarps, skyInfo, **kwargs):

singleCellCoadd = SingleCellCoadd(
outer=image_planes,
psf=cell_coadd_psf.computeKernelImage(cell_coadd_psf.getAveragePosition()),
psf=psf_masked_image.image,
inner_bbox=cellInfo.inner_bbox,
inputs=frozenset(observation_identifiers_gc[cellInfo.index]),
inputs=observation_identifiers_gc[cellInfo.index],
common=self.common,
identifiers=identifiers,
)
Expand Down
36 changes: 29 additions & 7 deletions tests/test_assemble_cell_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,23 @@ class AssembleCellCoaddTestCase(lsst.utils.tests.TestCase):
execution.
"""

def setUp(self):
@classmethod
def setUpClass(cls) -> None:
patch = 42
tract = 0
testData = MockCoaddTestData(fluxRange=1e4)
exposures = {}
matchedExposures = {}
for expId in range(100, 110):
exposures[expId], matchedExposures[expId] = testData.makeTestImage(expId)
self.dataRefList = testData.makeDataRefList(
cls.dataRefList = testData.makeDataRefList(
exposures, matchedExposures, "direct", patch=patch, tract=tract
)
self.skyInfo = makeMockSkyInfo(testData.bbox, testData.wcs, patch=patch)
cls.skyInfo = makeMockSkyInfo(testData.bbox, testData.wcs, patch=patch)

config = MockAssembleCellCoaddConfig()
assembleTask = MockAssembleCellCoaddTask(config=config)
cls.result = assembleTask.runQuantum(cls.skyInfo, cls.dataRefList)

def checkRun(self, assembleTask):
"""Check that the task runs successfully."""
Expand All @@ -122,15 +127,32 @@ def checkRun(self, assembleTask):
self.assertGreaterEqual(obsId.packed, packed)
packed = obsId.packed

def testAssembleBasic(self):
def test_assemble_basic(self):
"""Test that AssembleCellCoaddTask runs successfully without errors.

This test does not check the correctness of the coaddition algorithms.
This is intended to prevent the code from bit rotting.
"""
config = MockAssembleCellCoaddConfig()
assembleTask = MockAssembleCellCoaddTask(config=config)
self.checkRun(assembleTask)
# Check that we produced an exposure.
self.assertTrue(self.result.multipleCellCoadd is not None)

def test_visit_count(self):
"""Check that the visit_count method returns a number less than or
equal to the total number of input exposures available.
"""
max_visit_count = len(self.dataRefList)
for cellId, singleCellCoadd in self.result.multipleCellCoadd.cells.items():
with self.subTest(x=cellId.x, y=cellId.y):
self.assertLessEqual(singleCellCoadd.visit_count, max_visit_count)

def test_inputs_sorted(self):
"""Check that the inputs are sorted."""
for _, singleCellCoadd in self.result.multipleCellCoadd.cells.items():
packed = -np.inf
for obsId in singleCellCoadd.inputs:
with self.subTest(input_number=obsId):
self.assertGreaterEqual(obsId.packed, packed)
packed = obsId.packed


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
Expand Down