From 26c2bbab24a2cc5c679c4fafe594e3cdfa92f1bc Mon Sep 17 00:00:00 2001 From: Thomas Mansencal Date: Mon, 30 Aug 2021 23:16:34 +1200 Subject: [PATCH] Improve coverage of "colour.recovery.otsu2018" module. --- colour/recovery/otsu2018.py | 427 +++++++++++-------------- colour/recovery/tests/test_otsu2018.py | 128 +++++++- 2 files changed, 311 insertions(+), 244 deletions(-) diff --git a/colour/recovery/otsu2018.py b/colour/recovery/otsu2018.py index d6256167ec..14b43e81e4 100644 --- a/colour/recovery/otsu2018.py +++ b/colour/recovery/otsu2018.py @@ -32,7 +32,7 @@ if is_tqdm_installed(): from tqdm import tqdm -else: +else: # pragma: no cover from unittest import mock tqdm = mock.MagicMock() @@ -250,12 +250,6 @@ def read(self, path): path : unicode Path to the file. - Raises - ------ - ValueError, KeyError - Raised when loading the file succeeded but it did not contain the - expected data. - Examples -------- >>> import os @@ -275,23 +269,13 @@ def read(self, path): >>> dataset.read(path) # doctest: +SKIP """ - npz = np.load(path) + data = np.load(path) - if not isinstance(npz, np.lib.npyio.NpzFile): - raise ValueError('The loaded file is not an ".npz" type file!') - - start, end, interval = npz['shape'] + start, end, interval = data['shape'] self._shape = SpectralShape(start, end, interval) - self._basis_functions = npz['basis_functions'] - self._means = npz['means'] - self._selector_array = npz['selector_array'] - - n, three, m = self._basis_functions.shape - if (three != 3 or self._means.shape != (n, m) or - self._selector_array.shape[1] != 4): - raise ValueError( - 'Unexpected array shapes encountered, the file could be ' - 'corrupted or in a wrong format!') + self._basis_functions = data['basis_functions'] + self._means = data['means'] + self._selector_array = data['selector_array'] def write(self, path): """ @@ -526,14 +510,18 @@ class Data: ---------- - :attr:`~colour.recovery.otsu2018.Data.tree` - :attr:`~colour.recovery.otsu2018.Data.reflectances` - - :attr:`~colour.recovery.otsu2018.Data.XYZ` - - :attr:`~colour.recovery.otsu2018.Data.xy` + - :attr:`~colour.recovery.otsu2018.Data.reflectances` + - :attr:`~colour.recovery.otsu2018.Data.basis_functions` + - :attr:`~colour.recovery.otsu2018.Data.mean` Methods ------- - - :meth:`~colour.recovery.otsu2018.Data.__init__` - :meth:`~colour.recovery.otsu2018.Data.__str__` - :meth:`~colour.recovery.otsu2018.Data.__len__` + - :meth:`~colour.recovery.otsu2018.Data.PCA` + - :meth:`~colour.recovery.otsu2018.Data.reconstruct` + - :meth:`~colour.recovery.otsu2018.Data.reconstruction_error` + - :meth:`~colour.recovery.otsu2018.Data.origin` - :meth:`~colour.recovery.otsu2018.Data.partition` """ @@ -541,9 +529,18 @@ def __init__(self, tree, reflectances): self._tree = tree self._XYZ = None self._xy = None + + self._M = None + self._XYZ_mu = None + + self._mean = None + self._basis_functions = None + self._reflectances = None self.reflectances = reflectances + self._reconstruction_error = None + @property def tree(self): """ @@ -589,33 +586,34 @@ def reflectances(self, value): self.tree.cmfs, self.tree.illuminant, shape=self.tree.cmfs.shape) / 100 + self._xy = XYZ_to_xy(self._XYZ) @property - def XYZ(self): + def basis_functions(self): """ - Getter property for the colour data *CIE XYZ* tristimulus values. + Getter property for the node basis functions. Returns ------- - ndarray - Colour data *CIE XYZ* tristimulus values. + array_like + Node basis functions. """ - return self._XYZ + return self._basis_functions @property - def xy(self): + def mean(self): """ - Getter property for the colour data *CIE xy* tristimulus values. + Getter property for the node mean distribution. Returns ------- - ndarray - Colour data *CIE xy* tristimulus values. + array_like + Node mean distribution. """ - return self._xy + return self._mean def __str__(self): """ @@ -642,6 +640,105 @@ def __len__(self): return self._reflectances.shape[0] + def PCA(self): + """ + Performs the *Principal Component Analysis* (PCA) on the colours data + of the node and sets the relevant private attributes accordingly. + + Raises + ------ + RuntimeError + If the node is not a leaf node. + """ + + if self._M is not None: + return + + settings = { + 'cmfs': self._tree.cmfs, + 'illuminant': self._tree.illuminant, + 'shape': self._tree.cmfs.shape + } + self._mean = np.mean(self.reflectances, axis=0) + self._XYZ_mu = msds_to_XYZ_integration(self._mean, **settings) / 100 + + matrix_data = self.reflectances - self._mean + matrix_covariance = np.dot(np.transpose(matrix_data), matrix_data) + _eigenvalues, eigenvectors = np.linalg.eigh(matrix_covariance) + self._basis_functions = np.transpose(eigenvectors[:, -3:]) + + self._M = np.transpose( + msds_to_XYZ_integration(self._basis_functions, **settings) / 100) + + def reconstruct(self, XYZ): + """ + Reconstructs the reflectance for the given *CIE XYZ* tristimulus + values. + + If the node is a leaf, the colour data from the node is used, otherwise + the branch is traversed recursively to find the leaves. + + Parameters + ---------- + XYZ : ndarray, (3,) + *CIE XYZ* tristimulus values to recover the spectral distribution + from. + + Returns + ------- + SpectralDistribution + Recovered spectral distribution. + """ + + weights = np.dot(np.linalg.inv(self._M), XYZ - self._XYZ_mu) + reflectance = np.dot(weights, self._basis_functions) + self._mean + reflectance = np.clip(reflectance, 0, 1) + + return SpectralDistribution(reflectance, self._tree.cmfs.wavelengths) + + def reconstruction_error(self): + """ + Reconstructs the reflectance of the *CIE XYZ* tristimulus values in + the colour data of this node using PCA and compares the reconstructed + spectrum against the measured spectrum. The reconstruction errors are + then summed up and returned. + + Returns + ------- + error : float + The reconstruction errors summation for the node. + + Notes + ----- + The reconstruction error is cached upon being computed and thus is only + computed once per node. + + Raises + ------ + RuntimeError + If the node is not a leaf node. + """ + + if self._reconstruction_error is not None: + return self._reconstruction_error + else: + + self.PCA() + + error = 0 + for i in range(len(self)): + sd = self._reflectances[i, :] + XYZ = self._XYZ[i, :] + recovered_sd = self.reconstruct(XYZ) + error += np.sum((sd - recovered_sd.values) ** 2) + + self._reconstruction_error = error + + return error + + def origin(self, i, direction): + return self._xy[i, direction] + def partition(self, axis): """ Parameters @@ -660,16 +757,16 @@ def partition(self, axis): lesser = Data(self.tree, None) greater = Data(self.tree, None) - mask = self.xy[:, axis.direction] <= axis.origin + mask = self._xy[:, axis.direction] <= axis.origin - lesser._reflectances = self.reflectances[mask, :] - greater._reflectances = self.reflectances[~mask, :] + lesser._reflectances = self._reflectances[mask, :] + greater._reflectances = self._reflectances[~mask, :] - lesser._XYZ = self.XYZ[mask, :] - greater._XYZ = self.XYZ[~mask, :] + lesser._XYZ = self._XYZ[mask, :] + greater._XYZ = self._XYZ[~mask, :] - lesser._xy = self.xy[mask, :] - greater._xy = self.xy[~mask, :] + lesser._xy = self._xy[mask, :] + greater._xy = self._xy[~mask, :] return lesser, greater @@ -728,15 +825,8 @@ def __init__(self, tree, data): self._data = data self._children = [] self._partition_axis = None - self._mean = None - self._basis_functions = None - - self._M = None - self._M_inverse = None - self._XYZ_mu = None self._best_partition = None - self._cached_leaf_reconstruction_error = None @property def id(self): @@ -790,45 +880,6 @@ def children(self): return self._children - @property - def partition_axis(self): - """ - Getter property for the node partition axis. - - Returns - ------- - PartitionAxis - Node partition axis. - """ - - return self._partition_axis - - @property - def basis_functions(self): - """ - Getter property for the node basis functions. - - Returns - ------- - array_like - Node basis functions. - """ - - return self._basis_functions - - @property - def mean(self): - """ - Getter property for the node mean distribution. - - Returns - ------- - array_like - Node mean distribution. - """ - - return self._mean - @property def leaves(self): """ @@ -844,9 +895,20 @@ def leaves(self): yield self else: for child in self._children: - # TODO: Python 3 "yield from child.leaves". - for leaf in child.leaves: - yield leaf + yield from child.leaves + + @property + def partition_axis(self): + """ + Getter property for the node partition axis. + + Returns + ------- + PartitionAxis + Node partition axis. + """ + + return self._partition_axis def __str__(self): """ @@ -902,96 +964,59 @@ def split(self, children, partition_axis): Partition axis. """ - self._data = None self._children = children self._partition_axis = partition_axis - self._mean = None - self._basis_functions = None - - self._M = None - self._M_inverse = None - self._XYZ_mu = None + self._data = None self._best_partition = None - self._cached_leaf_reconstruction_error = None - - # - # PCA and Reconstruction - # - def PCA(self): + def find_best_partition(self): """ - Performs the *Principal Component Analysis* (PCA) on the colours data - of the node and sets the relevant private attributes accordingly. + Finds the best partition for the node. - Raises - ------ - RuntimeError - If the node is not a leaf node. + Returns + ------- + partition_error : float + Partition error + axis : PartitionAxis + Horizontal or vertical line, partitioning the 2D space in + two half-planes. + partition : tuple + Nodes created by splitting a node with a given partition. """ - if not self.is_leaf(): - raise RuntimeError('{0} is not a leaf node!'.format(self)) - - if self._M is not None: - return - - settings = { - 'cmfs': self._tree.cmfs, - 'illuminant': self._tree.illuminant, - 'shape': self._tree.cmfs.shape - } - self._mean = np.mean(self._data.reflectances, axis=0) - self._XYZ_mu = msds_to_XYZ_integration(self._mean, **settings) / 100 - - matrix_data = self._data.reflectances - self._mean - matrix_covariance = np.dot(np.transpose(matrix_data), matrix_data) - _eigenvalues, eigenvectors = np.linalg.eigh(matrix_covariance) - self._basis_functions = np.transpose(eigenvectors[:, -3:]) - - self._M = np.transpose( - msds_to_XYZ_integration(self._basis_functions, **settings) / 100) - self._M_inverse = np.linalg.inv(self._M) - - def reconstruct(self, XYZ): - """ - Reconstructs the reflectance for the given *CIE XYZ* tristimulus - values. + if self._best_partition is not None: + return self._best_partition - If the node is a leaf, the colour data from the node is used, otherwise - the branch is traversed recursively to find the leaves. + leaf_error = self.leaf_reconstruction_error() + best_error = None - Parameters - ---------- - XYZ : ndarray, (3,) - *CIE XYZ* tristimulus values to recover the spectral distribution - from. + with tqdm(total=2 * len(self.data)) as progress: + for direction in [0, 1]: + for i in range(len(self.data)): + progress.update() - Returns - ------- - SpectralDistribution - Recovered spectral distribution. - """ + axis = PartitionAxis( + self.data.origin(i, direction), direction) - xy = XYZ_to_xy(XYZ) + try: + partition_error, partition = ( + self.partition_reconstruction_error(axis)) + except RuntimeError: + continue - if not self.is_leaf(): - if (xy[self._partition_axis.direction] <= - self._partition_axis.origin): - return self._children[0].reconstruct(XYZ) - else: - return self._children[1].reconstruct(XYZ) + if partition_error >= leaf_error: + continue - weights = np.dot(self._M_inverse, XYZ - self._XYZ_mu) - reflectance = np.dot(weights, self._basis_functions) + self._mean - reflectance = np.clip(reflectance, 0, 1) + if best_error is None or partition_error < best_error: + self._best_partition = (partition_error, axis, + partition) - return SpectralDistribution(reflectance, self._tree.cmfs.wavelengths) + if self._best_partition is None: + raise RuntimeError('Could not find a best partition!') - # - # Optimisation - # + return self._best_partition def leaf_reconstruction_error(self): """ @@ -1016,25 +1041,7 @@ def leaf_reconstruction_error(self): If the node is not a leaf node. """ - if not self.is_leaf(): - raise RuntimeError('{0} is not a leaf node!'.format(self)) - - if self._cached_leaf_reconstruction_error: - return self._cached_leaf_reconstruction_error - - if self._M is None: - self.PCA() - - error = 0 - for i in range(len(self.data)): - sd = self.data.reflectances[i, :] - XYZ = self.data.XYZ[i, :] - recovered_sd = self.reconstruct(XYZ) - error += np.sum((sd - recovered_sd.values) ** 2) - - self._cached_leaf_reconstruction_error = error - - return error + return self._data.reconstruction_error() def branch_reconstruction_error(self): """ @@ -1082,62 +1089,16 @@ def partition_reconstruction_error(self, axis): 'than the minimum cluster size!') lesser = Node(self._tree, partition[0]) - lesser.PCA() + lesser._data.PCA() greater = Node(self._tree, partition[1]) - greater.PCA() + greater._data.PCA() error = (lesser.leaf_reconstruction_error() + greater.leaf_reconstruction_error()) return error, (lesser, greater) - def find_best_partition(self): - """ - Finds the best partition for the node. - - Returns - ------- - partition_error : float - Partition error - axis : PartitionAxis - Horizontal or vertical line, partitioning the 2D space in - two half-planes. - partition : tuple - Nodes created by splitting a node with a given partition. - """ - - if self._best_partition is not None: - return self._best_partition - - leaf_error = self.leaf_reconstruction_error() - best_error = None - - with tqdm(total=2 * len(self.data)) as progress: - for direction in [0, 1]: - for i in range(len(self.data)): - progress.update() - origin = self.data.xy[i, direction] - axis = PartitionAxis(origin, direction) - - try: - partition_error, partition = ( - self.partition_reconstruction_error(axis)) - except RuntimeError: - continue - - if partition_error >= leaf_error: - continue - - if best_error is None or partition_error < best_error: - self._best_partition = (partition_error, axis, - partition) - - if self._best_partition is None: - raise RuntimeError('Could not find a best partition!') - - return self._best_partition - class NodeTree_Otsu2018(Node): """ @@ -1499,7 +1460,7 @@ def optimise(self, best_partition = partition if optimised_total_error is None: - print_callable('\nNo further improvements are possible!\n' + print_callable('\nNo further improvement is possible!\n' 'Terminating at iteration {0}.\n'.format(i)) break @@ -1547,8 +1508,8 @@ def to_dataset(self): """ - basis_functions = [leaf.basis_functions for leaf in self.leaves] - means = [leaf.mean for leaf in self.leaves] + basis_functions = [leaf._data._basis_functions for leaf in self.leaves] + means = [leaf._data._mean for leaf in self.leaves] selector_array = self._create_selector_array() return Dataset_Otsu2018(self._cmfs.shape, basis_functions, means, diff --git a/colour/recovery/tests/test_otsu2018.py b/colour/recovery/tests/test_otsu2018.py index 1d4b510adb..e965764a76 100644 --- a/colour/recovery/tests/test_otsu2018.py +++ b/colour/recovery/tests/test_otsu2018.py @@ -16,7 +16,7 @@ from colour.models import XYZ_to_Lab, XYZ_to_xy from colour.recovery import (XYZ_to_sd_Otsu2018, SPECTRAL_SHAPE_OTSU2018, Dataset_Otsu2018, NodeTree_Otsu2018) -from colour.recovery.otsu2018 import Data, Node +from colour.recovery.otsu2018 import DATASET_REFERENCE_OTSU2018, Data, Node from colour.utilities import domain_range_scale, metric_mse __author__ = 'Colour Developers' @@ -38,6 +38,27 @@ class TestDataset_Otsu2018(unittest.TestCase): tests methods. """ + def setUp(self): + """ + Initialises common tests attributes. + """ + + self._dataset = DATASET_REFERENCE_OTSU2018 + self._xy = np.array([0.54369557, 0.32107944]) + + self._temporary_directory = tempfile.mkdtemp() + + self._path = os.path.join(self._temporary_directory, + 'Test_Otsu2018.npz') + self._dataset.write(self._path) + + def tearDown(self): + """ + After tests actions. + """ + + shutil.rmtree(self._temporary_directory) + def test_required_attributes(self): """ Tests presence of required attributes. @@ -54,11 +75,95 @@ def test_required_methods(self): Tests presence of required methods. """ - required_methods = ('__init__', 'select', 'cluster', 'read', 'write') + required_methods = ('__init__', '__str__', 'select', 'cluster', 'read', + 'write') for method in required_methods: self.assertIn(method, dir(Dataset_Otsu2018)) + def test_shape(self): + """ + Tests :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.shape` property. + """ + + self.assertEqual(self._dataset.shape, SPECTRAL_SHAPE_OTSU2018) + + def test_basis_functions(self): + """ + Tests :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.basis_functions` + property. + """ + + self.assertTupleEqual(self._dataset.basis_functions.shape, (8, 3, 36)) + + def test_means(self): + """ + Tests :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.means` + property. + """ + + self.assertTupleEqual(self._dataset.means.shape, (8, 36)) + + def test_selector_array(self): + """ + Tests :attr:`colour.recovery.otsu2018.Dataset_Otsu2018.selector_array` + property. + """ + + self.assertTupleEqual(self._dataset.selector_array.shape, (7, 4)) + + def test__str__(self): + """ + Tests :func:`colour.recovery.otsu2018.Dataset_Otsu2018.__str__` method. + """ + + self.assertEqual( + str(self._dataset), 'Dataset_Otsu2018(8 basis functions)') + + def test_select(self): + """ + Tests :func:`colour.recovery.otsu2018.Dataset_Otsu2018.select` method. + """ + + self.assertEqual(self._dataset.select(self._xy), 6) + + def test_cluster(self): + """ + Tests :func:`colour.recovery.otsu2018.Dataset_Otsu2018.cluster` method. + """ + + basis_functions, means = self._dataset.cluster(self._xy) + self.assertTupleEqual(basis_functions.shape, (3, 36)) + self.assertTupleEqual(means.shape, (36, )) + + def test_read(self): + """ + Tests :func:`colour.recovery.otsu2018.Dataset_Otsu2018.read` method. + """ + + dataset = Dataset_Otsu2018() + dataset.read(self._path) + + self.assertEqual(dataset.shape, SPECTRAL_SHAPE_OTSU2018) + self.assertTupleEqual(dataset.basis_functions.shape, (8, 3, 36)) + self.assertTupleEqual(dataset.means.shape, (8, 36)) + self.assertTupleEqual(dataset.selector_array.shape, (7, 4)) + + def test_write(self): + """ + Tests :func:`colour.recovery.otsu2018.Dataset_Otsu2018.write` method. + """ + + self._dataset.write(self._path) + + dataset = Dataset_Otsu2018() + dataset.read(self._path) + + self.assertEqual(dataset.shape, SPECTRAL_SHAPE_OTSU2018) + self.assertTupleEqual(dataset.basis_functions.shape, (8, 3, 36)) + self.assertTupleEqual(dataset.means.shape, (8, 36)) + self.assertTupleEqual(dataset.selector_array.shape, (7, 4)) + class TestXYZ_to_sd_Otsu2018(unittest.TestCase): """ @@ -135,7 +240,8 @@ def test_required_attributes(self): Tests presence of required attributes. """ - required_attributes = ('tree', 'reflectances', 'XYZ', 'xy') + required_attributes = ('tree', 'reflectances', 'reflectances', + 'basis_functions', 'mean') for attribute in required_attributes: self.assertIn(attribute, dir(Data)) @@ -145,7 +251,9 @@ def test_required_methods(self): Tests presence of required methods. """ - required_methods = ('__init__', '__str__', '__len__', 'partition') + required_methods = ('__init__', '__str__', '__len__', 'PCA', + 'reconstruct', 'reconstruction_error', 'origin', + 'partition') for method in required_methods: self.assertIn(method, dir(Data)) @@ -162,9 +270,8 @@ def test_required_attributes(self): Tests presence of required attributes. """ - required_attributes = ('id', 'tree', 'data', 'children', - 'partition_axis', 'basis_functions', 'mean', - 'leaves') + required_attributes = ('id', 'tree', 'data', 'children', 'leaves', + 'partition_axis') for attribute in required_attributes: self.assertIn(attribute, dir(Node)) @@ -175,11 +282,10 @@ def test_required_methods(self): """ required_methods = ('__init__', '__str__', '__len__', 'is_leaf', - 'split', 'PCA', 'reconstruct', + 'split', 'find_best_partition', 'leaf_reconstruction_error', 'branch_reconstruction_error', - 'partition_reconstruction_error', - 'find_best_partition') + 'partition_reconstruction_error') for method in required_methods: self.assertIn(method, dir(Node)) @@ -246,7 +352,7 @@ def test_NodeTree_Otsu2018_and_Dataset_Otsu2018(self): reflectances.append(reshape_sd(sd, self._shape).values) node_tree = NodeTree_Otsu2018(reflectances, self._cmfs, self._sd_D65) - node_tree.optimise(iterations=2) + node_tree.optimise(iterations=5) path = os.path.join(self._temporary_directory, 'Test_Otsu2018.npz') dataset = node_tree.to_dataset()