-
Notifications
You must be signed in to change notification settings - Fork 103
/
pspace.py
499 lines (404 loc) · 15.9 KB
/
pspace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
# Copyright 2014, 2015 The ODL development group
#
# This file is part of ODL.
#
# ODL is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ODL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ODL. If not, see <http://www.gnu.org/licenses/>.
"""Cartesian products of `LinearSpace`'s.
TODO: document public interface
"""
# Imports for common Python 2/3 codebase
from __future__ import print_function, division, absolute_import
from builtins import str, zip, super
from future import standard_library
standard_library.install_aliases()
from numbers import Integral
# External
import numpy as np
# ODL imports
from odl.set.space import LinearSpace
__all__ = ('ProductSpace',)
def _strip_space(x):
"""Strip the SPACE.element( ... ) part from a repr."""
r = repr(x)
space_repr = '{!r}.element('.format(x.space)
if r.startswith(space_repr) and r.endswith(')'):
r = r[len(space_repr):-1]
return r
def _indent(x):
"""Indent a string by 4 characters."""
lines = x.split('\n')
for i in range(len(lines)):
lines[i] = ' ' + lines[i]
return '\n'.join(lines)
def _prod_inner_sum_not_defined(x):
raise NotImplementedError('inner product not defined with custom product '
'norm.')
class ProductSpace(LinearSpace):
"""The Cartesian product of N linear spaces.
The product X1 x ... x XN is itself a linear space, where the
linear combination is defined component-wise.
TODO: document public interface
"""
def __init__(self, *spaces, **kwargs):
"""Initialize a new ProductSpace.
The product X1 x ... x XN is itself a linear space, where the
linear combination is defined component-wise.
Parameters
----------
args : {'LinearSpace' and 'int' OR 'LinearSpace' instances
Either a space and an integer,
in this case the power of the space is taken (R^n)
Otherwise, a set of spaces,
in this case the product is taken (RxRxRxC)
kwargs : {'ord', 'weights', 'prod_norm'}
'ord' : float, optional
Order of the product distance/norm, i.e.
dist(x, y) = np.linalg.norm(x-y, ord=ord)
norm(x) = np.linalg.norm(x, ord=ord)
Default: 2.0
'weights' : array-like, optional, only usable with 'ord'
Array of weights, same size as number of space
components. All weights must be positive. It is
multiplied with the tuple of distances before
applying the Rn norm or 'prod_norm'.
Default: (1.0,...,1.0)
'prod_norm' : callable, optional
Function that should be applied to the array of
distances/norms. Specifying a product norm causes
the space to NOT be a Hilbert space.
Default: np.linalg.norm(x, ord=ord)
The following float values for `prod_norm` can be specified.
Note that any value of ord < 1 only gives a pseudo-norm.
+----------+---------------------------+
|prod_norm |Distance Definition |
+==========+===========================+
|'inf' |`max(w * z)` |
|'-inf' |`min(w * z)` |
|other |`sum(w * z**ord)**(1/ord)` |
+==========+===========================+
Here, z = (x[0].dist(y[0]),..., x[n-1].dist(y[n-1])) and
w = weights.
Note that `0 <= ord < 1` are not allowed since these
pseudo-norms are very unstable numerically.
Returns
-------
prodspace : ProductSpace instance
Examples
--------
>>> from odl import Rn
>>> r2x3 = ProductSpace(Rn(2), Rn(3))
"""
if (len(spaces) == 2 and
isinstance(spaces[0], LinearSpace) and
isinstance(spaces[1], Integral)):
# Powerspace initialization
spaces = [spaces[0]] * spaces[1]
wrong_spaces = [spc for spc in spaces
if not isinstance(spc, LinearSpace)]
if wrong_spaces:
raise TypeError('{!r} not LinearSpace instance(s).'
''.format(wrong_spaces))
if not all(spc.field == spaces[0].field for spc in spaces):
raise TypeError('All spaces must have the same field')
prod_norm = kwargs.get('prod_norm', None)
if prod_norm is not None:
if not callable(prod_norm):
raise TypeError('product norm is not callable.')
self._prod_norm = prod_norm
self._prod_inner_sum = _prod_inner_sum_not_defined
else:
order = float(kwargs.get('ord', 2.0))
if 0 <= order < 1:
raise ValueError('Cannot use {:.2}-norm due to numerical '
'instability.'.format(order))
weights = kwargs.get('weights', None)
if weights is not None:
weights = np.atleast_1d(weights)
if not np.all(weights > 0):
raise ValueError('weights must all be positive')
if not len(weights) == len(spaces):
raise ValueError('spaces and weights have different '
'lengths ({} != {}).'
''.format(len(spaces), len(weights)))
def w_norm(x):
return np.linalg.norm(x*weights, ord=order)
self._prod_norm = w_norm
if order == 2.0:
def w_inner_sum(x):
return np.linalg.dot(x, weights)
self._prod_inner_sum = w_inner_sum
else:
self._prod_inner_sum = _prod_inner_sum_not_defined
else:
def norm(x):
return np.linalg.norm(x, ord=order)
self._prod_norm = norm
if order == 2.0:
self._prod_inner_sum = np.sum
else:
self._prod_inner_sum = _prod_inner_sum_not_defined
self._spaces = tuple(spaces)
self._size = len(spaces)
self._field = spaces[0].field
super().__init__()
@property
def size(self):
"""The number of factors."""
return self._size
@property
def field(self):
"""The common underlying field of all factors."""
return self._field
@property
def spaces(self):
"""A tuple containing all spaces."""
return self._spaces
def element(self, inp=None):
"""Create an element in the product space.
Parameters
----------
The method has three call patterns, the first is:
args : None
Create a new vector from scratch.
The second is to wrap existing vectors:
args : tuple of `LinearSpace.Vector`s
A tuple of vectors in the underlying spaces.
This will simply wrap the Vectors (not copy).
The third pattern is to create a new Vector from scratch, in
this case
args : tuple of array-like objects
Returns
-------
ProductSpace.Vector instance
Examples
--------
>>> from odl import Rn
>>> r2, r3 = Rn(2), Rn(3)
>>> vec_2, vec_3 = r2.element(), r3.element()
>>> r2x3 = ProductSpace(r2, r3)
>>> vec_2x3 = r2x3.element()
>>> vec_2.space == vec_2x3[0].space
True
>>> vec_3.space == vec_2x3[1].space
True
Creates an element in the product space
>>> from odl import Rn
>>> r2, r3 = Rn(2), Rn(3)
>>> prod = ProductSpace(r2, r3)
>>> x2 = r2.element([1, 2])
>>> x3 = r3.element([1, 2, 3])
>>> x = prod.element([x2, x3])
>>> print(x)
{[1.0, 2.0], [1.0, 2.0, 3.0]}
"""
# If data is given as keyword arg, prefer it over arg list
if inp is None:
inp = [space.element() for space in self.spaces]
if (all(isinstance(v, LinearSpace.Vector) for v in inp) and
all(part.space == space
for part, space in zip(inp, self.spaces))):
parts = list(inp)
else:
# Delegate constructors
parts = [space.element(arg)
for arg, space in zip(inp, self.spaces)]
return self.Vector(self, parts)
def zero(self):
"""Create the zero vector of the product space.
The i:th component of the product space zero vector is the
zero vector of the i:th space in the product.
Parameters
----------
None
Returns
-------
zero : ProductSpace.Vector
The zero vector in the product space
Examples
--------
>>> from odl import Rn
>>> r2, r3 = Rn(2), Rn(3)
>>> zero_2, zero_3 = r2.zero(), r3.zero()
>>> r2x3 = ProductSpace(r2, r3)
>>> zero_2x3 = r2x3.zero()
>>> zero_2 == zero_2x3[0]
True
>>> zero_3 == zero_2x3[1]
True
"""
return self.element([space.zero() for space in self.spaces])
def _lincomb(self, a, x, b, y, out):
# pylint: disable=protected-access
for space, xp, yp, outp in zip(self.spaces, x.parts, y.parts, out.parts):
space._lincomb(a, xp, b, yp, outp)
def _dist(self, x1, x2):
dists = np.fromiter(
(spc._dist(x1p, x2p)
for spc, x1p, x2p in zip(self.spaces, x1.parts, x2.parts)),
dtype=np.float64, count=self.size)
return self._prod_norm(dists)
def _norm(self, x):
norms = np.fromiter(
(spc._norm(xp)
for spc, xp in zip(self.spaces, x.parts)),
dtype=np.float64, count=self.size)
return self._prod_norm(norms)
def _inner(self, x1, x2):
inners = np.fromiter(
(spc._inner(x1p, x2p)
for spc, x1p, x2p in zip(self.spaces, x1.parts, x2.parts)),
dtype=np.float64, count=self.size)
return self._prod_inner_sum(inners)
def _multiply(self, x1, x2, out):
for spc, xp, yp, outp in zip(self.spaces, x1.parts, x2.parts, out.parts):
spc._multiply(xp, yp, outp)
def __eq__(self, other):
"""`ps.__eq__(other) <==> ps == other`.
Returns
-------
equals : bool
`True` if `other` is a `ProductSpace` instance, has
the same length and the same factors. `False` otherwise.
Examples
--------
>>> from odl import Rn
>>> r2, r3 = Rn(2), Rn(3)
>>> rn, rm = Rn(2), Rn(3)
>>> r2x3, rnxm = ProductSpace(r2, r3), ProductSpace(rn, rm)
>>> r2x3 == rnxm
True
>>> r3x2 = ProductSpace(r3, r2)
>>> r2x3 == r3x2
False
>>> r5 = ProductSpace(*[Rn(1)]*5)
>>> r2x3 == r5
False
>>> r5 = Rn(5)
>>> r2x3 == r5
False
"""
if other is self:
return True
else:
return (isinstance(other, ProductSpace) and
len(self) == len(other) and
all(x == y for x, y in zip(self.spaces,
other.spaces)))
def __len__(self):
"""`ps.__len__() <==> len(ps)`."""
return self._size
def __getitem__(self, indices):
"""`ps.__getitem__(indices) <==> ps[indices]`."""
return self.spaces[indices]
def __str__(self):
"""`ps.__str__() <==> str(ps)`."""
if all(self.spaces[0] == space for space in self.spaces):
return '{' + str(self.spaces[0]) + '}^' + str(self.size)
else:
return ' x '.join(str(space) for space in self.spaces)
def __repr__(self):
"""`ps.__repr__() <==> repr(ps)`."""
if all(self.spaces[0] == space for space in self.spaces):
return 'ProductSpace({!r}, {})'.format(self.spaces[0],
self.size)
else:
inner_str = ', '.join(repr(space) for space in self.spaces)
return 'ProductSpace({})'.format(inner_str)
class Vector(LinearSpace.Vector):
def __init__(self, space, parts):
""""Initialize a new instance."""
super().__init__(space)
self._parts = parts
@property
def parts(self):
"""The parts of this vector."""
return self._parts
@property
def size(self):
"""The number of factors of this vector's space."""
return self.space.size
def __eq__(self, other):
"""`ps.__eq__(other) <==> ps == other`.
Overrides the default `LinearSpace` method since it is
implemented with the distance function, which is prone to
numerical errors. This function checks equality per
component.
"""
if other not in self.space:
return False
elif other is self:
return True
else:
return all(sp == op for sp, op in zip(self.parts, other.parts))
def __len__(self):
"""`v.__len__() <==> len(v)`."""
return len(self.space)
def __getitem__(self, indices):
"""`ps.__getitem__(indices) <==> ps[indices]`."""
return self.parts[indices]
def __setitem__(self, indices, values):
"""`ps.__setitem__(indcs, vals) <==> ps[indcs] = vals`."""
self.parts[indices] = values
def __str__(self):
"""`ps.__str__() <==> str(ps)`."""
inner_str = ', '.join(str(part) for part in self.parts)
return '{{{}}}'.format(inner_str)
def __repr__(self):
"""`s.__repr__() <==> repr(s)`.
Examples
--------
>>> from odl import Rn
>>> r2, r3 = Rn(2), Rn(3)
>>> r2x3 = ProductSpace(r2, r3)
>>> x = r2x3.element([[1, 2], [3, 4, 5]])
>>> eval(repr(x)) == x
True
The result is readable:
>>> x
ProductSpace(Rn(2), Rn(3)).element([
[1.0, 2.0],
[3.0, 4.0, 5.0]
])
Nestled spaces work as well
>>> X = ProductSpace(r2x3, r2x3)
>>> x = X.element([[[1, 2], [3, 4, 5]],[[1, 2], [3, 4, 5]]])
>>> eval(repr(x)) == x
True
>>> x
ProductSpace(ProductSpace(Rn(2), Rn(3)), 2).element([
[
[1.0, 2.0],
[3.0, 4.0, 5.0]
],
[
[1.0, 2.0],
[3.0, 4.0, 5.0]
]
])
"""
inner_str = '[\n'
if len(self) < 5:
inner_str += ',\n'.join('{}'.format(
_indent(_strip_space(part))) for part in self.parts)
else:
inner_str += ',\n'.join('{}'.format(
_indent(_strip_space(part))) for part in self.parts[:3])
inner_str += ',\n ...\n'
inner_str += ',\n'.join('{}'.format(
_indent(_strip_space(part))) for part in self.parts[-1:])
inner_str += '\n]'
return '{!r}.element({})'.format(self.space, inner_str)
if __name__ == '__main__':
from doctest import testmod, NORMALIZE_WHITESPACE
testmod(optionflags=NORMALIZE_WHITESPACE)