From 442ad5be7f64e10abc817c6b4333efdd2c7af2e3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 11 May 2022 21:16:00 -0600 Subject: [PATCH 1/4] Change finite difference operator boundary handling options --- scico/linop/__init__.py | 3 +- scico/linop/_diff.py | 168 ++++++++++++++++++++++++++++------ scico/test/linop/test_diff.py | 20 +++- 3 files changed, 158 insertions(+), 33 deletions(-) diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py index 8ee3ea2e4..9b0a52dc4 100644 --- a/scico/linop/__init__.py +++ b/scico/linop/__init__.py @@ -24,7 +24,7 @@ valid_adjoint, ) from ._matrix import MatrixOperator -from ._diff import FiniteDifference +from ._diff import FiniteDifference, SingleAxisFiniteDifference from ._convolve import Convolve, ConvolveByX from ._circconv import CircularConvolve from ._dft import DFT @@ -37,6 +37,7 @@ "DFT", "Diagonal", "FiniteDifference", + "SingleAxisFiniteDifference", "Identity", "LinearOperatorStack", "MatrixOperator", diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index e16b7e198..8c2534c1e 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -12,7 +12,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Optional +from typing import Optional, Union import numpy as np @@ -25,11 +25,14 @@ class FiniteDifference(LinearOperatorStack): - """Finite Difference operator. + """Finite difference operator. Computes finite differences along the specified axes, returning the results in a `DeviceArray` (whenever possible) or `BlockArray`. See :class:`LinearOperatorStack` for details on how this choice is made. + See :class:`SingleAxisFiniteDifference` for the mathematical + implications of the different boundary handling options `prepend`, + `append`, and `circular`. Example ------- @@ -48,7 +51,8 @@ def __init__( input_shape: Shape, input_dtype: DType = np.float32, axes: Optional[Axes] = None, - append: Optional[float] = None, + prepend: Optional[Union[0, 1]] = None, + append: Optional[Union[0, 1]] = None, circular: bool = False, jit: bool = True, **kwargs, @@ -63,12 +67,19 @@ def __init__( axes: Axis or axes over which to apply finite difference operator. If not specified, or ``None``, differences are evaluated along all axes. - append: Value to append to the input along each axis before - taking differences. Zero is a typical choice. If not - ``None``, `circular` must be ``False``. + prepend: Flag indicating handling of the left/top/etc. + boundary. If ``None``, there is no boundary extension. + Values of `0` or `1` indicate respectively that zeros or + the initial value in the array are prepended to the + difference array. + append: Flag indicating handling of the right/bottom/etc. + boundary. If ``None``, there is no boundary extension. + Values of `0` or `1` indicate respectively that zeros or + -1 times the final value in the array are appended to the + difference array. circular: If ``True``, perform circular differences, i.e., - include x[-1] - x[0]. If ``True``, `append` must be - ``None``. + include x[-1] - x[0]. If ``True``, `prepend` and `append + must both be ``None``. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ @@ -80,8 +91,13 @@ def __init__( else: axes_list = (axes,) self.axes = parse_axes(axes_list, input_shape) - single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False) - ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list] + single_kwargs = dict( + input_dtype=input_dtype, prepend=prepend, append=append, circular=circular, jit=False + ) + ops = [ + SingleAxisFiniteDifference(input_shape, axis=axis, **single_kwargs) + for axis in axes_list + ] super().__init__( ops, # type: ignore @@ -90,69 +106,146 @@ def __init__( ) -class FiniteDifferenceSingleAxis(LinearOperator): - """Finite Difference operator acting along a single axis.""" +class SingleAxisFiniteDifference(LinearOperator): + r"""Finite difference operator acting along a single axis. + + By default (i.e. `prepend` and `append` set to ``None`` and `circular` + set to ``False``), the difference operator corresponds to the matrix + + .. math:: + + \left(\begin{array}{rrrrr} + -1 & 1 & 0 & \ldots & 0\\ + 0 & -1 & 1 & \ldots & 0\\ + \vdots & \vdots & \ddots & \ddots & \vdots\\ + 0 & 0 & \ldots & -1 & 1 + \end{array}\right) \;, + + mapping :math:`\mbb{R}^N \rightarrow \mbb{R}^{N-1}`, while if `circular` + is ``True``, it corresponds to the :math:`\mbb{R}^N \rightarrow \mbb{R}^N` + mapping + + .. math:: + + \left(\begin{array}{rrrrr} + -1 & 1 & 0 & \ldots & 0\\ + 0 & -1 & 1 & \ldots & 0\\ + \vdots & \vdots & \ddots & \ddots & \vdots\\ + 0 & 0 & \ldots & -1 & 1\\ + 1 & 0 & \dots & 0 & -1 + \end{array}\right) \;. + + Other possible choices include `prepend` set to ``None`` and `append` + set to `0`, giving the :math:`\mbb{R}^N \rightarrow \mbb{R}^N` + mapping + + .. math:: + + \left(\begin{array}{rrrrr} + -1 & 1 & 0 & \ldots & 0\\ + 0 & -1 & 1 & \ldots & 0\\ + \vdots & \vdots & \ddots & \ddots & \vdots\\ + 0 & 0 & \ldots & -1 & 1\\ + 0 & 0 & \dots & 0 & 0 + \end{array}\right) \;, + + and both `prepend` and `append` set to `1`, giving the + :math:`\mbb{R}^N \rightarrow \mbb{R}^{N+1}` mapping + + .. math:: + + \left(\begin{array}{rrrrr} + 1 & 0 & 0 & \ldots & 0\\ + -1 & 1 & 0 & \ldots & 0\\ + 0 & -1 & 1 & \ldots & 0\\ + \vdots & \vdots & \ddots & \ddots & \vdots\\ + 0 & 0 & \ldots & -1 & 1\\ + 0 & 0 & \dots & 0 & -1 + \end{array}\right) \;. + """ def __init__( self, - axis: int, input_shape: Shape, input_dtype: DType = np.float32, - append: Optional[float] = None, + axis: int = -1, + prepend: Optional[Union[0, 1]] = None, + append: Optional[Union[0, 1]] = None, circular: bool = False, jit: bool = True, **kwargs, ): r""" Args: - axis: Axis over which to apply finite difference operator. input_shape: Shape of input array. input_dtype: `dtype` for input argument. Defaults to ``float32``. If `LinearOperator` implements complex-valued operations, this must be ``complex64`` for proper adjoint and gradient calculation. - append: Value to append to the input along `axis` before - taking differences. Defaults to 0. + axis: Axis over which to apply finite difference operator. + prepend: Flag indicating handling of the left/top/etc. + boundary. If ``None``, there is no boundary extension. + Values of `0` or `1` indicate respectively that zeros or + the initial value in the array are prepended to the + difference array. + append: Flag indicating handling of the right/bottom/etc. + boundary. If ``None``, there is no boundary extension. + Values of `0` or `1` indicate respectively that zeros or + -1 times the final value in the array are appended to the + difference array. circular: If ``True``, perform circular differences, i.e., - include x[-1] - x[0]. If ``True``, `append` must be - ``None``. + include x[-1] - x[0]. If ``True``, `prepend` and `append + must both be ``None``. jit: If ``True``, jit the evaluation, adjoint, and gram functions of the LinearOperator. """ if not isinstance(axis, int): - raise TypeError(f"Expected `axis` to be of type int, got {type(axis)} instead") + raise TypeError(f"Expected axis to be of type int, got {type(axis)} instead.") + if axis < 0: + axis = len(input_shape) + axis if axis >= len(input_shape): raise ValueError( - f"Invalid axis {axis} specified; `axis` must be less than " - f"`len(input_shape)`={len(input_shape)}" + f"Invalid axis {axis} specified; axis must be less than " + f"len(input_shape)={len(input_shape)}." ) self.axis = axis - if append is not None and circular: + if circular and (prepend is not None or append is not None): raise ValueError( - "`append` and `circular` are mutually exclusive but both were specified" + "Parameter circular must be False if either prepend or append is not None." ) + if prepend not in [None, 0, 1]: + raise ValueError("Parameter prepend may only take values None, 0, or 1.") + if append not in [None, 0, 1]: + raise ValueError("Parameter append may only take values None, 0, or 1.") - self.circular = circular + self.prepend = prepend self.append = append + self.circular = circular - if self.append is None and not circular: - output_shape = tuple(x - (i == axis) for i, x in enumerate(input_shape)) - else: + if self.circular: output_shape = input_shape + else: + output_shape = tuple( + x + ((i == axis) * ((self.prepend is not None) + (self.append is not None) - 1)) + for i, x in enumerate(input_shape) + ) super().__init__( input_shape=input_shape, output_shape=output_shape, input_dtype=input_dtype, + output_dtype=input_dtype, jit=jit, **kwargs, ) def _eval(self, x: JaxArray) -> JaxArray: + prepend = None + append = None if self.circular: # set append to the first slice along the specified axis ind = tuple( @@ -160,6 +253,21 @@ def _eval(self, x: JaxArray) -> JaxArray: ) append = x[ind] else: - append = self.append + if self.prepend == 0: + ind = tuple( + slice(0, 1) if i == self.axis else slice(None) + for i in range(len(self.input_shape)) + ) + prepend = x[ind] + elif self.prepend == 1: + prepend = 0 + if self.append == 0: + ind = tuple( + slice(-1, None) if i == self.axis else slice(None) + for i in range(len(self.input_shape)) + ) + append = x[ind] + elif self.append == 1: + append = 0 - return snp.diff(x, axis=self.axis, append=append) + return snp.diff(x, axis=self.axis, prepend=prepend, append=append) diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index 94ac3d605..ed12e540e 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -3,7 +3,7 @@ import pytest import scico.numpy as snp -from scico.linop import FiniteDifference +from scico.linop import FiniteDifference, SingleAxisFiniteDifference from scico.random import randn from scico.test.linop.test_linop import adjoint_test @@ -12,7 +12,7 @@ def test_eval(): with pytest.raises(ValueError): # axis 3 does not exist A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3)) - A = FiniteDifference(input_shape=(2, 3), append=0.0) + A = FiniteDifference(input_shape=(2, 3), append=1) x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32) @@ -25,6 +25,22 @@ def test_eval(): snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows +def test_eval_prepend(): + x = snp.arange(1, 6) + A = SingleAxisFiniteDifference(input_shape=(5,), prepend=0) + snp.testing.assert_allclose(A @ x, snp.array([0, 1, 1, 1, 1])) + A = SingleAxisFiniteDifference(input_shape=(5,), prepend=1) + snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 1])) + + +def test_eval_append(): + x = snp.arange(1, 6) + A = SingleAxisFiniteDifference(input_shape=(5,), append=0) + snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 0])) + A = SingleAxisFiniteDifference(input_shape=(5,), append=1) + snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, -5])) + + @pytest.mark.parametrize("input_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("input_shape", [(16,), (16, 24)]) @pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None]) From 710979a5407bc1cf2348e287af6b7039c456d801 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 12 May 2022 07:08:29 -0600 Subject: [PATCH 2/4] Resolve typing issues --- scico/linop/_diff.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 8c2534c1e..4625fd12b 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -12,7 +12,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Optional, Union +from typing import Literal, Optional, Union import numpy as np @@ -51,8 +51,8 @@ def __init__( input_shape: Shape, input_dtype: DType = np.float32, axes: Optional[Axes] = None, - prepend: Optional[Union[0, 1]] = None, - append: Optional[Union[0, 1]] = None, + prepend: Optional[Union[Literal[0], Literal[1]]] = None, + append: Optional[Union[Literal[0], Literal[1]]] = None, circular: bool = False, jit: bool = True, **kwargs, @@ -169,8 +169,8 @@ def __init__( input_shape: Shape, input_dtype: DType = np.float32, axis: int = -1, - prepend: Optional[Union[0, 1]] = None, - append: Optional[Union[0, 1]] = None, + prepend: Optional[Union[Literal[0], Literal[1]]] = None, + append: Optional[Union[Literal[0], Literal[1]]] = None, circular: bool = False, jit: bool = True, **kwargs, From 6b1eeaf613cbebec4a166a6e23822a44414bb5e2 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 12 May 2022 07:15:53 -0600 Subject: [PATCH 3/4] Add/modify comments --- scico/linop/_diff.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 4625fd12b..87caa218f 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -247,27 +247,36 @@ def _eval(self, x: JaxArray) -> JaxArray: prepend = None append = None if self.circular: - # set append to the first slice along the specified axis + # Append a copy of the initial value at the end of the array so that the difference + # array includes the difference across the right/bottom/etc. boundary. ind = tuple( slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) append = x[ind] else: if self.prepend == 0: + # Prepend a 0 to the difference array by prepending a copy of the initial value + # before the difference is computed. ind = tuple( slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) prepend = x[ind] elif self.prepend == 1: + # Prepend a copy of the initial value to the difference array by prepending a 0 + # before the difference is computed. prepend = 0 if self.append == 0: + # Append a 0 to the difference array by appending a copy of the initial value + # before the difference is computed. ind = tuple( slice(-1, None) if i == self.axis else slice(None) for i in range(len(self.input_shape)) ) append = x[ind] elif self.append == 1: + # Append a copy of the initial value to the difference array by appending a 0 + # before the difference is computed. append = 0 return snp.diff(x, axis=self.axis, prepend=prepend, append=append) From 0566b6da0bd5b9db0862ba2dca9b26db5101aa3f Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 12 May 2022 08:13:02 -0600 Subject: [PATCH 4/4] Add some tests --- scico/test/linop/test_diff.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/scico/test/linop/test_diff.py b/scico/test/linop/test_diff.py index ed12e540e..1a13cc475 100644 --- a/scico/test/linop/test_diff.py +++ b/scico/test/linop/test_diff.py @@ -25,6 +25,20 @@ def test_eval(): snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows +def test_except(): + with pytest.raises(TypeError): # axis is not an int + A = SingleAxisFiniteDifference(input_shape=(3,), axis=2.5) + + with pytest.raises(ValueError): # invalid parameter combination + A = SingleAxisFiniteDifference(input_shape=(3,), prepend=0, circular=True) + + with pytest.raises(ValueError): # invalid prepend value + A = SingleAxisFiniteDifference(input_shape=(3,), prepend=2) + + with pytest.raises(ValueError): # invalid append value + A = SingleAxisFiniteDifference(input_shape=(3,), append="a") + + def test_eval_prepend(): x = snp.arange(1, 6) A = SingleAxisFiniteDifference(input_shape=(5,), prepend=0)