Skip to content

Commit

Permalink
ENH: added ComponentProjection, see #53
Browse files Browse the repository at this point in the history
  • Loading branch information
adler-j committed Nov 19, 2015
1 parent 0302393 commit a5cbfd6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
44 changes: 33 additions & 11 deletions odl/operator/pspace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from odl.set.pspace import ProductSpace


__all__ = ('ProductSpaceOperator', )
__all__ = ('ProductSpaceOperator', 'ComponentProjection')


class ProductSpaceOperator(Operator):
Expand Down Expand Up @@ -64,19 +64,20 @@ def __init__(self, operators, dom=None, ran=None):
ranges = [None]*self.ops.shape[0]

for row, col, op in zip(self.ops.row, self.ops.col, self.ops.data):
if domains[col] is None:
domains[col] = op.domain
elif domains[col] != op.domain:
raise ValueError('Column {}, has inconcistient domains,'
'got {} and {}'
''.format(col, domains[col], op.domain))

if ranges[row] is None:
ranges[row] = op.range
elif ranges[row] != op.range:
raise ValueError('Row {}, has inconcistient ranges,'
'got {} and {}'
''.format(row, ranges[row], op.range))

if domains[col] is None:
domains[col] = op.domain
elif domains[col] != op.domain:
raise ValueError('Column {}, has inconcistient domains,'
'got {} and {}'
''.format(col, domains[col], op.domain))

if dom is None:
for col in range(len(domains)):
Expand Down Expand Up @@ -114,7 +115,7 @@ def _apply(self, x, out):
Returns
-------
None
Examples
--------
todo
Expand All @@ -126,8 +127,10 @@ def _apply(self, x, out):
else:
# TODO: optimize
out[i] += op(x[j])

# TODO: set zero on non-evaluated rows

for i, evaluated in enumerate(has_evaluated_row):
if not evaluated:
out[i].set_zero()

def _call(self, x):
""" TODO
Expand All @@ -150,7 +153,6 @@ def _call(self, x):
out[i] += op(x[j])
return out


@property
def adjoint(self):
""" The adjoint is given by taking the conjugate of the scalar
Expand All @@ -162,6 +164,26 @@ def __repr__(self):
"""op.__repr__() <==> repr(op)."""
return 'ProductSpaceOperator({!r})'.format(self.ops)


class ComponentProjection(Operator):
def __init__(self, space, index):
self.index = index
super().__init__(space, space[index], linear=True)

def _apply(self, x, out):
out.assign(x[self.index])

def _call(self, x):
return x[self.index].copy()

@property
def adjoint(self):
""" The adjoint is given by extending along indices, and setting
zero along the others
"""
# TODO: implement
raise NotImplementedError()


if __name__ == '__main__':
from doctest import testmod, NORMALIZE_WHITESPACE
Expand Down
26 changes: 26 additions & 0 deletions test/operator/pspace_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ def test_sum_call():
assert all_almost_equal(op(z)[0], x + y)


def test_project_call():
r3 = odl.Rn(3)
I = odl.IdentityOperator(r3)
op = odl.ProductSpaceOperator([[I],
[I]])

x = r3.element([1, 2, 3])
y = r3.element([7, 8, 9])
z = op.domain.element([x, y])

assert all_almost_equal(op(z)[0], x)


def test_diagonal_call():
r3 = odl.Rn(3)
I = odl.IdentityOperator(r3)
Expand All @@ -96,6 +109,19 @@ def test_swap_call():

assert all_almost_equal(op(z), result)

def test_projection():
r3 = odl.Rn(3)
r3xr3 = odl.ProductSpace(r3, 2)

x = r3.element([1, 2, 3])
y = r3.element([7, 8, 9])
z = r3xr3.element([x, y])
proj_0 = odl.ComponentProjection(r3xr3, 0)
assert x == proj_0(z)

proj_1 = odl.ComponentProjection(r3xr3, 1)
assert y == proj_1(z)


if __name__ == '__main__':
pytest.main(str(__file__.replace('\\', '/')) + ' -v')

0 comments on commit a5cbfd6

Please sign in to comment.