Skip to content

Commit

Permalink
ENH: add test for x in op.domain in all solvers, closes #291
Browse files Browse the repository at this point in the history
  • Loading branch information
adler-j committed Aug 11, 2016
1 parent f89eca6 commit db4b3c0
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion odl/solvers/advanced/chambolle_pock.py
Expand Up @@ -154,7 +154,7 @@ def chambolle_pock_solver(op, x, tau, sigma, proximal_primal, proximal_dual,
# Starting point
if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x.space, op.domain))
''.format(x, op.domain))

# Step size parameter
tau, tau_in = float(tau), tau
Expand Down
5 changes: 5 additions & 0 deletions odl/solvers/findroot/newton.py
Expand Up @@ -74,6 +74,11 @@ def bfgs_method(grad, x, line_search, niter=1, callback=None):
-------
`None`
"""

if x not in grad.domain:
raise TypeError('`x` {} is not in the domain of `grad` {}'
''.format(x, grad.domain))

hess = ident = IdentityOperator(grad.range)
grad_x = grad(x)
for _ in range(niter):
Expand Down
18 changes: 17 additions & 1 deletion odl/solvers/iterative/iterative.py
Expand Up @@ -102,6 +102,10 @@ def landweber(op, x, rhs, niter=1, omega=1, projection=None, callback=None):
"""
# TODO: add a book reference

if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x, op.domain))

# Reusable temporaries
tmp_ran = op.range.element()
tmp_dom = op.domain.element()
Expand Down Expand Up @@ -166,6 +170,10 @@ def conjugate_gradient(op, x, rhs, niter=1, callback=None):
if op.domain != op.range:
raise ValueError('operator needs to be self-adjoint')

if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x, op.domain))

r = op(x)
r.lincomb(1, rhs, -1, r) # r = rhs - A x
p = r.copy()
Expand Down Expand Up @@ -250,6 +258,10 @@ def conjugate_gradient_normal(op, x, rhs, niter=1, callback=None):
# TODO: add a book reference
# TODO: update doc

if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x, op.domain))

d = op(x)
d.lincomb(1, rhs, -1, d) # d = rhs - A x
p = op.derivative(x).adjoint(d)
Expand Down Expand Up @@ -351,9 +363,13 @@ def gauss_newton(op, x, rhs, niter=1, zero_seq=exp_zero_seq(2.0),
-------
None
"""
if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x, op.domain))

x0 = x.copy()
id_op = IdentityOperator(op.domain)
dx = x.space.zero()
dx = op.domain.zero()

tmp_dom = op.domain.element()
u = op.domain.element()
Expand Down
4 changes: 4 additions & 0 deletions odl/solvers/scalar/gradient.py
Expand Up @@ -79,6 +79,10 @@ def steepest_descent(grad, x, niter=1, line_search=1, projection=None,
Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b``
"""

if x not in grad.domain:
raise TypeError('`x` {} is not in the domain of `grad` {}'
''.format(x, grad.domain))

if not callable(line_search):
step = float(line_search)
smart_line_search = False
Expand Down
16 changes: 10 additions & 6 deletions odl/solvers/vector/newton.py
Expand Up @@ -28,7 +28,7 @@
__all__ = ('newtons_method',)


def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,
def newtons_method(grad, x, line_search, num_iter=10, cg_iter=None,
callback=None):
"""Newton's method for solving a system of equations.
Expand All @@ -49,9 +49,9 @@ def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,
Parameters
----------
op : `Operator`
grad : `Operator`
Gradient of the objective function, ``x --> grad f(x)``
x : element in the domain of ``op``
x : element in the domain of ``grad``
Starting point of the iteration
line_search : `LineSearch`
Strategy to choose the step length
Expand All @@ -78,10 +78,14 @@ def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,
solved using the conjugate gradient method.
"""
# TODO: update doc
if x not in grad.domain:
raise TypeError('`x` {} is not in the domain of `grad` {}'
''.format(x, grad.domain))

if cg_iter is None:
# Motivated by that if it is Ax = b, x and b in Rn, it takes at most n
# iterations to solve with cg
cg_iter = op.domain.size
cg_iter = grad.domain.size

# TODO: optimize by using lincomb and avoiding to create copies
for _ in range(num_iter):
Expand All @@ -90,8 +94,8 @@ def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,
search_direction = x.space.zero()

# Compute hessian (as operator) and gradient in the current point
hessian = op.derivative(x)
deriv_in_point = op(x).copy()
hessian = grad.derivative(x)
deriv_in_point = grad(x).copy()

# Solving A*x = b for x, in this case f'(x)*p = -f(x)
# TODO: Let the user provide/choose method for how to solve this?
Expand Down

0 comments on commit db4b3c0

Please sign in to comment.