Skip to content

Commit

Permalink
API: Remove ResidualOperator, use op - x syntax instead. Add Operator…
Browse files Browse the repository at this point in the history
…VectorSum
  • Loading branch information
adler-j committed Sep 30, 2016
1 parent 2503425 commit 9b9996c
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 109 deletions.
120 changes: 28 additions & 92 deletions odl/operator/default_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
__all__ = ('ScalingOperator', 'ZeroOperator', 'IdentityOperator',
'LinCombOperator', 'MultiplyOperator', 'PowerOperator',
'InnerProductOperator', 'NormOperator', 'DistOperator',
'ConstantOperator', 'ResidualOperator')
'ConstantOperator')


class ScalingOperator(Operator):
Expand Down Expand Up @@ -765,8 +765,9 @@ def __init__(self, constant, domain=None, range=None):
if range is None:
range = constant.space

super().__init__(domain, range)
self.__constant = self.range.element(constant)
self.__constant = range.element(constant)
linear = self.constant.norm() == 0
super().__init__(domain, range, linear=linear)

@property
def constant(self):
Expand All @@ -780,6 +781,13 @@ def _call(self, x, out=None):
else:
out.assign(self.constant)

@property
def adjoint(self):
"""Adjoint of the operator.
Only defined if the operator is the constant operator.
"""

def derivative(self, point):
"""Derivative of this operator, always zero.
Expand Down Expand Up @@ -808,7 +816,7 @@ def __str__(self):
return "{}".format(self.constant)


class ZeroOperator(ConstantOperator):
class ZeroOperator(Operator):

"""Operator mapping each element to the zero element::
Expand All @@ -824,114 +832,42 @@ def __init__(self, domain, range=None):
Domain of the operator.
range : `LinearSpace`, optional
Range of the operator. Default: ``domain``
"""
if range is None:
range = domain

super().__init__(constant=range.zero(), domain=domain, range=range)

def __repr__(self):
"""Return ``repr(self)``."""
return '{}({!r})'.format(self.__class__.__name__, self.domain)

def __str__(self):
"""Return ``str(self)``."""
return '0'


class ResidualOperator(Operator):

"""Operator that calculates the residual ``op(x) - y``.
``ResidualOperator(op, y)(x) == op(x) - y``
"""

def __init__(self, operator, vector):
"""Initialize a new instance.
Parameters
----------
operator : `Operator`
Operator to be used in the residual expression. Its
`Operator.range` must be a `LinearSpace`.
vector : ``operator.range`` `element-like`
Vector to be subtracted from the operator result.
Examples
--------
>>> import odl
>>> r3 = odl.rn(3)
>>> y = r3.element([1, 2, 3])
>>> ident_op = odl.IdentityOperator(r3)
>>> res_op = odl.ResidualOperator(ident_op, y)
>>> x = r3.element([4, 5, 6])
>>> res_op(x)
rn(3).element([3.0, 3.0, 3.0])
>>> op = odl.ZeroOperator(odl.rn(3))
>>> op([1, 2, 3])
rn(3).element([0.0, 0.0, 0.0])
"""
if not isinstance(operator, Operator):
raise TypeError('`op` {!r} not a Operator instance'
''.format(operator))

if not isinstance(operator.range, LinearSpace):
raise TypeError('`op.range` {!r} not a LinearSpace instance'
''.format(operator.range))

self.__operator = operator
self.__vector = operator.range.element(vector)
super().__init__(operator.domain, operator.range)

@property
def operator(self):
"""The operator to apply."""
return self.__operator
if range is None:
range = domain

@property
def vector(self):
"""The constant operator range element to subtract."""
return self.__vector
super().__init__(domain, range, linear=True)

def _call(self, x, out=None):
"""Evaluate the residual at ``x`` and write to ``out`` if given."""
"""Return the constant vector or assign it to ``out``."""
if out is None:
out = self.operator(x)
out = 0 * x
else:
self.operator(x, out=out)

out -= self.vector
out.lincomb(0, x)
return out

def derivative(self, point):
"""Derivative the residual operator.
It is equal to the derivative of the "inner" operator:
``ResidualOperator(op, y).derivative(z) == op.derivative(z)``
Parameters
----------
point : `domain` element
Any element in the domain where the derivative should be taken
@property
def adjoint(self):
"""Adjoint of the operator.
Examples
--------
>>> import odl
>>> r3 = odl.rn(3)
>>> op = IdentityOperator(r3)
>>> res = ResidualOperator(op, r3.element([1, 2, 3]))
>>> x = r3.element([4, 5, 6])
>>> res.derivative(x)(x)
rn(3).element([4.0, 5.0, 6.0])
The zero operator is self adjoint.
"""
return self.operator.derivative(point)
return self

def __repr__(self):
"""Return ``repr(self)``."""
return '{}({!r}, {!r})'.format(self.__class__.__name__, self.operator,
self.vector)
return '{}({!r})'.format(self.__class__.__name__, self.domain)

def __str__(self):
"""Return ``str(self)``."""
return "{} - {}".format(self.op, self.vector)
return '0'


if __name__ == '__main__':
Expand Down
106 changes: 96 additions & 10 deletions odl/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from odl.set import LinearSpace, LinearSpaceElement, Set, Field


__all__ = ('Operator', 'OperatorComp', 'OperatorSum',
__all__ = ('Operator', 'OperatorComp', 'OperatorSum', 'OperatorVectorSum',
'OperatorLeftScalarMult', 'OperatorRightScalarMult',
'FunctionalLeftVectorMult',
'OperatorLeftVectorMult', 'OperatorRightVectorMult',
Expand Down Expand Up @@ -719,14 +719,11 @@ def __add__(self, other):
``self + other <==> (x --> self(x) + other(x))``
"""
from odl.operator.default_ops import ConstantOperator

if other in self.range:
return self + ConstantOperator(other, self.domain, self.range)
return OperatorVectorSum(self, other)
elif other in self.range.field:
constant_vector = other * self.range.one()
return self + ConstantOperator(constant_vector,
self.domain, self.range)
return OperatorVectorSum(self, constant_vector)
elif isinstance(other, Operator):
return OperatorSum(self, other)
else:
Expand Down Expand Up @@ -1158,6 +1155,97 @@ def __str__(self):
return '({} + {})'.format(self.left, self.right)


class OperatorVectorSum(Operator):

"""Operator that computes ``op(x) + y``.
``OperatorVectorSum(op, y)(x) == op(x) + y``
"""

def __init__(self, operator, vector):
"""Initialize a new instance.
Parameters
----------
operator : `Operator`
Operator to be used in the sum. Its
`Operator.range` must be a `LinearSpace`.
vector : ``operator.range`` `element-like`
Vector to be subtracted from the operator result.
Examples
--------
>>> import odl
>>> r3 = odl.rn(3)
>>> y = r3.element([1, 2, 3])
>>> ident_op = odl.IdentityOperator(r3)
>>> sum_op = odl.OperatorVectorSum(ident_op, y)
>>> x = r3.element([4, 5, 6])
>>> sum_op(x)
rn(3).element([5.0, 7.0, 9.0])
"""
if not isinstance(operator, Operator):
raise TypeError('`op` {!r} not a Operator instance'
''.format(operator))

if not isinstance(operator.range, LinearSpace):
raise TypeError('`op.range` {!r} not a LinearSpace instance'
''.format(operator.range))

self.__operator = operator
self.__vector = operator.range.element(vector)
super().__init__(operator.domain, operator.range)

@property
def operator(self):
"""The operator to apply."""
return self.__operator

@property
def vector(self):
"""The constant operator range element to subtract."""
return self.__vector

def _call(self, x, out=None):
"""Evaluate the residual at ``x`` and write to ``out`` if given."""
if out is None:
out = self.operator(x)
else:
self.operator(x, out=out)

out += self.vector
return out

def derivative(self, point):
"""Derivative the operator vector sum.
It is equal to the derivative of the "inner" operator:
``OperatorVectorSum(op, y).derivative(z) == op.derivative(z)``
Parameters
----------
point : `domain` element
Any element in the domain where the derivative should be taken
Examples
--------
>>> import odl
>>> r3 = odl.rn(3)
>>> op = odl.IdentityOperator(r3)
>>> res = odl.OperatorVectorSum(op, r3.element([1, 2, 3]))
>>> x = r3.element([4, 5, 6])
>>> res.derivative(x)(x)
rn(3).element([4.0, 5.0, 6.0])
"""
return self.operator.derivative(point)

def __repr__(self):
"""Return ``repr(self)``."""
return '{}({!r}, {!r})'.format(self.__class__.__name__, self.operator,
self.vector)


class OperatorComp(Operator):

"""Expression type for the composition of operators.
Expand Down Expand Up @@ -1470,8 +1558,7 @@ def derivative(self, x):
--------
>>> import odl
>>> space = odl.rn(3)
>>> operator = odl.ResidualOperator(odl.IdentityOperator(space),
... [1, 1, 1])
>>> operator = odl.IdentityOperator(space) - space.element([1, 1, 1])
>>> left_mul_op = OperatorLeftScalarMult(operator, 3)
>>> derivative = left_mul_op.derivative([0, 0, 0])
>>> derivative([1, 1, 1])
Expand Down Expand Up @@ -1653,8 +1740,7 @@ def derivative(self, x):
--------
>>> import odl
>>> space = odl.rn(3)
>>> operator = odl.ResidualOperator(odl.IdentityOperator(space),
... [1, 1, 1])
>>> operator = odl.IdentityOperator(space) - space.element([1, 1, 1])
>>> left_mul_op = OperatorRightScalarMult(operator, 3)
>>> derivative = left_mul_op.derivative([0, 0, 0])
>>> derivative([1, 1, 1])
Expand Down
6 changes: 3 additions & 3 deletions odl/operator/pspace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def derivative(self, x):
Example with affine operator
>>> residual_op = odl.ResidualOperator(I, r3.element([1, 1, 1]))
>>> residual_op = I - r3.element([1, 1, 1])
>>> op = ProductSpaceOperator([[0, residual_op], [0, 0]],
... domain=X, range=X)
Expand Down Expand Up @@ -635,7 +635,7 @@ def derivative(self, x):
>>> import odl
>>> I = odl.IdentityOperator(odl.rn(3))
>>> residual_op = odl.ResidualOperator(I, I.domain.element([1, 1, 1]))
>>> residual_op = I - I.domain.element([1, 1, 1])
>>> op = BroadcastOperator(residual_op, 2 * residual_op)
Calling operator offsets by ``[1, 1, 1]``:
Expand Down Expand Up @@ -787,7 +787,7 @@ def derivative(self, x):
Example with affine operator
>>> residual_op = odl.ResidualOperator(I, r3.element([1, 1, 1]))
>>> residual_op = I - r3.element([1, 1, 1])
>>> op = ReductionOperator(residual_op, 2 * residual_op)
Calling operator gives offset by [3, 3, 3]
Expand Down
5 changes: 2 additions & 3 deletions odl/solvers/nonsmooth/proximal_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import scipy.special

from odl.operator import (Operator, IdentityOperator, ScalingOperator,
ConstantOperator, ResidualOperator, DiagonalOperator)
ConstantOperator, DiagonalOperator)
from odl.space import ProductSpace
from odl.set import LinearSpaceElement

Expand Down Expand Up @@ -302,8 +302,7 @@ def quadratic_perturbation_prox_factory(step_size):
prox = proximal_arg_scaling(prox_factory, const)(step_size)
if u is not None:
return (const * prox *
ResidualOperator(ScalingOperator(u.space, const),
step_size * const * u))
(ScalingOperator(u.space, const) - step_size * const * u))
else:
return const * prox * const

Expand Down
2 changes: 1 addition & 1 deletion test/solvers/functional/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_arithmetic():
# Create elements needed for later
functional = odl.solvers.L2Norm(space).translated([1, 2, 3])
functional2 = odl.solvers.L2NormSquared(space)
operator = odl.ResidualOperator(odl.IdentityOperator(space), [4, 5, 6])
operator = odl.IdentityOperator(space) - space.element([4, 5, 6])
x = noise_element(functional.domain)
y = noise_element(functional.domain)
scalar = np.pi
Expand Down

0 comments on commit 9b9996c

Please sign in to comment.