Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve issue #240 #354

Merged
merged 7 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import jax
import jax.experimental.host_callback as hcb

import scico.linop
import scico.numpy as snp
from scico.numpy import BlockArray
from scico.typing import BlockShape, DType, JaxArray, Shape
Expand Down Expand Up @@ -296,12 +297,12 @@ def f(x, *args):
def cg(
A: Callable,
b: JaxArray,
x0: JaxArray,
x0: Optional[JaxArray] = None,
*,
tol: float = 1e-5,
atol: float = 0.0,
maxiter: int = 1000,
info: bool = False,
info: bool = True,
M: Optional[Callable] = None,
) -> Tuple[JaxArray, dict]:
r"""Conjugate Gradient solver.
Expand All @@ -310,15 +311,20 @@ def cg(
positive definite, via the conjugate gradient method.

Args:
A: Function implementing linear operator :math:`A`, should be
positive definite.
A: Callable implementing linear operator :math:`A`, which should
be positive definite.
b: Input array :math:`\mb{b}`.
x0: Initial solution.
x0: Initial solution. If `A` is a :class:`.LinearOperator`, this
parameter need to be specified, and defaults to a zero array.
Otherwise, it is required.
tol: Relative residual stopping tolerance. Convergence occurs
when `norm(residual) <= max(tol * norm(b), atol)`.
atol: Absolute residual stopping tolerance. Convergence occurs
when `norm(residual) <= max(tol * norm(b), atol)`.
maxiter: Maximum iterations. Default: 1000.
info: If ``True`` return a tuple consting of the solution array
and a dictionary containing diagnostic information, otherwise
just return the solution.
M: Preconditioner for `A`. The preconditioner should approximate
the inverse of `A`. The default, ``None``, uses no
preconditioner.
Expand All @@ -329,6 +335,11 @@ def cg(
- **x** : Solution array.
- **info**: Dictionary containing diagnostic information.
"""
if x0 is None:
if isinstance(A, scico.linop.LinearOperator):
x0 = snp.zeros(A.input_shape, b.dtype)
else:
raise ValueError("Parameter x0 must be specified if A is not a LinearOperator")

if M is None:
M = lambda x: x
Expand All @@ -342,8 +353,7 @@ def cg(
num = snp.sum(r.conj() * z)
ii = 0

# termination tolerance
# uses the "non-legacy" form of scicpy.sparse.linalg.cg
# termination tolerance (uses the "non-legacy" form of scicpy.sparse.linalg.cg)
termination_tol_sq = snp.maximum(tol * bn, atol) ** 2

while (ii < maxiter) and (num > termination_tol_sq):
Expand All @@ -358,7 +368,71 @@ def cg(
p = z + beta * p
ii += 1

return (x, {"num_iter": ii, "rel_res": snp.sqrt(num).real / bn})
if info:
return (x, {"num_iter": ii, "rel_res": snp.sqrt(num).real / bn})
else:
return x


def lstsq(
A: Callable,
b: JaxArray,
x0: Optional[JaxArray] = None,
tol: float = 1e-5,
atol: float = 0.0,
maxiter: int = 1000,
info: bool = False,
M: Optional[Callable] = None,
) -> Tuple[JaxArray, dict]:
r"""Least squares solver.

Solve the least squares problem

.. math::
\argmin_{\mb{x}} \; (1/2) \norm{ A \mb{x} - \mb{b}) }_2^2 \;,

where :math:`A` is a linear operator and :math:`\mb{b}` is a vector.
The problem is solved using :func:`cg`.

Args:
A: Callable implementing linear operator :math:`A`.
b: Input array :math:`\mb{b}`.
x0: Initial solution. If `A` is a :class:`.LinearOperator`, this
parameter need to be specified, and defaults to a zero array.
Otherwise, it is required.
tol: Relative residual stopping tolerance. Convergence occurs
when `norm(residual) <= max(tol * norm(b), atol)`.
atol: Absolute residual stopping tolerance. Convergence occurs
when `norm(residual) <= max(tol * norm(b), atol)`.
maxiter: Maximum iterations. Default: 1000.
info: If ``True`` return a tuple consting of the solution array
and a dictionary containing diagnostic information, otherwise
just return the solution.
M: Preconditioner for `A`. The preconditioner should approximate
the inverse of `A`. The default, ``None``, uses no
preconditioner.

Returns:
tuple: A tuple (x, info) containing:

- **x** : Solution array.
- **info**: Dictionary containing diagnostic information.
"""
if isinstance(A, scico.linop.LinearOperator):
Aop = A
else:
assert x0 is not None
Aop = scico.linop.LinearOperator(
input_shape=x0.shape,
output_shape=b.shape,
eval_fn=A,
input_dtype=b.dtype,
output_dtype=b.dtype,
)

ATA = Aop.T @ Aop
ATb = Aop.T @ b
return cg(ATA, ATb, x0=x0, tol=tol, atol=atol, maxiter=maxiter, info=info, M=M)


def bisect(
Expand Down
55 changes: 51 additions & 4 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import scico.numpy as snp
from scico import random, solver
from scico import linop, random, solver


class TestSet:
Expand Down Expand Up @@ -42,7 +42,23 @@ def test_cg_std(self):
assert info["rel_res"].ndim == 0
assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6

def test_cg_info(self):
def test_cg_op(self):
N = 32
Ac = np.random.randn(N, N).astype(np.float32)
Am = Ac.dot(Ac.T)
A = Am.dot
x = np.random.randn(N).astype(np.float32)
b = Am.dot(x)
tol = 1e-12
try:
xcg, info = solver.cg(linop.MatrixOperator(Am), b, tol=tol)
except Exception as e:
print(e)
assert 0
assert info["rel_res"].ndim == 0
assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6

def test_cg_no_info(self):
N = 64
Ac = np.random.randn(N, N)
Am = Ac.dot(Ac.T)
Expand All @@ -52,11 +68,10 @@ def test_cg_info(self):
x0 = np.zeros((N,))
tol = 1e-12
try:
xcg, info = solver.cg(A, b, x0, tol=tol, info=True)
xcg = solver.cg(A, b, x0, tol=tol, info=False)
except Exception as e:
print(e)
assert 0
assert info["rel_res"] <= tol
assert np.linalg.norm(A(xcg) - b) / np.linalg.norm(b) < 1e-6

def test_cg_complex(self):
Expand Down Expand Up @@ -98,6 +113,38 @@ def test_preconditioned_cg(self):
# Assert that PCG converges faster in a few iterations
assert cg_info["rel_res"] > 3 * pcg_info["rel_res"]

def test_lstsq_func(self):
N = 24
M = 32
Ac = jax.device_put(np.random.randn(N, M).astype(np.float32))
Am = Ac.dot(Ac.T)
A = Am.dot
x = jax.device_put(np.random.randn(N).astype(np.float32))
b = Am.dot(x)
x0 = snp.zeros((N,), dtype=np.float32)
tol = 1e-6
try:
xlsq = solver.lstsq(A, b, x0=x0, tol=tol)
except Exception as e:
print(e)
assert 0
assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 5e-6

def test_lstsq_op(self):
N = 32
M = 24
Ac = jax.device_put(np.random.randn(N, M).astype(np.float32))
A = linop.MatrixOperator(Ac)
x = jax.device_put(np.random.randn(M).astype(np.float32))
b = Ac.dot(x)
tol = 1e-7
try:
xlsq = solver.lstsq(A, b, tol=tol)
except Exception as e:
print(e)
assert 0
assert np.linalg.norm(A(xlsq) - b) / np.linalg.norm(b) < 1e-6


class TestOptimizeScalar:
# Adopted from SciPy minimize_scalar tests
Expand Down