Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 25.1.0
rev: 25.9.0
hooks:
- id: black
language_version: python3.12
language_version: python3.13
- repo: https://github.com/pycqa/isort
rev: 6.0.1
rev: 7.0.0
hooks:
- id: isort
name: isort (python)
210 changes: 146 additions & 64 deletions python/lsst/drp/tasks/assemble_cell_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,23 @@
"AssembleCellCoaddTask",
"AssembleCellCoaddConfig",
"ConvertMultipleCellCoaddToExposureTask",
"EmptyCellCoaddError",
)

import dataclasses
import itertools

import numpy as np

import lsst.afw.geom as afwGeom
import lsst.afw.image as afwImage
import lsst.afw.math as afwMath
import lsst.geom as geom
from lsst.afw.detection import InvalidPsfError
from lsst.afw.geom import SinglePolygonException, makeWcsPairTransform
from lsst.cell_coadds import (
CellIdentifiers,
CoaddApCorrMapStacker,
CoaddInputs,
CoaddUnits,
CommonComponents,
GridContainer,
Expand All @@ -52,7 +56,6 @@
from lsst.meas.algorithms import AccumulatorMeanStack
from lsst.pex.config import ConfigField, ConfigurableField, DictField, Field, ListField, RangeField
from lsst.pipe.base import (
AlgorithmError,
InMemoryDatasetHandle,
NoWorkFound,
PipelineTask,
Expand All @@ -67,19 +70,6 @@
from lsst.skymap import BaseSkyMap


class EmptyCellCoaddError(AlgorithmError):
"""Raised if no cells could be populated."""

def __init__(self):
msg = "No cells could be populated for the cell coadd."
super().__init__(msg)

@property
def metadata(self) -> dict:
"""There is no metadata associated with this error."""
return {}


@dataclasses.dataclass
class WarpInputs:
"""Collection of associate inputs along with warps."""
Expand Down Expand Up @@ -223,13 +213,27 @@ class AssembleCellCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCe
"but may be turned off for parity with deepCoadd.",
default=False,
)
min_overlap_fraction = RangeField[float](
doc="The minimum overlap fraction required for a single (visit, detector) input to be included in a "
"cell.",
# A value of 1.0 corresponds to ideal, edge-free cells.
# A value of 0.0 corresponds to the deep_coadd style coadds.
# This has to be at least 0.5 to ensure that the an input overlaps the
# cell center. Inputs will overlap fraction less than 0.25 will
# definitely not overlap the cell center.
default=0.0,
min=0.0,
max=1.0,
inclusiveMin=True,
inclusiveMax=True,
)
bad_mask_planes = ListField[str](
doc="Mask planes that count towards the masked fraction within a cell.",
default=("BAD", "NO_DATA", "SAT", "CLIPPED"),
)
remove_mask_planes = ListField[str](
doc="Mask planes to remove before coadding",
default=["NOT_DEBLENDED", "EDGE"],
default=["NOT_DEBLENDED"],
)
calc_error_from_input_variance = Field[bool](
doc="Calculate coadd variance from input variance by stacking "
Expand Down Expand Up @@ -530,11 +534,11 @@ def run(
# get the corresponding PSFs as well.
cell_centers_sky = GridContainer[geom.SpherePoint](warp_stacker_gc.shape)
# Make a container to hold the observation identifiers for each cell.
observation_identifiers_gc = GridContainer[list](warp_stacker_gc.shape)
observation_identifiers_gc = GridContainer[dict](warp_stacker_gc.shape)
# Populate them.
for cellInfo in skyInfo.patchInfo:
# Make a list to hold the observation identifiers for each cell.
observation_identifiers_gc[cellInfo.index] = []
observation_identifiers_gc[cellInfo.index] = {}
cell_centers_sky[cellInfo.index] = skyInfo.wcs.pixelToSky(cellInfo.inner_bbox.getCenter())
psf_bbox_gc[cellInfo.index] = geom.Box2I.makeCenteredBox(
geom.Point2D(cellInfo.inner_bbox.getCenter()),
Expand Down Expand Up @@ -601,9 +605,6 @@ def run(
f"Warp {warp_input.dataId} has BUNIT {warp.metadata['BUNIT']}, expected nJy"
)

# Coadd the warp onto the cells it completely overlaps.
edge = afwImage.Mask.getPlaneBitMask(["NO_DATA", "SENSOR_EDGE"])
reject = afwImage.Mask.getPlaneBitMask(["CLIPPED", "REJECTED"])
removeMaskPlanes(warp.mask, self.config.remove_mask_planes, self.log)

# Compute the weight for each CCD in the warp from the visitSummary
Expand Down Expand Up @@ -638,47 +639,59 @@ def run(

noise_warps = [ref.get(parameters={"bbox": skyInfo.bbox}) for ref in warp_input.noise_warps]

for cellInfo in skyInfo.patchInfo:
bbox = cellInfo.outer_bbox
inner_bbox = cellInfo.inner_bbox
mi = warp[bbox].maskedImage

if (mi.mask[inner_bbox].array & edge).any():
self.log.debug(
"Skipping %s in cell %s because it has a pixel with SENSOR_EDGE or NO_DATA bit set",
warp_input.dataId,
cellInfo.index,
visit_polygons: dict[ObservationIdentifiers, afwGeom.Polygon] = {}
# Create an image where each pixel value corresponds to the
# detector ID that pixel comes from.
detector_map = afwImage.ImageI(bbox=warp.getBBox(), initialValue=-1)
for row in full_ccd_table:
transform = makeWcsPairTransform(row.wcs, warp.wcs)
if (src_polygon := row.validPolygon) is None:
src_polygon = afwGeom.Polygon(geom.Box2D(row.getBBox()))
try:
dest_polygon = src_polygon.transform(transform).intersectionSingle(
geom.Box2D(warp.getBBox())
)
except SinglePolygonException:
continue

if (mi.mask[inner_bbox].array & reject).any():
self.log.debug(
"Skipping %s in cell %s because it has a pixel with CLIPPED or REJECTED bit set",
observation_identifier = ObservationIdentifiers.from_data_id(
warp_input.dataId,
backup_detector=row["ccd"],
)
visit_polygons[observation_identifier] = dest_polygon

detector_map_slice = dest_polygon.createImage(detector_map.getBBox()).array > 0
if not (detector_map.array[detector_map_slice] < 0).all():
self.log.warning("Multiple detectors from visit %s are overlapping", warp_input.dataId)
detector_map.array[detector_map_slice] = row["ccd"]

# Update common with the visit polygons.
self.common = dataclasses.replace(
self.common,
visit_polygons=visit_polygons,
)

if (detector_map.array < 0).all():
self.log.warning("Unable to split the warp %s into single-detector warps.", warp_input.dataId)
detector_map.array[:, :] = 0

for cellInfo, ccd_row in itertools.product(skyInfo.patchInfo, full_ccd_table):
bbox = cellInfo.outer_bbox
inner_bbox = cellInfo.inner_bbox

overlap_fraction = (detector_map[inner_bbox].array == ccd_row["ccd"]).mean()
assert -1e-4 < overlap_fraction < 1.0001, "Overlap fraction is not within [0, 1]."
if (overlap_fraction < self.config.min_overlap_fraction) or (overlap_fraction <= 0.0):
self.log.log(
self.log.DEBUG if overlap_fraction == 0.0 else self.log.INFO,
"Skipping %s in cell %s because it had only %.3f < %.3f fractional overlap.",
warp_input.dataId,
cellInfo.index,
overlap_fraction,
self.config.min_overlap_fraction,
)
continue

# Find the CCD that contributed to this cell.
if len(warp.getInfo().getCoaddInputs().ccds) == 1:
# If there is only one, don't bother with a WCS look up.
ccd_row = full_ccd_table[0]
else:
ccd_table = full_ccd_table.subsetContaining(cell_centers_sky[cellInfo.index])

if len(ccd_table) == 0:
# This condition rarely occurs in test runs, if ever.
# But the QG generated upfront in campaign management
# land routinely has extra quanta that should be
# dropped during runtime. These cases arise when
# the tasks upstream didn't process.
# See DM-52306 for example.
self.log.debug("No CCD found for %s in cell %s", warp_input.dataId, cellInfo.index)
continue

assert len(ccd_table) == 1, "More than one CCD from a warp found within a cell."
ccd_row = ccd_table[0]

weight = weights[ccd_row["ccd"]]
if not np.isfinite(weight):
self.log.warn(
Expand All @@ -692,21 +705,87 @@ def run(
)
continue

observation_identifier = ObservationIdentifiers.from_data_id(
warp_input.dataId,
backup_detector=ccd_row["ccd"],
)
observation_identifiers_gc[cellInfo.index].append(observation_identifier)
# Decide if a deep copy is necessary to apply the single
# detector cuts since it involves modifying the image in-place.
# If within the inner cell, there are three or more different
# values that detector map takes, then there are definitely
# multiple detectors (one for chip gaps, two for two detectors)
deep_copy = len(set(detector_map[inner_bbox].array.ravel())) >= 3
if deep_copy:
single_detector_mask_array = detector_map[bbox].array != ccd_row["ccd"]

mi = afwImage.MaskedImageF(warp[bbox].maskedImage, deep=deep_copy)
if deep_copy:
mi.image.array[single_detector_mask_array] = 0.0
mi.variance.array[single_detector_mask_array] = np.inf
nodata_or_mask = (single_detector_mask_array) * afwImage.Mask.getPlaneBitMask("NO_DATA")
mi.mask[bbox].array |= nodata_or_mask
warp_stacker_gc[cellInfo.index].add_masked_image(mi, weight=weight)

stacker = warp_stacker_gc[cellInfo.index]
stacker.add_masked_image(mi, weight=weight)
if masked_fraction_image:
mi = afwImage.ImageF(masked_fraction_image[bbox], deep=deep_copy)
if deep_copy:
mi.array[single_detector_mask_array] = 0.0
maskfrac_stacker_gc[cellInfo.index].add_image(masked_fraction_image[bbox], weight=weight)

for n in range(self.config.num_noise_realizations):
mi = noise_warps[n][bbox]
mi = afwImage.MaskedImageF(noise_warps[n][bbox], deep=deep_copy)
if deep_copy:
mi.image.array[single_detector_mask_array] = 0.0
mi.variance.array[single_detector_mask_array] = np.inf
mi.mask[bbox].array |= nodata_or_mask
noise_stacker_gc_list[n][cellInfo.index].add_masked_image(mi, weight=weight)

# Set the defaults for PSF shape quantities.
psf_shape = afwGeom.Quadrupole()
psf_shape_flag = True
psf_eval_point = None
try:
if overlap_fraction < 1.0:
psf_eval_point = dest_polygon.intersectionSingle(
geom.Box2D(inner_bbox)
).calculateCenter()
else:
psf_eval_point = geom.Point2D(inner_bbox.getCenter())
psf_shape = warp.psf.computeShape(psf_eval_point)
psf_shape_flag = False
except SinglePolygonException:
self.log.info(
"Unable to find the overlapping polygon between %d detector in %s and cell %s",
ccd_row["ccd"],
warp_input.dataId,
cellInfo.index,
)
except InvalidPsfError:
self.log.info(
"Unable to compute PSF shape from %d detector in %s at %s",
ccd_row["ccd"],
warp_input.dataId,
psf_eval_point,
)

overlaps_center = detector_map[geom.Point2I(bbox.getCenter())] == ccd_row["ccd"]

observation_identifier = ObservationIdentifiers.from_data_id(
warp_input.dataId,
backup_detector=ccd_row["ccd"],
)
observation_identifiers_gc[cellInfo.index][observation_identifier] = CoaddInputs(
overlaps_center=overlaps_center,
overlap_fraction=overlap_fraction,
weight=weight,
psf_shape=psf_shape,
psf_shape_flag=psf_shape_flag,
)
if overlaps_center is False:
self.log.debug(
"%s does not overlap with the center of the cell %s",
warp_input.dataId,
cellInfo.index,
)
continue

# Everything below this has to do with the center of the cell
calexp_point = ccd_row.getWcs().skyToPixel(cell_centers_sky[cellInfo.index])
undistorted_psf_im = ccd_row.getPsf().computeImage(calexp_point)

Expand Down Expand Up @@ -764,7 +843,10 @@ def run(
ap_corr_map = None

# Post-process the coadd before converting to new data structures.
if self.config.do_interpolate_coadd:
if np.isnan(cell_masked_image.image.array).all():
cell_masked_image.image.array[:, :] = 0.0
cell_masked_image.variance.array[:, :] = np.inf
elif self.config.do_interpolate_coadd:
self.interpolate_coadd.run(cell_masked_image, planeName="NO_DATA")
for noise_image in cell_noise_images:
self.interpolate_coadd.run(noise_image, planeName="NO_DATA")
Expand Down Expand Up @@ -799,7 +881,7 @@ def run(
cells.append(singleCellCoadd)

if not cells:
raise EmptyCellCoaddError()
raise NoWorkFound("No cells could be populated for the cell coadd.")

grid = self._construct_grid(skyInfo)
multipleCellCoadd = MultipleCellCoadd(
Expand Down
9 changes: 3 additions & 6 deletions tests/test_assemble_cell_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
import lsst.afw.image as afwImage
import lsst.pipe.base as pipeBase
import lsst.utils.tests
from lsst.cell_coadds import CommonComponents
from lsst.drp.tasks.assemble_cell_coadd import (
AssembleCellCoaddConfig,
AssembleCellCoaddTask,
EmptyCellCoaddError,
WarpInputs,
)

Expand Down Expand Up @@ -89,7 +89,7 @@ def runQuantum(
The coadded exposure and associated metadata.
"""

self.common = pipeBase.Struct(
self.common = CommonComponents(
units=None,
wcs=mockSkyInfo.wcs,
band="i",
Expand Down Expand Up @@ -152,9 +152,6 @@ def setUpClass(cls) -> None:
]
cls.skyInfo = makeMockSkyInfo(testData.bbox, testData.wcs, patch=patch)

def tearDown(self) -> None:
del self.result

def runTask(
self,
config=None,
Expand Down Expand Up @@ -237,7 +234,7 @@ def test_assemble_empty(self):
"""Test that AssembleCellCoaddTask runs successfully without errors
when no input exposures are provided."""
self.result = None # so tearDown has something.
with self.assertRaises(EmptyCellCoaddError, msg="No cells could be populated for the cell coadd."):
with self.assertRaises(pipeBase.NoWorkFound, msg="No cells could be populated for the cell coadd."):
self.runTask(warpRefList=[], maskedFractionRefList=[], noise0RefList=[], visitSummaryList=[])

def test_assemble_without_visitSummary(self):
Expand Down