Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up scico.solver module #125

Merged
merged 46 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e8715de
Various cleanup, mainly docstrings
bwohlberg Nov 19, 2021
5eefaf8
Merge branch 'main' into brendt/cleanup
bwohlberg Nov 20, 2021
9791e8d
Docstring cleanup
bwohlberg Nov 20, 2021
f2e0621
Docstring cleanup
bwohlberg Nov 21, 2021
4e09adc
Typo bug fix
bwohlberg Nov 21, 2021
bdef1d1
Use intersphinx refs
bwohlberg Nov 21, 2021
bb07468
Merge branch 'main' into brendt/cleanup
bwohlberg Nov 24, 2021
a06faec
Merge branch 'main' into brendt/cleanup
bwohlberg Nov 30, 2021
2fe5a31
Some improvements
bwohlberg Dec 2, 2021
73f116b
Remove explicit copy of jax docstrings
bwohlberg Dec 2, 2021
2c5570c
Clean up docstrings
bwohlberg Dec 3, 2021
085cf7e
Fix indentation issues
bwohlberg Dec 3, 2021
dddcf1c
Clean up docstrings
bwohlberg Dec 3, 2021
29dd7c9
Clean up docstrings
bwohlberg Dec 3, 2021
6b57421
Clean up docstrings
bwohlberg Dec 3, 2021
a0b84dc
Clean up docstrings
bwohlberg Dec 3, 2021
fa12ed9
Clean up docstrings
bwohlberg Dec 3, 2021
d5ba36e
Clean up docstrings
bwohlberg Dec 3, 2021
5d2259e
Minor edit
bwohlberg Dec 3, 2021
c80c574
Clean up docstrings
bwohlberg Dec 3, 2021
a7ec6c7
Style guide compliance
bwohlberg Dec 3, 2021
5c9efc5
Docstring cleanup
bwohlberg Dec 3, 2021
acc3bd0
Style guide compliance
bwohlberg Dec 3, 2021
61302af
Cleanup
bwohlberg Dec 3, 2021
376c06a
Fix docstring format problem
bwohlberg Dec 3, 2021
84d64af
Cleanup and style compliance
bwohlberg Dec 4, 2021
b6e0c77
Merge branch 'main' into brendt/cleanup
bwohlberg Dec 4, 2021
4527d2b
Apply black manually
bwohlberg Dec 4, 2021
5290e01
Docstring cleanup and style compliance
bwohlberg Dec 4, 2021
4834cb2
Docstring cleanup and style compliance
bwohlberg Dec 5, 2021
81791d2
Docstring cleanup and style compliance
bwohlberg Dec 5, 2021
81f732b
Style compliance
bwohlberg Dec 5, 2021
2eed6b7
Docstring cleanup
bwohlberg Dec 5, 2021
6f2bd1e
Improve formatting of returns specification
bwohlberg Dec 6, 2021
1f876dd
Alternative formatting of multiple returns
bwohlberg Dec 6, 2021
6616a70
Formatting improvement
bwohlberg Dec 6, 2021
b363832
Minor changes
bwohlberg Dec 6, 2021
e273edc
Docstring cleanup
bwohlberg Dec 6, 2021
32aaeaa
Merge branch 'main' into brendt/cleanup
bwohlberg Dec 6, 2021
75eacd1
Format as code
Michael-T-McCann Dec 6, 2021
400e1e8
Remove explicit copies of docs from scipy.optimize
bwohlberg Dec 6, 2021
e6c26aa
Make auxiliary functions private
bwohlberg Dec 6, 2021
850e21c
Merge branch 'main' into brendt/solver
bwohlberg Dec 6, 2021
da515b3
Merge branch 'main' into brendt/solver
bwohlberg Dec 7, 2021
891c909
Remove todo notes from docstring and code
bwohlberg Dec 7, 2021
dad8e0a
Address oversight in function renaming
bwohlberg Dec 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 20 additions & 191 deletions scico/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Optimization algorithms.

.. todo::
Add motivation for this module; when to choose over jax optimizers

"""
"""Optimization algorithms."""


from functools import wraps
Expand Down Expand Up @@ -86,7 +81,7 @@ def wrapper(x, *args):
return wrapper


def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
def _split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""Split an array of shape (N,M,...) into real and imaginary parts.

Args:
Expand All @@ -101,12 +96,12 @@ def split_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArra
BlockArray.
"""
if isinstance(x, BlockArray):
return BlockArray.array([split_real_imag(_) for _ in x])
return BlockArray.array([_split_real_imag(_) for _ in x])
else:
return snp.stack((snp.real(x), snp.imag(x)))


def join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
def _join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:
"""Join a real array of shape (2,N,M,...) into a complex array.

Join a real array of shape (2,N,M,...) into a complex array of length
Expand All @@ -120,16 +115,11 @@ def join_real_imag(x: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray
and ``x[1]`` respectively.
"""
if isinstance(x, BlockArray):
return BlockArray.array([join_real_imag(_) for _ in x])
return BlockArray.array([_join_real_imag(_) for _ in x])
else:
return x[0] + 1j * x[1]


# TODO: Use jax to compute Hessian-vector products for use in Newton methods
# see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Hessian-vector-products-using-both-forward--and-reverse-mode
# for examples of constructing Hessians in jax


def minimize(
func: Callable,
x0: Union[JaxArray, BlockArray],
Expand All @@ -154,152 +144,27 @@ def minimize(
supported.
- Functions mapping from complex arrays -> float are supported.

Docstring for :func:`scipy.optimize.minimize` follows. For
descriptions of the optimization methods and custom minimizers, refer
to the original docstring for :func:`scipy.optimize.minimize`.

Args:
func: The objective function to be minimized.

``func(x, *args) -> float``

where ``x`` is an array and ``args`` is a tuple of the fixed parameters
needed to completely specify the function. Unlike
:func:`scipy.optimize.minimize`, ``x`` need not be a 1D array.
x0: Initial guess. If ``func`` is a mapping from complex arrays to floats,
x0 must have a complex data type.
args: Extra arguments passed to the objective function and `hess`.
method: Type of solver. Should be one of:

- 'Nelder-Mead' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-neldermead.html>`__
- 'Powell' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-powell.html>`__
- 'CG' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-cg.html>`__
- 'BFGS' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-bfgs.html>`__
- 'Newton-CG' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-newtoncg.html>`__
- 'L-BFGS-B' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html>`__
- 'TNC' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-tnc.html>`__
- 'COBYLA' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-cobyla.html>`__
- 'SLSQP' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html>`__
- 'trust-constr'`(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustconstr.html>`__
- 'dogleg' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-dogleg.html>`__
- 'trust-ncg' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustncg.html>`__
- 'trust-exact' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustexact.html>`__
- 'trust-krylov' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize-trustkrylov.html>`__
- custom - a callable object (added in version SciPy 0.14.0), see :func:`scipy.optimize.minmize_scalar`.

If not given, chosen to be one of ``BFGS``, ``L-BFGS-B``, ``SLSQP``,
depending if the problem has constraints or bounds.

hess: Method for computing the Hessian matrix. Only for Newton-CG, dogleg,
trust-ncg, trust-krylov, trust-exact and trust-constr. If it is
callable, it should return the Hessian matrix:

``hess(x, *args) -> {LinearOperator, spmatrix, array}, (n, n)``

where x is a (n,) ndarray and `args` is a tuple with the fixed
parameters. LinearOperator and sparse matrix returns are
allowed only for 'trust-constr' method. Alternatively, the keywords
{'2-point', '3-point', 'cs'} select a finite difference scheme
for numerical estimation. Or, objects implementing
`HessianUpdateStrategy` interface can be used to approximate
the Hessian. Available quasi-Newton methods implementing
this interface are:

- `BFGS`;
- `SR1`.

Whenever the gradient is estimated via finite-differences,
the Hessian cannot be estimated with options
{'2-point', '3-point', 'cs'} and needs to be
estimated using one of the quasi-Newton strategies.
Finite-difference options {'2-point', '3-point', 'cs'} and
`HessianUpdateStrategy` are available only for 'trust-constr' method.
NOTE: In the future, `hess` may be determined using jax.
hessp: Hessian of objective function times an arbitrary vector p.
Only for Newton-CG, trust-ncg, trust-krylov, trust-constr.
Only one of `hessp` or `hess` needs to be given. If `hess` is
provided, then `hessp` will be ignored. `hessp` must compute the
Hessian times an arbitrary vector:

``hessp(x, p, *args) -> array``

where x is a ndarray, p is an arbitrary vector with
dimension equal to x, and `args` is a tuple with the fixed parameters.
bounds (None, optional): Bounds on variables for L-BFGS-B, TNC, SLSQP, Powell, and
trust-constr methods. There are two ways to specify the bounds:

1. Instance of `Bounds` class.
2. Sequence of ``(min, max)`` pairs for each element in `x`. None
is used to specify no bound.

constraints: Constraints definition (only for COBYLA, SLSQP and trust-constr).
Constraints for 'trust-constr' are defined as a single object or a
list of objects specifying constraints to the optimization problem.

Available constraints are:

- `LinearConstraint`
- `NonlinearConstraint`

Constraints for COBYLA, SLSQP are defined as a list of dictionaries.
Each dictionary with fields:

type : str
Constraint type: 'eq' for equality, 'ineq' for inequality.
fun : callable
The function defining the constraint.
jac : callable, optional
The Jacobian of `fun` (only for SLSQP).
args : sequence, optional
Extra arguments to be passed to the function and Jacobian.

Equality constraint means that the constraint function result is to
be zero whereas inequality means that it is to be non-negative.
Note that COBYLA only supports inequality constraints.

tol: Tolerance for termination. For detailed control, use solver-specific options.
callback: Called after each iteration. For 'trust-constr' it is a callable with
the signature:

``callback(xk, OptimizeResult state) -> bool``

where ``xk`` is the current parameter vector. and ``state``
is an `OptimizeResult` object, with the same fields
as the ones from the return. If callback returns True
the algorithm execution is terminated.
For all the other methods, the signature is:

``callback(xk)``

where ``xk`` is the current parameter vector.
options: A dictionary of solver options. All methods accept the following
generic options:

maxiter : int
Maximum number of iterations to perform.
disp : bool
Set to True to print convergence messages.

See :func:`scipy.optimize.show_options()` for solver-specific options.

For more detail, including descriptions of the optimization methods
and custom minimizers, refer to the original docs for
:func:`scipy.optimize.minimize`.
"""

if snp.iscomplexobj(x0):
# scipy minimize function requires real-valued arrays, so
# we split x0 into a vector with real/imaginary parts stacked
# and compose `func` with a `join_real_imag`
# and compose `func` with a `_join_real_imag`
iscomplex = True
func_ = lambda x: func(join_real_imag(x))
x0 = split_real_imag(x0)
func_ = lambda x: func(_join_real_imag(x))
x0 = _split_real_imag(x0)
else:
iscomplex = False
func_ = func

x0_shape = x0.shape
x0_dtype = x0.dtype
x0 = x0.ravel() # If x0 is a BlockArray it will become a DeviceArray here
x0 = x0.ravel() # if x0 is a BlockArray it will become a DeviceArray here
if isinstance(x0, jax.interpreters.xla.DeviceArray):
dev = x0.device_buffer.device() # device where x0 resides; used to put result back in place
dev = x0.device_buffer.device() # device for x0; used to put result back in place
x0 = x0.copy().astype(float)
else:
dev = None
Expand Down Expand Up @@ -330,15 +195,15 @@ def minimize(
# un-vectorize the output array, put on device
res.x = snp.reshape(
res.x, x0_shape
) # If x0 was originally a BlockArray be converted back to one here
) # if x0 was originally a BlockArray be converted back to one here

res.x = res.x.astype(x0_dtype)

if dev:
res.x = jax.device_put(res.x, dev)

if iscomplex:
res.x = join_real_imag(res.x)
res.x = _join_real_imag(res.x)

return res

Expand All @@ -355,47 +220,11 @@ def minimize_scalar(

"""Minimization of scalar function of one variable.

Wrapper around :func:`scipy.optimize.minimize_scalar`. Docstring for
:func:`scipy.optimize.minimize_scalar` follows. For descriptions of
the optimization methods and custom minimizers, refer to the original
docstring for :func:`scipy.optimize.minimize_scalar`.

Args:
func: Objective function. Scalar function, must return a scalar.
bracket: For methods 'brent' and 'golden', `bracket` defines the bracketing
interval and can either have three items ``(a, b, c)`` so that
``a < b < c`` and ``fun(b) < fun(a), fun(c)`` or two items ``a`` and
``c`` which are assumed to be a starting interval for a downhill
bracket search (see `bracket`); it doesn't always mean that the
obtained solution will satisfy ``a <= x <= c``.
bounds: For method 'bounded', `bounds` is mandatory and must have two items
corresponding to the optimization bounds.
args: Extra arguments passed to the objective function.
method: Type of solver. Should be one of:

- 'Brent' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-brent.html>`__
- 'Bounded' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-bounded.html>`__
- 'Golden' `(see here) <https://docs.scipy.org/doc/scipy/reference/optimize.minimize_scalar-golden.html>`__
- custom - a callable object (added in SciPy version 0.14.0), see :func:`scipy.optimize.minmize_scalar`.


tol: Tolerance for termination. For detailed control, use solver-specific
options.
options: A dictionary of solver options.
maxiter : int
Maximum number of iterations to perform.
disp : bool
Set to True to print convergence messages.

See :func:`scipy.optimize.show_options()` for solver-specific options.

Returns:
The optimization result represented as a ``OptimizeResult`` object.
Important attributes are: ``x`` the solution array, ``success`` a
Boolean flag indicating if the optimizer exited successfully and
``message`` which describes the cause of the termination. See
:class:`scipy.optimize.OptimizeResult` for a description of other attributes.
Wrapper around :func:`scipy.optimize.minimize_scalar`.

For more detail, including descriptions of the optimization methods
and custom minimizers, refer to the original docstring for
:func:`scipy.optimize.minimize_scalar`.
"""

def f(x, *args):
Expand Down Expand Up @@ -437,7 +266,7 @@ def cg(
x0: Initial solution.
tol: Relative residual stopping tolerance. Convergence occurs
when ``norm(residual) <= max(tol * norm(b), atol)``.
atol : Absolute residual stopping tolerance. Convergence occurs
atol: Absolute residual stopping tolerance. Convergence occurs
when ``norm(residual) <= max(tol * norm(b), atol)``.
maxiter: Maximum iterations. Default: 1000.
M: Preconditioner for A. The preconditioner should approximate
Expand Down
8 changes: 4 additions & 4 deletions scico/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,25 @@ def f(x):

def test_split_join_array():
x, key = random.randn((4, 4), dtype=np.complex64)
x_s = solver.split_real_imag(x)
x_s = solver._split_real_imag(x)
assert x_s.shape == (2, 4, 4)
np.testing.assert_allclose(x_s[0], snp.real(x))
np.testing.assert_allclose(x_s[1], snp.imag(x))

x_j = solver.join_real_imag(x_s)
x_j = solver._join_real_imag(x_s)
np.testing.assert_allclose(x_j, x, rtol=1e-4)


def test_split_join_blockarray():
x, key = random.randn(((4, 4), (3,)), dtype=np.complex64)
x_s = solver.split_real_imag(x)
x_s = solver._split_real_imag(x)
assert x_s.shape == ((2, 4, 4), (2, 3))

real_block = BlockArray.array((x_s[0][0], x_s[1][0]))
imag_block = BlockArray.array((x_s[0][1], x_s[1][1]))
np.testing.assert_allclose(real_block.ravel(), snp.real(x).ravel(), rtol=1e-4)
np.testing.assert_allclose(imag_block.ravel(), snp.imag(x).ravel(), rtol=1e-4)

x_j = solver.join_real_imag(x_s)
x_j = solver._join_real_imag(x_s)
assert x_j.shape == x.shape
np.testing.assert_allclose(x_j.ravel(), x.ravel(), rtol=1e-4)