Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-21129: Improve "unsupported operand types" error for afwImage arithmetic #483

Merged
merged 1 commit into from
Aug 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 56 additions & 0 deletions python/lsst/afw/image/disableArithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from lsst.afw.image.imageSlice import ImageSliceF, ImageSliceD

__all__ = ("disableImageArithmetic", "disableMaskArithmetic")


def wrapNotImplemented(cls, attr):
"""Wrap a method providing a helpful error message about image arithmetic

Parameters
----------
cls : `type`
Class in which the method is to be defined.
attr : `str`
Name of the method.

Returns
-------
method : callable
Wrapped method.
"""
existing = getattr(cls, attr, None)

def notImplemented(self, other):
"""Provide a helpful error message about image arithmetic

Unless we're operating on an ImageSlice, in which case it might be
defined.

Parameters
----------
self : subclass of `lsst.afw.image.ImageBase`
Image someone's attempting to do arithmetic with.
other : anything
The operand of the arithmetic operation.
"""
if existing is not None and isinstance(other, (ImageSliceF, ImageSliceD)):
return existing(self, other)
raise NotImplementedError("This arithmetic operation is not implemented, in order to prevent the "
"accidental proliferation of temporaries. Please use the in-place "
"arithmetic operations (e.g., += instead of +) or operate on the "
"underlying arrays.")
return notImplemented


def disableImageArithmetic(cls):
"""Add helpful error messages about image arithmetic"""
for attr in ("__add__", "__sub__", "__mul__", "__truediv__",
"__radd__", "__rsub__", "__rmul__", "__rtruediv__"):
setattr(cls, attr, wrapNotImplemented(cls, attr))


def disableMaskArithmetic(cls):
"""Add helpful error messages about mask arithmetic"""
for attr in ("__or__", "__and__", "__xor__",
"__ror__", "__rand__", "__rxor__"):
setattr(cls, attr, wrapNotImplemented(cls, attr))
2 changes: 2 additions & 0 deletions python/lsst/afw/image/exposure/exposureContinued.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lsst.utils import TemplateMeta

from ..slicing import supportSlicing
from ..disableArithmetic import disableImageArithmetic
from ..image.fitsIoWithOptions import imageReadFitsWithOptions, exposureWriteFitsWithOptions
from .exposure import ExposureI, ExposureF, ExposureD, ExposureU, ExposureL

Expand Down Expand Up @@ -121,3 +122,4 @@ def setCalib(self, *args, **kwargs):

for cls in set(Exposure.values()):
supportSlicing(cls)
disableImageArithmetic(cls)
3 changes: 3 additions & 0 deletions python/lsst/afw/image/image/imageContinued.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lsst.utils import TemplateMeta

from ..slicing import supportSlicing
from ..disableArithmetic import disableImageArithmetic
from .fitsIoWithOptions import imageReadFitsWithOptions, imageWriteFitsWithOptions
from .image import ImageI, ImageF, ImageD, ImageU, ImageL
from .image import DecoratedImageI, DecoratedImageF, DecoratedImageD, DecoratedImageU, DecoratedImageL
Expand Down Expand Up @@ -87,6 +88,8 @@ def convertD(self):

for cls in set(Image.values()):
supportSlicing(cls)
disableImageArithmetic(cls)

for cls in set(DecoratedImage.values()):
supportSlicing(cls)
disableImageArithmetic(cls)
2 changes: 2 additions & 0 deletions python/lsst/afw/image/image/maskContinued.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .fitsIoWithOptions import imageReadFitsWithOptions, imageWriteFitsWithOptions
from .image import MaskX
from ..slicing import supportSlicing
from ..disableArithmetic import disableMaskArithmetic

MaskPixel = np.int32

Expand Down Expand Up @@ -56,3 +57,4 @@ def __repr__(self):

for cls in (MaskX, ):
supportSlicing(cls)
disableMaskArithmetic(cls)
2 changes: 2 additions & 0 deletions python/lsst/afw/image/maskedImage/maskedImageContinued.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ..image.fitsIoWithOptions import imageReadFitsWithOptions, exposureWriteFitsWithOptions
from ..slicing import supportSlicing
from ..disableArithmetic import disableImageArithmetic
from .maskedImage import MaskedImageI, MaskedImageF, MaskedImageD, MaskedImageU, MaskedImageL


Expand Down Expand Up @@ -127,3 +128,4 @@ def __repr__(self):

for cls in set(MaskedImage.values()):
supportSlicing(cls)
disableImageArithmetic(cls)
131 changes: 131 additions & 0 deletions tests/test_imageArithmetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# LSST Data Management System
# Copyright 2019 LSST Corporation.
#
# This product includes software developed by the
# LSST Project (http://www.lsst.org/).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the LSST License Statement and
# the GNU General Public License along with this program. If not,
# see <http://www.lsstcorp.org/LegalNotices/>.

import unittest

import lsst.utils.tests
from lsst.geom import Box2I, Point2I, Extent2I
import lsst.afw.image


class ImageArithmeticTestCase(lsst.utils.tests.TestCase):
def setUp(self):
self.bbox = Box2I(Point2I(12345, 56789), Extent2I(12, 34))

def tearDown(self):
del self.bbox

def testImage(self):
for cls in (lsst.afw.image.ImageI,
lsst.afw.image.ImageL,
lsst.afw.image.ImageU,
lsst.afw.image.ImageF,
lsst.afw.image.ImageD,
lsst.afw.image.DecoratedImageI,
lsst.afw.image.DecoratedImageL,
lsst.afw.image.DecoratedImageU,
lsst.afw.image.DecoratedImageF,
lsst.afw.image.DecoratedImageD,
lsst.afw.image.MaskedImageI,
lsst.afw.image.MaskedImageL,
lsst.afw.image.MaskedImageU,
lsst.afw.image.MaskedImageF,
lsst.afw.image.MaskedImageD,
lsst.afw.image.ExposureI,
lsst.afw.image.ExposureL,
lsst.afw.image.ExposureU,
lsst.afw.image.ExposureF,
lsst.afw.image.ExposureD,
):
im1 = cls(self.bbox)
im2 = cls(self.bbox)

# Image and image
with self.assertRaises(NotImplementedError):
im1 + im2
with self.assertRaises(NotImplementedError):
im1 - im2
with self.assertRaises(NotImplementedError):
im1 * im2
with self.assertRaises(NotImplementedError):
im1 / im2

# Image and scalar
with self.assertRaises(NotImplementedError):
im1 + 12345
with self.assertRaises(NotImplementedError):
im1 - 12345
with self.assertRaises(NotImplementedError):
im1 * 12345
with self.assertRaises(NotImplementedError):
im1 / 12345

# Scalar and image
with self.assertRaises(NotImplementedError):
54321 + im2
with self.assertRaises(NotImplementedError):
54321 - im2
with self.assertRaises(NotImplementedError):
54321 * im2
with self.assertRaises(NotImplementedError):
54321 / im2

def testMask(self):
for cls in (lsst.afw.image.MaskX,
):
im1 = cls(self.bbox)
im2 = cls(self.bbox)

# Image and image
with self.assertRaises(NotImplementedError):
im1 | im2
with self.assertRaises(NotImplementedError):
im1 & im2
with self.assertRaises(NotImplementedError):
im1 ^ im2

# Image and scalar
with self.assertRaises(NotImplementedError):
im1 | 12345
with self.assertRaises(NotImplementedError):
im1 & 12345
with self.assertRaises(NotImplementedError):
im1 ^ 12345

# Scalar and image
with self.assertRaises(NotImplementedError):
54321 | im2
with self.assertRaises(NotImplementedError):
54321 & im2
with self.assertRaises(NotImplementedError):
54321 ^ im2


class TestMemory(lsst.utils.tests.MemoryTestCase):
pass


def setup_module(module):
lsst.utils.tests.init()


if __name__ == "__main__":
lsst.utils.tests.init()
unittest.main()