Skip to content

Commit

Permalink
Merge branch 'main' into brendt/notebooks-build
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg authored Nov 23, 2021
2 parents 24819c3 + 76854da commit f7abee9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 33 deletions.
64 changes: 42 additions & 22 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,50 +95,70 @@ def valid_adjoint(
A: LinearOperator,
AT: LinearOperator,
eps: Optional[float] = 1e-7,
x: Optional[JaxArray] = None,
y: Optional[JaxArray] = None,
key: Optional[PRNGKey] = None,
) -> Union[bool, float]:
r"""Check whether :class:`.LinearOperator` `AT` is the adjoint of `A`.
The test exploits the identity
Check whether :class:`.LinearOperator` :math:`\mathsf{AT}` is the
adjoint of :math:`\mathsf{A}`. The test exploits the identity
.. math::
\mathbf{y}^T (A \mathbf{x}) = (\mathbf{y}^T A) \mathbf{x} =
(A^T \mathbf{y})^T \mathbf{x}
by computing :math:`\mathbf{u} = A \mathbf{x}` and
:math:`\mathbf{v} = A^T \mathbf{y}` for random :math:`\mathbf{x}`
and :math:`\mathbf{y}` and confirming that :math:`\| \mathbf{y}^T
\mathbf{u} - \mathbf{v}^T \mathbf{x} \|_2 < \epsilon` since
by computing :math:`\mathbf{u} = \mathsf{A}(\mathbf{x})` and
:math:`\mathbf{v} = \mathsf{AT}(\mathbf{y})` for random
:math:`\mathbf{x}` and :math:`\mathbf{y}` and confirming that
.. math::
\mathbf{y}^T \mathbf{u} = \mathbf{y}^T (A \mathbf{x}) =
(A^T \mathbf{y})^T \mathbf{x} = \mathbf{v}^T \mathbf{x}
\frac{| \mathbf{y}^T \mathbf{u} - \mathbf{v}^T \mathbf{x} |}
{\max \left\{ | \mathbf{y}^T \mathbf{u} |,
| \mathbf{v}^T \mathbf{x} | \right\}}
< \epsilon \;.
when :math:`A^T` is a valid adjoint of :math:`A`. If :math:`A` is a
complex operator (with a complex `input_dtype`) then the test checks
whether `AT` is the Hermitian conjugate of `A`, with a test as above,
but with all the :math:`\cdot^T` replaced with :math:`\cdot^H`.
If :math:`\mathsf{A}` is a complex operator (with a complex
`input_dtype`) then the test checks whether :math:`\mathsf{AT}` is
the Hermitian conjugate of :math:`\mathsf{A}`, with a test as above,
but with all the :math:`(\cdot)^T` replaced with :math:`(\cdot)^H`.
Args:
A: Primary :class:`.LinearOperator`.
AT: Adjoint :class:`.LinearOperator`.
eps: Error threshold for validation of `AT` as adjoint of `A`. If
None, the relative error is returned instead of a boolean value.
eps: Error threshold for validation of :math:`\mathsf{AT}` as
adjoint of :math:`\mathsf{AT}`. If None, the relative error
is returned instead of a boolean value.
x : If not the default None, use the specified array instead of a
random array as test vector :math:`\mb{x}`. If specified, the
array must have shape ``A.input_shape``.
y : If not the default None, use the specified array instead of a
random array as test vector :math:`\mb{y}`. If specified, the
array must have shape ``AT.input_shape``.
key: Jax PRNG key. Defaults to None, in which case a new key is
created.
Returns:
Boolean value indicating that validation passed, or relative error
of test, depending on type of parameter `eps`.
Boolean value indicating whether validation passed, or relative
error of test, depending on type of parameter `eps`.
"""

x0, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype)
x1, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype)
y0 = A(x0)
y1 = AT(x1)
x1y0 = snp.dot(x1.ravel().conj(), y0.ravel())
y1x0 = snp.dot(y1.ravel().conj(), x0.ravel())
err = snp.linalg.norm(x1y0 - y1x0) / max(snp.linalg.norm(x1y0), snp.linalg.norm(y1x0))
if x is None:
x, key = randn(shape=A.input_shape, key=key, dtype=A.input_dtype)
else:
if x.shape != A.input_shape:
raise ValueError("Shape of x array not appropriate as an input for operator A")
if y is None:
y, key = randn(shape=AT.input_shape, key=key, dtype=AT.input_dtype)
else:
if y.shape != AT.input_shape:
raise ValueError("Shape of y array not appropriate as an input for operator AT")

u = A(x)
v = AT(y)
yTu = snp.dot(y.ravel().conj(), u.ravel())
vTx = snp.dot(v.ravel().conj(), x.ravel())
err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx))
if eps is None:
return err
else:
Expand Down
25 changes: 22 additions & 3 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@
import scico.numpy as snp
from scico import linop
from scico.random import randn
from scico.typing import PRNGKey
from scico.typing import JaxArray, PRNGKey


def adjoint_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
def adjoint_test(
A: linop.LinearOperator,
key: Optional[PRNGKey] = None,
rtol: float = 1e-4,
x: Optional[JaxArray] = None,
y: Optional[JaxArray] = None,
):
"""Check the validity of A.conj().T as the adjoint for a LinearOperator A.
Args:
Expand All @@ -27,7 +33,20 @@ def adjoint_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: f
rtol: Relative tolerance
"""

assert linop.valid_adjoint(A, A.H, rtol, key)
assert linop.valid_adjoint(A, A.H, key=key, eps=rtol, x=x, y=y)


def test_valid_adjoint():

diagonal, key = randn((16,), dtype=np.float32)
D = linop.Diagonal(diagonal=diagonal)
assert linop.valid_adjoint(D, D.T, key=key, eps=None) < 1e-4
x, key = randn((8,), dtype=np.float32)
y, key = randn((8,), dtype=np.float32)
with pytest.raises(ValueError):
linop.valid_adjoint(D, D.T, key=key, x=x)
with pytest.raises(ValueError):
linop.valid_adjoint(D, D.T, key=key, y=y)


class AbsMatOp(linop.LinearOperator):
Expand Down
35 changes: 27 additions & 8 deletions scico/test/linop/test_radon_astra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,35 @@

import pytest

import scico
from scico.test.linop.test_linop import adjoint_test
from scico.test.linop.test_radon_svmbir import make_im

try:
from scico.linop.radon_astra import ParallelBeamProjector
except ImportError as e:
pytest.skip("astra not installed", allow_module_level=True)

import scico

N = 128
rtol_cpu = 5e-5
rtol_gpu = 7e-2
rtol_gpu_random_input = 1.0


def get_tol():
if jax.devices()[0].device_kind == "cpu":
rtol = 5e-5
rtol = rtol_cpu
else:
rtol = 7e-2
rtol = rtol_gpu # astra inaccurate in GPU
return rtol


def get_tol_random_input():
if jax.devices()[0].device_kind == "cpu":
rtol = rtol_cpu
else:
rtol = rtol_gpu_random_input # astra more inaccurate in GPU for random inputs
return rtol


Expand All @@ -37,9 +49,9 @@ def __init__(self, volume_geometry):
self.A = ParallelBeamProjector(
input_shape=(N, N),
volume_geometry=volume_geometry,
detector_spacing=1,
det_count=384,
angles=np.linspace(0, np.pi, 180, False),
detector_spacing=detector_spacing,
det_count=N_det,
angles=angles,
)


Expand Down Expand Up @@ -93,6 +105,13 @@ def test_adjoint_grad(testobj):
np.testing.assert_allclose(scico.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=get_tol())


def test_adjoint(testobj):
def test_adjoint_random(testobj):
A = testobj.A
adjoint_test(A, rtol=get_tol())
adjoint_test(A, rtol=get_tol_random_input())


def test_adjoint_typical_input(testobj):
A = testobj.A
x = make_im(A.input_shape[0], A.input_shape[1], is_3d=False)

adjoint_test(A, x=x, rtol=get_tol())

0 comments on commit f7abee9

Please sign in to comment.