Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify FiniteDifference linop #296

Merged
merged 4 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
valid_adjoint,
)
from ._matrix import MatrixOperator
from ._diff import FiniteDifference
from ._diff import FiniteDifference, SingleAxisFiniteDifference
from ._convolve import Convolve, ConvolveByX
from ._circconv import CircularConvolve
from ._dft import DFT
Expand All @@ -37,6 +37,7 @@
"DFT",
"Diagonal",
"FiniteDifference",
"SingleAxisFiniteDifference",
"Identity",
"LinearOperatorStack",
"MatrixOperator",
Expand Down
179 changes: 148 additions & 31 deletions scico/linop/_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from typing import Optional
from typing import Literal, Optional, Union

import numpy as np

Expand All @@ -25,11 +25,14 @@


class FiniteDifference(LinearOperatorStack):
"""Finite Difference operator.
"""Finite difference operator.

Computes finite differences along the specified axes, returning the
results in a `DeviceArray` (whenever possible) or `BlockArray`. See
:class:`LinearOperatorStack` for details on how this choice is made.
See :class:`SingleAxisFiniteDifference` for the mathematical
implications of the different boundary handling options `prepend`,
`append`, and `circular`.

Example
-------
Expand All @@ -48,7 +51,8 @@ def __init__(
input_shape: Shape,
input_dtype: DType = np.float32,
axes: Optional[Axes] = None,
append: Optional[float] = None,
prepend: Optional[Union[Literal[0], Literal[1]]] = None,
append: Optional[Union[Literal[0], Literal[1]]] = None,
circular: bool = False,
jit: bool = True,
**kwargs,
Expand All @@ -63,12 +67,19 @@ def __init__(
axes: Axis or axes over which to apply finite difference
operator. If not specified, or ``None``, differences are
evaluated along all axes.
append: Value to append to the input along each axis before
taking differences. Zero is a typical choice. If not
``None``, `circular` must be ``False``.
prepend: Flag indicating handling of the left/top/etc.
boundary. If ``None``, there is no boundary extension.
Values of `0` or `1` indicate respectively that zeros or
the initial value in the array are prepended to the
difference array.
append: Flag indicating handling of the right/bottom/etc.
boundary. If ``None``, there is no boundary extension.
Values of `0` or `1` indicate respectively that zeros or
-1 times the final value in the array are appended to the
difference array.
circular: If ``True``, perform circular differences, i.e.,
include x[-1] - x[0]. If ``True``, `append` must be
``None``.
include x[-1] - x[0]. If ``True``, `prepend` and `append
must both be ``None``.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""
Expand All @@ -80,8 +91,13 @@ def __init__(
else:
axes_list = (axes,)
self.axes = parse_axes(axes_list, input_shape)
single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False)
ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list]
single_kwargs = dict(
input_dtype=input_dtype, prepend=prepend, append=append, circular=circular, jit=False
)
ops = [
SingleAxisFiniteDifference(input_shape, axis=axis, **single_kwargs)
for axis in axes_list
]

super().__init__(
ops, # type: ignore
Expand All @@ -90,76 +106,177 @@ def __init__(
)


class FiniteDifferenceSingleAxis(LinearOperator):
"""Finite Difference operator acting along a single axis."""
class SingleAxisFiniteDifference(LinearOperator):
r"""Finite difference operator acting along a single axis.

By default (i.e. `prepend` and `append` set to ``None`` and `circular`
set to ``False``), the difference operator corresponds to the matrix

.. math::

\left(\begin{array}{rrrrr}
-1 & 1 & 0 & \ldots & 0\\
0 & -1 & 1 & \ldots & 0\\
\vdots & \vdots & \ddots & \ddots & \vdots\\
0 & 0 & \ldots & -1 & 1
\end{array}\right) \;,

mapping :math:`\mbb{R}^N \rightarrow \mbb{R}^{N-1}`, while if `circular`
is ``True``, it corresponds to the :math:`\mbb{R}^N \rightarrow \mbb{R}^N`
mapping

.. math::

\left(\begin{array}{rrrrr}
-1 & 1 & 0 & \ldots & 0\\
0 & -1 & 1 & \ldots & 0\\
\vdots & \vdots & \ddots & \ddots & \vdots\\
0 & 0 & \ldots & -1 & 1\\
1 & 0 & \dots & 0 & -1
\end{array}\right) \;.

Other possible choices include `prepend` set to ``None`` and `append`
set to `0`, giving the :math:`\mbb{R}^N \rightarrow \mbb{R}^N`
mapping

.. math::

\left(\begin{array}{rrrrr}
-1 & 1 & 0 & \ldots & 0\\
0 & -1 & 1 & \ldots & 0\\
\vdots & \vdots & \ddots & \ddots & \vdots\\
0 & 0 & \ldots & -1 & 1\\
0 & 0 & \dots & 0 & 0
\end{array}\right) \;,

and both `prepend` and `append` set to `1`, giving the
:math:`\mbb{R}^N \rightarrow \mbb{R}^{N+1}` mapping

.. math::

\left(\begin{array}{rrrrr}
1 & 0 & 0 & \ldots & 0\\
-1 & 1 & 0 & \ldots & 0\\
0 & -1 & 1 & \ldots & 0\\
\vdots & \vdots & \ddots & \ddots & \vdots\\
0 & 0 & \ldots & -1 & 1\\
0 & 0 & \dots & 0 & -1
\end{array}\right) \;.
"""

def __init__(
self,
axis: int,
input_shape: Shape,
input_dtype: DType = np.float32,
append: Optional[float] = None,
axis: int = -1,
prepend: Optional[Union[Literal[0], Literal[1]]] = None,
append: Optional[Union[Literal[0], Literal[1]]] = None,
circular: bool = False,
jit: bool = True,
**kwargs,
):
r"""
Args:
axis: Axis over which to apply finite difference operator.
input_shape: Shape of input array.
input_dtype: `dtype` for input argument. Defaults to
``float32``. If `LinearOperator` implements
complex-valued operations, this must be ``complex64`` for
proper adjoint and gradient calculation.
append: Value to append to the input along `axis` before
taking differences. Defaults to 0.
axis: Axis over which to apply finite difference operator.
prepend: Flag indicating handling of the left/top/etc.
boundary. If ``None``, there is no boundary extension.
Values of `0` or `1` indicate respectively that zeros or
the initial value in the array are prepended to the
difference array.
append: Flag indicating handling of the right/bottom/etc.
boundary. If ``None``, there is no boundary extension.
Values of `0` or `1` indicate respectively that zeros or
-1 times the final value in the array are appended to the
difference array.
circular: If ``True``, perform circular differences, i.e.,
include x[-1] - x[0]. If ``True``, `append` must be
``None``.
include x[-1] - x[0]. If ``True``, `prepend` and `append
must both be ``None``.
jit: If ``True``, jit the evaluation, adjoint, and gram
functions of the LinearOperator.
"""

if not isinstance(axis, int):
raise TypeError(f"Expected `axis` to be of type int, got {type(axis)} instead")
raise TypeError(f"Expected axis to be of type int, got {type(axis)} instead.")

if axis < 0:
axis = len(input_shape) + axis
if axis >= len(input_shape):
raise ValueError(
f"Invalid axis {axis} specified; `axis` must be less than "
f"`len(input_shape)`={len(input_shape)}"
f"Invalid axis {axis} specified; axis must be less than "
f"len(input_shape)={len(input_shape)}."
)

self.axis = axis

if append is not None and circular:
if circular and (prepend is not None or append is not None):
raise ValueError(
"`append` and `circular` are mutually exclusive but both were specified"
"Parameter circular must be False if either prepend or append is not None."
)
if prepend not in [None, 0, 1]:
raise ValueError("Parameter prepend may only take values None, 0, or 1.")
if append not in [None, 0, 1]:
raise ValueError("Parameter append may only take values None, 0, or 1.")

self.circular = circular
self.prepend = prepend
self.append = append
self.circular = circular

if self.append is None and not circular:
output_shape = tuple(x - (i == axis) for i, x in enumerate(input_shape))
else:
if self.circular:
output_shape = input_shape
else:
output_shape = tuple(
x + ((i == axis) * ((self.prepend is not None) + (self.append is not None) - 1))
for i, x in enumerate(input_shape)
)

super().__init__(
input_shape=input_shape,
output_shape=output_shape,
input_dtype=input_dtype,
output_dtype=input_dtype,
jit=jit,
**kwargs,
)

def _eval(self, x: JaxArray) -> JaxArray:
prepend = None
append = None
if self.circular:
# set append to the first slice along the specified axis
# Append a copy of the initial value at the end of the array so that the difference
# array includes the difference across the right/bottom/etc. boundary.
ind = tuple(
slice(0, 1) if i == self.axis else slice(None) for i in range(len(self.input_shape))
)
append = x[ind]
else:
append = self.append
if self.prepend == 0:
# Prepend a 0 to the difference array by prepending a copy of the initial value
# before the difference is computed.
ind = tuple(
slice(0, 1) if i == self.axis else slice(None)
for i in range(len(self.input_shape))
)
prepend = x[ind]
elif self.prepend == 1:
# Prepend a copy of the initial value to the difference array by prepending a 0
# before the difference is computed.
prepend = 0
if self.append == 0:
# Append a 0 to the difference array by appending a copy of the initial value
# before the difference is computed.
ind = tuple(
slice(-1, None) if i == self.axis else slice(None)
for i in range(len(self.input_shape))
)
append = x[ind]
elif self.append == 1:
# Append a copy of the initial value to the difference array by appending a 0
# before the difference is computed.
append = 0

return snp.diff(x, axis=self.axis, append=append)
return snp.diff(x, axis=self.axis, prepend=prepend, append=append)
34 changes: 32 additions & 2 deletions scico/test/linop/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import scico.numpy as snp
from scico.linop import FiniteDifference
from scico.linop import FiniteDifference, SingleAxisFiniteDifference
from scico.random import randn
from scico.test.linop.test_linop import adjoint_test

Expand All @@ -12,7 +12,7 @@ def test_eval():
with pytest.raises(ValueError): # axis 3 does not exist
A = FiniteDifference(input_shape=(3, 4, 5), axes=(0, 3))

A = FiniteDifference(input_shape=(2, 3), append=0.0)
A = FiniteDifference(input_shape=(2, 3), append=1)

x = snp.array([[1, 0, 1], [1, 1, 0]], dtype=snp.float32)

Expand All @@ -25,6 +25,36 @@ def test_eval():
snp.testing.assert_allclose(Ax[1], snp.array([[-1, 1, -1], [0, -1, 0]])) # along rows


def test_except():
with pytest.raises(TypeError): # axis is not an int
A = SingleAxisFiniteDifference(input_shape=(3,), axis=2.5)

with pytest.raises(ValueError): # invalid parameter combination
A = SingleAxisFiniteDifference(input_shape=(3,), prepend=0, circular=True)

with pytest.raises(ValueError): # invalid prepend value
A = SingleAxisFiniteDifference(input_shape=(3,), prepend=2)

with pytest.raises(ValueError): # invalid append value
A = SingleAxisFiniteDifference(input_shape=(3,), append="a")


def test_eval_prepend():
x = snp.arange(1, 6)
A = SingleAxisFiniteDifference(input_shape=(5,), prepend=0)
snp.testing.assert_allclose(A @ x, snp.array([0, 1, 1, 1, 1]))
A = SingleAxisFiniteDifference(input_shape=(5,), prepend=1)
snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 1]))


def test_eval_append():
x = snp.arange(1, 6)
A = SingleAxisFiniteDifference(input_shape=(5,), append=0)
snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, 0]))
A = SingleAxisFiniteDifference(input_shape=(5,), append=1)
snp.testing.assert_allclose(A @ x, snp.array([1, 1, 1, 1, -5]))


@pytest.mark.parametrize("input_dtype", [np.float32, np.complex64])
@pytest.mark.parametrize("input_shape", [(16,), (16, 24)])
@pytest.mark.parametrize("axes", [0, 1, (0,), (1,), None])
Expand Down