Skip to content

Commit

Permalink
Implement support for callbacks in `colour.continuous.AbstractContinu…
Browse files Browse the repository at this point in the history
…ousFunction` class.
  • Loading branch information
KelSolaar committed May 1, 2023
1 parent 8e0c9c1 commit 09e79b4
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 41 deletions.
34 changes: 25 additions & 9 deletions colour/colorimetry/spectrum.py
Expand Up @@ -716,6 +716,19 @@ def __init__(
self._display_name: str = self.name
self.display_name = kwargs.get("display_name", self._display_name)

self._shape: SpectralShape | None = None

def _on_domain_changed(
self, name: str, value: NDArrayFloat
) -> NDArrayFloat:
"""Invalidate *self._shape* when *self._domain* is changed."""
if name == "_domain":
self._shape = None

return value

self.register_callback("on_domain_changed", _on_domain_changed)

@property
def display_name(self) -> str:
"""
Expand Down Expand Up @@ -836,17 +849,20 @@ def shape(self) -> SpectralShape:
SpectralShape(500.0, 600.0, 10.0)
"""

wavelengths = self.wavelengths
wavelengths_interval = interval(wavelengths)
if wavelengths_interval.size != 1:
runtime_warning(
f'"{self.name}" spectral distribution is not uniform, using '
f"minimum interval!"
if self._shape is None:
wavelengths = self.wavelengths
wavelengths_interval = interval(wavelengths)
if wavelengths_interval.size != 1:
runtime_warning(
f'"{self.name}" spectral distribution is not uniform, '
"using minimum interval!"
)

self._shape = SpectralShape(
wavelengths[0], wavelengths[-1], min(wavelengths_interval)
)

return SpectralShape(
wavelengths[0], wavelengths[-1], min(wavelengths_interval)
)
return self._shape

def interpolate(
self,
Expand Down
90 changes: 63 additions & 27 deletions colour/colorimetry/tests/test_spectrum.py
Expand Up @@ -1515,30 +1515,35 @@ def test_interpolate(self):
SpectralDistribution.interpolate` method.
"""

shape = SpectralShape(self._sd.shape.start, self._sd.shape.end, 1)
sd = reshape_sd(self._sd, shape, "Interpolate")
np.testing.assert_array_almost_equal(
reshape_sd(
self._sd,
SpectralShape(self._sd.shape.start, self._sd.shape.end, 1),
"Interpolate",
).values,
sd.values,
DATA_SAMPLE_INTERPOLATED,
decimal=7,
)
self.assertEqual(sd.shape, shape)

shape = SpectralShape(
self._non_uniform_sd.shape.start,
self._non_uniform_sd.shape.end,
1,
)
sd = reshape_sd(self._non_uniform_sd, shape, "Interpolate")
np.testing.assert_allclose(
reshape_sd(
self._non_uniform_sd,
SpectralShape(
self._non_uniform_sd.shape.start,
self._non_uniform_sd.shape.end,
1,
),
"Interpolate",
).values,
sd.values,
DATA_SAMPLE_INTERPOLATED_NON_UNIFORM,
rtol=0.0000001,
atol=0.0000001,
)
self.assertEqual(
sd.shape,
SpectralShape(
np.ceil(self._non_uniform_sd.shape.start),
np.floor(self._non_uniform_sd.shape.end),
1,
),
)

def test_extrapolate(self):
"""
Expand All @@ -1556,8 +1561,9 @@ def test_extrapolate(self):
sd = SpectralDistribution(
np.linspace(0, 1, 10), np.linspace(25, 35, 10)
)
shape = SpectralShape(10, 50, 10)
sd.extrapolate(
SpectralShape(10, 50, 10),
shape,
extrapolator_kwargs={
"method": "Linear",
"left": None,
Expand Down Expand Up @@ -1602,6 +1608,17 @@ def test_normalise(self):
self._sd.copy().normalise(100).values, DATA_SAMPLE_NORMALISED
)

def test_callback_on_domain_changed(self):
"""
Test :class:`colour.colorimetry.spectrum.\
SpectralDistribution` *on_domain_changed* callback.
"""

sd = self._sd.copy()
self.assertEqual(sd.shape, SpectralShape(340, 820, 20))
sd[840] = 0
self.assertEqual(sd.shape, SpectralShape(340, 840, 20))


class TestMultiSpectralDistributions(unittest.TestCase):
"""
Expand Down Expand Up @@ -1760,26 +1777,25 @@ def test_interpolate(self):
"""

# pylint: disable=E1102
msds = reshape_msds(
self._sample_msds,
SpectralShape(
self._sample_msds.shape.start, self._sample_msds.shape.end, 1
),
"Interpolate",
shape = SpectralShape(
self._sample_msds.shape.start, self._sample_msds.shape.end, 1
)
msds = reshape_msds(self._sample_msds, shape, "Interpolate")
for signal in msds.signals.values():
np.testing.assert_array_almost_equal(
signal.values, DATA_SAMPLE_INTERPOLATED, decimal=7
)
self.assertEqual(msds.shape, shape)

# pylint: disable=E1102
shape = SpectralShape(
self._non_uniform_sample_msds.shape.start,
self._non_uniform_sample_msds.shape.end,
1,
)
msds = reshape_msds(
self._non_uniform_sample_msds,
SpectralShape(
self._non_uniform_sample_msds.shape.start,
self._non_uniform_sample_msds.shape.end,
1,
),
shape,
"Interpolate",
)
for signal in msds.signals.values():
Expand All @@ -1789,6 +1805,14 @@ def test_interpolate(self):
rtol=0.0000001,
atol=0.0000001,
)
self.assertEqual(
msds.shape,
SpectralShape(
np.ceil(self._non_uniform_sample_msds.shape.start),
np.floor(self._non_uniform_sample_msds.shape.end),
1,
),
)

def test_extrapolate(self):
"""
Expand Down Expand Up @@ -1853,7 +1877,7 @@ def test_trim(self):

def test_normalise(self):
"""
Test :func:`colour.colorimetry.spectrum.
Test :func:`colour.colorimetry.spectrum.
MultiSpectralDistributions.normalise` method.
"""

Expand All @@ -1875,6 +1899,18 @@ def test_to_sds(self):
self.assertEqual(sd.name, self._labels[i])
self.assertEqual(sd.display_name, self._display_labels[i])

def test_callback_on_domain_changed(self):
"""
Test underlying :class:`colour.colorimetry.spectrum.\
SpectralDistribution` *on_domain_changed* callback when used with
:class:`colour.colorimetry.spectrum.MultiSpectralDistributions` class.
"""

msds = self._msds.copy()
self.assertEqual(msds.shape, SpectralShape(380, 780, 5))
msds[785] = 0
self.assertEqual(msds.shape, SpectralShape(380, 785, 5))


class TestReshapeSd(unittest.TestCase):
"""
Expand Down
3 changes: 2 additions & 1 deletion colour/continuous/abstract.py
Expand Up @@ -29,6 +29,7 @@
Type,
)
from colour.utilities import (
MixinCallback,
as_float,
attest,
closest,
Expand All @@ -49,7 +50,7 @@
]


class AbstractContinuousFunction(ABC):
class AbstractContinuousFunction(ABC, MixinCallback):
"""
Define the base class for abstract continuous function.
Expand Down
6 changes: 2 additions & 4 deletions colour/continuous/signal.py
Expand Up @@ -948,8 +948,7 @@ def _fill_domain_nan(
variable.
"""

self._domain = fill_nan(self._domain, method, default)
self._function = None # Invalidate the underlying continuous function.
self.domain = fill_nan(self.domain, method, default)

def _fill_range_nan(
self,
Expand All @@ -974,8 +973,7 @@ def _fill_range_nan(
variable.
"""

self._range = fill_nan(self._range, method, default)
self._function = None # Invalidate the underlying continuous function.
self.range = fill_nan(self.range, method, default)

def arithmetical_operation(
self,
Expand Down
8 changes: 8 additions & 0 deletions colour/utilities/__init__.py
Expand Up @@ -11,6 +11,10 @@
LazyCanonicalMapping,
Node,
)
from .callback import (
Callback,
MixinCallback,
)
from .common import (
CacheRegistry,
CACHE_REGISTRY,
Expand Down Expand Up @@ -124,6 +128,10 @@
"LazyCanonicalMapping",
"Node",
]
__all__ += [
"Callback",
"MixinCallback",
]
__all__ += [
"CacheRegistry",
"CACHE_REGISTRY",
Expand Down

0 comments on commit 09e79b4

Please sign in to comment.