Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add test for x in op.domain in all solvers, closes #291 #502

Merged
merged 1 commit into from Aug 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions odl/solvers/advanced/chambolle_pock.py
Expand Up @@ -148,13 +148,13 @@ def chambolle_pock_solver(op, x, tau, sigma, proximal_primal, proximal_dual,
"""
# Forward operator
if not isinstance(op, Operator):
raise TypeError('`op` {} is not an instance of {}'
''.format(op, Operator))
raise TypeError('`op` {!r} is not an `Operator` instance'
''.format(op))

# Starting point
if x not in op.domain:
raise TypeError('`x` {} is not in the domain of `op` {}'
''.format(x.space, op.domain))
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

# Step size parameter
tau, tau_in = float(tau), tau
Expand Down
5 changes: 5 additions & 0 deletions odl/solvers/findroot/newton.py
Expand Up @@ -74,6 +74,11 @@ def bfgs_method(grad, x, line_search, niter=1, callback=None):
-------
`None`
"""

if x not in grad.domain:
raise TypeError('`x` {!r} is not in the domain of `grad` {!r}'
''.format(x, grad.domain))

hess = ident = IdentityOperator(grad.range)
grad_x = grad(x)
for _ in range(niter):
Expand Down
18 changes: 17 additions & 1 deletion odl/solvers/iterative/iterative.py
Expand Up @@ -102,6 +102,10 @@ def landweber(op, x, rhs, niter=1, omega=1, projection=None, callback=None):
"""
# TODO: add a book reference

if x not in op.domain:
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

# Reusable temporaries
tmp_ran = op.range.element()
tmp_dom = op.domain.element()
Expand Down Expand Up @@ -166,6 +170,10 @@ def conjugate_gradient(op, x, rhs, niter=1, callback=None):
if op.domain != op.range:
raise ValueError('operator needs to be self-adjoint')

if x not in op.domain:
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

r = op(x)
r.lincomb(1, rhs, -1, r) # r = rhs - A x
p = r.copy()
Expand Down Expand Up @@ -250,6 +258,10 @@ def conjugate_gradient_normal(op, x, rhs, niter=1, callback=None):
# TODO: add a book reference
# TODO: update doc

if x not in op.domain:
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

d = op(x)
d.lincomb(1, rhs, -1, d) # d = rhs - A x
p = op.derivative(x).adjoint(d)
Expand Down Expand Up @@ -351,9 +363,13 @@ def gauss_newton(op, x, rhs, niter=1, zero_seq=exp_zero_seq(2.0),
-------
None
"""
if x not in op.domain:
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

x0 = x.copy()
id_op = IdentityOperator(op.domain)
dx = x.space.zero()
dx = op.domain.zero()

tmp_dom = op.domain.element()
u = op.domain.element()
Expand Down
4 changes: 4 additions & 0 deletions odl/solvers/scalar/gradient.py
Expand Up @@ -79,6 +79,10 @@ def steepest_descent(grad, x, niter=1, line_search=1, projection=None,
Optimized solver for the case ``f(x) = x^T Ax - 2 x^T b``
"""

if x not in grad.domain:
raise TypeError('`x` {!r} is not in the domain of `grad` {!r}'
''.format(x, grad.domain))

if not callable(line_search):
step = float(line_search)
smart_line_search = False
Expand Down
8 changes: 6 additions & 2 deletions odl/solvers/vector/newton.py
Expand Up @@ -78,6 +78,10 @@ def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,
solved using the conjugate gradient method.
"""
# TODO: update doc
if x not in op.domain:
raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
''.format(x, op.domain))

if cg_iter is None:
# Motivated by that if it is Ax = b, x and b in Rn, it takes at most n
# iterations to solve with cg
Expand All @@ -91,12 +95,12 @@ def newtons_method(op, x, line_search, num_iter=10, cg_iter=None,

# Compute hessian (as operator) and gradient in the current point
hessian = op.derivative(x)
deriv_in_point = op(x).copy()
deriv_in_point = op(x)

# Solving A*x = b for x, in this case f'(x)*p = -f(x)
# TODO: Let the user provide/choose method for how to solve this?
conjugate_gradient(hessian, search_direction,
-1 * deriv_in_point, cg_iter)
-deriv_in_point, cg_iter)

# Computing step length
dir_deriv = search_direction.inner(deriv_in_point)
Expand Down