Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-42529: Use the accelerated utility functions in CloughTocher2DInterpolatorTask #881

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
194 changes: 0 additions & 194 deletions python/lsst/pipe/tasks/interpImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = (
"CloughTocher2DInterpolateConfig",
"CloughTocher2DInterpolateTask",
"InterpImageConfig",
"InterpImageTask",
)


from contextlib import contextmanager
from itertools import product
from typing import Iterable

import lsst.pex.config as pexConfig
import lsst.geom
Expand All @@ -39,7 +35,6 @@
import lsst.meas.algorithms as measAlg
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
from scipy.interpolate import CloughTocher2DInterpolator


class InterpImageConfig(pexConfig.Config):
Expand Down Expand Up @@ -272,192 +267,3 @@ def interpolateImage(self, maskedImage, psf, defectList, fallbackValue):
with self.transposeContext(maskedImage, defectList) as (image, defects):
measAlg.interpolateOverDefects(image, psf, defects, fallbackValue,
self.config.useFallbackValueAtEdge)


class CloughTocher2DInterpolateConfig(pexConfig.Config):
"""Config for CloughTocher2DInterpolateTask."""

badMaskPlanes = pexConfig.ListField[str](
doc="List of mask planes to interpolate over.",
default=["BAD", "SAT", "CR"],
)
fillValue = pexConfig.Field[float](
doc="Constant value to fill outside of the convex hull of the good "
"pixels. A long (longer than twice the ``interpLength``) streak of "
"bad pixels at an edge will be set to this value.",
default=0.0,
)
interpLength = pexConfig.Field[int](
doc="Maximum number of pixels away from a bad pixel to include in "
"building the interpolant. Must be greater than or equal to 1.",
default=4,
check=lambda x: x >= 1,
)


class CloughTocher2DInterpolateTask(pipeBase.Task):
"""Interpolated over bad pixels using CloughTocher2DInterpolator.

Pixels with mask bits set to any of those listed ``badMaskPlanes`` config
are considered bad and are interpolated over. All good (non-bad) pixels
within ``interpLength`` pixels of a bad pixel in either direction are used
to construct the interpolant. An extended streak of bad pixels at an edge,
longer than ``interpLength``, is set to `fillValue`` specified in config.
"""

ConfigClass = CloughTocher2DInterpolateConfig
_DefaultName = "cloughTocher2DInterpolate"

def run(self, maskedImage, badpix: set[tuple[int, int]] | None = None, goodpix: dict | None = None):
"""Interpolate over bad pixels in a masked image.

This modifies the ``image`` attribute of the ``maskedImage`` in place.
This method returns, and accepts, the coordinates of the bad pixels
that were interpolated over, and the coordinates and values of the
good pixels that were used to construct the interpolant. This avoids
having to search for the bad and the good pixels repeatedly when the
mask plane is shared among many images, as would be the case with
noise realizations.

Parameters
----------
maskedImage : `~lsst.afw.image.MaskedImage`
Image on which to perform interpolation (and modify in-place).
badpix: `set` [`tuple` [`int`, `int`]], optional
The coordinates of the bad pixels to interpolate over.
If None, then the coordinates of the bad pixels are determined by
an exhaustive search over the image.
goodpix: `dict` [`tuple` [`int`, `int`], `float`], optional
A mapping whose keys are the coordinates of the good pixels around
``badpix`` that must be included when constructing the
interpolant. If ``badpix`` is provided, then the pixels in
``goodpix`` are used as to construct the interpolatant. If not,
any additional good pixels around internally determined ``badpix``
are added to ``goodpix`` and used to construct the interpolant. In
all cases, the values are populated from the image plane of the
``maskedImage`` (provided values will be ignored.

Returns
-------
badpix: `set` [`tuple` [`int`, `int`]]
The coordinates of the bad pixels that were interpolated over.
goodpix: `dict` [`tuple` [`int`, `int`], `float`]
Mapping of the coordinates of the good pixels around ``badpix``
to their values that were included when constructing the
interpolant.

Raises
------
RuntimeError
If a pixel passed in as ``goodpix`` is found to be bad as specified by
``maskPlanes``.
ValueError
If an input ``badpix`` is not found to be bad as specified by
``maskPlanes``.
"""
max_window_extent = lsst.geom.Extent2I(
2 * self.config.interpLength + 1, 2 * self.config.interpLength + 1
)
# Even if badpix and/or goodpix is provided, make sure to update
# the values of goodpix.
badpix, goodpix = find_good_pixels_around_bad_pixels(
maskedImage,
self.config.badMaskPlanes,
max_window_extent=max_window_extent,
badpix=badpix,
goodpix=goodpix,
)

# Construct the interpolant.
interpolator = CloughTocher2DInterpolator(
list(goodpix.keys()),
list(goodpix.values()),
fill_value=self.config.fillValue,
)

# Fill in the bad pixels.
for x, y in badpix:
maskedImage.image[x, y] = interpolator((x, y))

return badpix, goodpix


def find_good_pixels_around_bad_pixels(
image: afwImage.MaskedImage,
maskPlanes: Iterable[str],
*,
max_window_extent: lsst.geom.Extent2I,
badpix: set | None = None,
goodpix: dict | None = None,
):
"""Find the location of bad pixels, and neighboring good pixels.

Parameters
----------
image : `~lsst.afw.image.MaskedImage`
Image from which to find the bad and the good pixels.
maskPlanes : `list` [`str`]
List of mask planes to consider as bad pixels.
max_window_extent : `lsst.geom.Extent2I`
Maximum extent of the window around a bad pixel to consider when
looking for good pixels.
badpix : `list` [`tuple` [`int`, `int`]], optional
A known list of bad pixels. If provided, the function does not look for
any additional bad pixels, but it verifies that the provided
coordinates correspond to bad pixels. If an input``badpix`` is not
found to be bad as specified by ``maskPlanes``, an exception is raised.
goodpix : `dict` [`tuple` [`int`, `int`], `float`], optional
A known mapping of the coordinates of good pixels to their values, to
which any newly found good pixels locations will be added, and the
values (even for existing items) will be updated.

Returns
-------
badpix : `list` [`tuple` [`int`, `int`]]
The coordinates of the bad pixels. If ``badpix`` was provided as an
input argument, the returned quantity is the same as the input.
goodpix : `dict` [`tuple` [`int`, `int`], `float`]
Updated mapping of the coordinates of good pixels to their values.

Raises
------
RuntimeError
If a pixel passed in as ``goodpix`` is found to be bad as specified by
``maskPlanes``.
ValueError
If an input ``badpix`` is not found to be bad as specified by
``maskPlanes``.
"""

bbox = image.getBBox()
if badpix is None:
iterator = product(range(bbox.minX, bbox.maxX + 1), range(bbox.minY, bbox.maxY + 1))
badpix = set()
else:
iterator = badpix

if goodpix is None:
goodpix = {}

for x, y in iterator:
if image.mask[x, y] & afwImage.Mask.getPlaneBitMask(maskPlanes):
if (x, y) in goodpix:
raise RuntimeError(f"Pixel ({x}, {y}) is bad as specified by maskPlanes {maskPlanes} but "
"passed in as goodpix")
badpix.add((x, y))
window = lsst.geom.Box2I.makeCenteredBox(
center=lsst.geom.Point2D(x, y), # center has to be a Point2D instance.
size=max_window_extent,
)
# Restrict to the bounding box of the image.
window.clip(bbox)

for xx, yy in product(range(window.minX, window.maxX + 1), range(window.minY, window.maxY + 1)):
if not (image.mask[xx, yy] & afwImage.Mask.getPlaneBitMask(maskPlanes)):
goodpix[(xx, yy)] = image.image[xx, yy]
elif (x, y) in badpix:
# If (x, y) is in badpix, but did not get flagged as bad,
# raise an exception.
raise ValueError(f"Pixel ({x}, {y}) is not bad as specified by maskPlanes {maskPlanes}")

return badpix, goodpix
133 changes: 1 addition & 132 deletions tests/test_interpImageTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import lsst.afw.image as afwImage
import lsst.pex.config as pexConfig
import lsst.ip.isr as ipIsr
from lsst.pipe.tasks.interpImage import CloughTocher2DInterpolateTask, InterpImageTask
from lsst.pipe.tasks.interpImage import InterpImageTask

try:
display
Expand Down Expand Up @@ -166,137 +166,6 @@ def testTranspose(self):
self.assertFloatsEqual(image.image.array, value)


class CloughTocher2DInterpolateTestCase(lsst.utils.tests.TestCase):
"""Test the CloughTocher2DInterpolateTask."""

def setUp(self):
super().setUp()

self.maskedimage = afwImage.MaskedImageF(100, 121)
for x in range(100):
for y in range(121):
self.maskedimage[x, y] = (3 * y + x * 5, 0, 1.0)

# Clone the maskedimage so we can compare it after running the task.
self.reference = self.maskedimage.clone()

# Set some central pixels as SAT
sliceX, sliceY = slice(30, 35), slice(40, 45)
self.maskedimage.mask[sliceX, sliceY] = afwImage.Mask.getPlaneBitMask("SAT")
self.maskedimage.image[sliceX, sliceY] = np.nan
# Put nans here to make sure interp is done ok

# Set an entire column as BAD
self.maskedimage.mask[54:55, :] = afwImage.Mask.getPlaneBitMask("BAD")
self.maskedimage.image[54:55, :] = np.nan

# Set an entire row as BAD
self.maskedimage.mask[:, 110:111] = afwImage.Mask.getPlaneBitMask("BAD")
self.maskedimage.image[:, 110:111] = np.nan

# Set a diagonal set of pixels as CR
for i in range(74, 78):
self.maskedimage.mask[i, i] = afwImage.Mask.getPlaneBitMask("CR")
self.maskedimage.image[i, i] = np.nan

# Set one of the edges as EDGE
self.maskedimage.mask[0:1, :] = afwImage.Mask.getPlaneBitMask("EDGE")
self.maskedimage.image[0:1, :] = np.nan

# Set a smaller streak at the edge
self.maskedimage.mask[25:28, 0:1] = afwImage.Mask.getPlaneBitMask("EDGE")
self.maskedimage.image[25:28, 0:1] = np.nan

# Update the reference image's mask alone, so we can compare them after
# running the task.
self.reference.mask.array[:, :] = self.maskedimage.mask.array

# Create a noise image
self.noise = self.maskedimage.clone()
np.random.seed(12345)
self.noise.image.array[:, :] = np.random.normal(size=self.noise.image.array.shape)

@lsst.utils.tests.methodParameters(n_runs=(1, 2))
def test_interpolation(self, n_runs: int):
"""Test that the interpolation is done correctly.

Parameters
----------
n_runs : `int`
Number of times to run the task. Running the task more than once
should have no effect.
"""
config = CloughTocher2DInterpolateTask.ConfigClass()
config.badMaskPlanes = (
"BAD",
"SAT",
"CR",
"EDGE",
)
config.fillValue = 0.5
task = CloughTocher2DInterpolateTask(config)
for n in range(n_runs):
task.run(self.maskedimage)

# Assert that the mask and the variance planes remain unchanged.
self.assertImagesEqual(self.maskedimage.variance, self.reference.variance)
self.assertMasksEqual(self.maskedimage.mask, self.reference.mask)

# Check that the long streak of bad pixels have been replaced with the
# fillValue, but not the short streak.
np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue)
with self.assertRaises(AssertionError):
np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue)

# Check that interpolated pixels are close to the reference (original),
# and that none of them is still NaN.
self.assertTrue(np.isfinite(self.maskedimage.image.array).all())
self.assertImagesAlmostEqual(
self.maskedimage.image[1:, :], self.reference.image[1:, :], rtol=1e-05, atol=1e-08
)

@lsst.utils.tests.methodParametersProduct(pass_badpix=(True, False), pass_goodpix=(True, False))
def test_interpolation_with_noise(self, pass_badpix: bool = True, pass_goodpix: bool = True):
"""Test that we can reuse the badpix and goodpix.

Parameters
----------
pass_badpix : `bool`
Whether to pass the badpix to the task?
pass_goodpix : `bool`
Whether to pass the goodpix to the task?
"""

config = CloughTocher2DInterpolateTask.ConfigClass()
config.badMaskPlanes = (
"BAD",
"SAT",
"CR",
"EDGE",
)
task = CloughTocher2DInterpolateTask(config)

badpix, goodpix = task.run(self.noise)
task.run(
self.maskedimage,
badpix=(badpix if pass_badpix else None),
goodpix=(goodpix if pass_goodpix else None),
)

# Check that the long streak of bad pixels by the edge have been
# replaced with fillValue, but not the short streak.
np.testing.assert_array_equal(self.maskedimage.image[0:1, :].array, config.fillValue)
with self.assertRaises(AssertionError):
np.testing.assert_array_equal(self.maskedimage.image[25:28, 0:1].array, config.fillValue)

# Check that interpolated pixels are close to the reference (original),
# and that none of them is still NaN.
self.assertTrue(np.isfinite(self.maskedimage.image.array).all())
self.assertImagesAlmostEqual(
self.maskedimage.image[1:, :], self.reference.image[1:, :], rtol=1e-05, atol=1e-08
)


def setup_module(module):
lsst.utils.tests.init()

Expand Down