/
index_methods.py
393 lines (305 loc) · 12.6 KB
/
index_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
"""Module with functions operating on IndexedBase, Indexed and Idx objects
- Check shape conformance
- Determine indices in resulting expression
etc.
Methods in this module could be implemented by calling methods on Expr
objects instead. When things stabilize this could be a useful
refactoring.
"""
import functools
from ..core import Function
from .indexed import Idx, Indexed
class IndexConformanceExceptionError(Exception):
"""Raised if indexes are not consistent."""
def _remove_repeated(inds):
"""Removes repeated objects from sequences
Returns a set of the unique objects and a tuple of all that have been
removed.
>>> l1 = [1, 2, 3, 2]
>>> _remove_repeated(l1)
({1, 3}, (2,))
"""
sum_index = {}
for i in inds:
if i in sum_index:
sum_index[i] += 1
else:
sum_index[i] = 0
inds = [x for x in inds if not sum_index[x]]
return set(inds), tuple(k for k, v in sum_index.items() if v)
def _get_indices_Mul(expr, return_dummies=False):
"""Determine the outer indices of a Mul object.
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
>>> x = IndexedBase('x')
>>> y = IndexedBase('y')
>>> _get_indices_Mul(x[i, k]*y[j, k])
({i, j}, {})
>>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True)
({i, j}, {}, (k,))
"""
inds = list(map(get_indices, expr.args))
inds, _ = list(zip(*inds))
inds = list(map(list, inds))
inds = list(functools.reduce(lambda x, y: x + y, inds))
inds, dummies = _remove_repeated(inds)
symmetry = {}
if return_dummies:
return inds, symmetry, dummies
return inds, symmetry
def _get_indices_Pow(expr):
"""Determine outer indices of a power or an exponential.
A power is considered a universal function, so that the indices of a Pow is
just the collection of indices present in the expression. This may be
viewed as a bit inconsistent in the special case:
x[i]**2 = x[i]*x[i] (1)
The above expression could have been interpreted as the contraction of x[i]
with itself, but we choose instead to interpret it as a function
lambda y: y**2
applied to each element of x (a universal function in numpy terms). In
order to allow an interpretation of (1) as a contraction, we need
contravariant and covariant Idx subclasses.
Expressions in the base or exponent are subject to contraction as usual,
but an index that is present in the exponent, will not be considered
contractable with its own base. Note however, that indices in the same
exponent can be contracted with each other.
>>> A = IndexedBase('A')
>>> x = IndexedBase('x')
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
>>> _get_indices_Pow(exp(A[i, j]*x[j]))
({i}, {})
>>> _get_indices_Pow(Pow(x[i], x[i]))
({i}, {})
>>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i]))
({i}, {})
"""
base, exp = expr.as_base_exp()
binds, _ = get_indices(base)
einds, _ = get_indices(exp)
inds = binds | einds
symmetries = {}
return inds, symmetries
def _get_indices_Add(expr):
"""Determine outer indices of an Add object.
In a sum, each term must have the same set of outer indices. A valid
expression could be
x(i)*y(j) - x(j)*y(i)
But we do not allow expressions like:
x(i)*y(j) - z(j)*z(j)
>>> i, j, k = map(Idx, ['i', 'j', 'k'])
>>> x = IndexedBase('x')
>>> y = IndexedBase('y')
>>> _get_indices_Add(x[i] + x[k]*y[i, k])
({i}, {})
"""
inds = list(map(get_indices, expr.args))
inds, _ = list(zip(*inds))
# allow broadcast of scalars
non_scalars = [x for x in inds if x != set()]
if not non_scalars:
return set(), {}
if not all(x == non_scalars[0] for x in non_scalars[1:]):
raise IndexConformanceExceptionError(f'Indices are not consistent: {expr}')
symmetries = {}
return non_scalars[0], symmetries
def get_indices(expr):
"""Determine the outer indices of expression ``expr``
By *outer* we mean indices that are not summation indices. Returns a set
and a dict. The set contains outer indices and the dict contains
information about index symmetries.
Examples
========
>>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
>>> i, j = symbols('i j', integer=True)
The indices of the total expression is determined, Repeated indices imply a
summation, for instance the trace of a matrix A:
>>> get_indices(A[i, i])
(set(), {})
In the case of many terms, the terms are required to have identical
outer indices. Else an IndexConformanceExceptionError is raised.
>>> get_indices(x[i] + A[i, j]*y[j])
({i}, {})
:Exceptions:
An IndexConformanceExceptionError means that the terms are not compatible, e.g.
>>> get_indices(x[i] + y[j])
Traceback (most recent call last):
...
IndexConformanceExceptionError: Indices are not consistent: x(i) + y(j)
.. warning::
The concept of *outer* indices applies recursively, starting on the deepest
level. This implies that dummies inside parenthesis are assumed to be
summed first, so that the following expression is handled gracefully:
>>> get_indices((x[i] + A[i, j]*y[j])*x[j])
({i, j}, {})
This is correct and may appear convenient, but you need to be careful
with this as Diofant will happily .expand() the product, if requested. The
resulting expression would mix the outer ``j`` with the dummies inside
the parenthesis, which makes it a different expression. To be on the
safe side, it is best to avoid such ambiguities by using unique indices
for all contractions that should be held separate.
"""
# We call ourself recursively to determine indices of sub expressions.
# break recursion
if isinstance(expr, Indexed):
c = expr.indices
inds, _ = _remove_repeated(c)
return inds, {}
if expr is None:
return set(), {}
if expr.is_Atom:
return set(), {}
if isinstance(expr, Idx):
return {expr}, {}
# recurse via specialized functions
if expr.is_Mul:
return _get_indices_Mul(expr)
if expr.is_Add:
return _get_indices_Add(expr)
if expr.is_Pow:
return _get_indices_Pow(expr)
if isinstance(expr, Function):
# Support ufunc like behaviour by returning indices from arguments.
# Functions do not interpret repeated indices across argumnts
# as summation
ind0 = set()
for arg in expr.args:
ind, sym = get_indices(arg)
ind0 |= ind
return ind0, sym
raise NotImplementedError('No specialized handling of '
f'type {type(expr)}')
def get_contraction_structure(expr):
"""Determine dummy indices of ``expr`` and describe its structure
By *dummy* we mean indices that are summation indices.
The structure of the expression is determined and described as follows:
1) A conforming summation of Indexed objects is described with a dict where
the keys are summation indices and the corresponding values are sets
containing all terms for which the summation applies. All Add objects
in the Diofant expression tree are described like this.
2) For all nodes in the Diofant expression tree that are *not* of type Add, the
following applies:
If a node discovers contractions in one of its arguments, the node
itself will be stored as a key in the dict. For that key, the
corresponding value is a list of dicts, each of which is the result of a
recursive call to get_contraction_structure(). The list contains only
dicts for the non-trivial deeper contractions, omitting dicts with None
as the one and only key.
.. Note:: The presence of expressions among the dictionary keys indicates
multiple levels of index contractions. A nested dict displays nested
contractions and may itself contain dicts from a deeper level. In
practical calculations the summation in the deepest nested level must be
calculated first so that the outer expression can access the resulting
indexed object.
Examples
========
>>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
>>> i, j = map(Idx, ['i', 'j'])
>>> get_contraction_structure(x[i]*y[i] + A[j, j])
{(i,): {x[i]*y[i]}, (j,): {A[j, j]}}
>>> get_contraction_structure(x[i]*y[j])
{None: {x[i]*y[j]}}
A multiplication of contracted factors results in nested dicts representing
the internal contractions.
>>> d = get_contraction_structure(x[i, i]*y[j, j])
>>> sorted(d, key=default_sort_key)
[None, x[i, i]*y[j, j]]
In this case, the product has no contractions:
>>> d[None]
{x[i, i]*y[j, j]}
Factors are contracted "first":
>>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key)
[{(i,): {x[i, i]}}, {(j,): {y[j, j]}}]
A parenthesized Add object is also returned as a nested dictionary. The
term containing the parenthesis is a Mul with a contraction among the
arguments, so it will be found as a key in the result. It stores the
dictionary resulting from a recursive call on the Add expression.
>>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j]))
>>> sorted(d, key=default_sort_key)
[(x[j]*A[i, j] + y[i])*x[i], (i,)]
>>> d[(i,)]
{(x[j]*A[i, j] + y[i])*x[i]}
>>> d[x[i]*(A[i, j]*x[j] + y[i])]
[{None: {y[i]}, (j,): {x[j]*A[i, j]}}]
Powers with contractions in either base or exponent will also be found as
keys in the dictionary, mapping to a list of results from recursive calls:
>>> d = get_contraction_structure(A[j, j]**A[i, i])
>>> d[None]
{A[j, j]**A[i, i]}
>>> nested_contractions = d[A[j, j]**A[i, i]]
>>> nested_contractions[0]
{(j,): {A[j, j]}}
>>> nested_contractions[1]
{(i,): {A[i, i]}}
The description of the contraction structure may appear complicated when
represented with a string in the above examples, but it is easy to iterate
over:
>>> for key in d:
... if isinstance(key, Expr):
... continue
... for term in d[key]:
... if term in d:
... # treat deepest contraction first
... pass
... # treat outermost contactions here
"""
# We call ourself recursively to inspect sub expressions.
if isinstance(expr, Indexed):
_, key = _remove_repeated(expr.indices)
return {key or None: {expr}}
if expr.is_Atom:
return {None: {expr}}
if expr.is_Mul:
*_, key = _get_indices_Mul(expr, return_dummies=True)
result = {key or None: {expr}}
# recurse on every factor
nested = []
for fac in expr.args:
facd = get_contraction_structure(fac)
if not (None in facd and len(facd) == 1):
nested.append(facd)
if nested:
result[expr] = nested
return result
if expr.is_Pow:
# recurse in base and exp separately. If either has internal
# contractions we must include ourselves as a key in the returned dict
b, e = expr.as_base_exp()
dbase = get_contraction_structure(b)
dexp = get_contraction_structure(e)
dicts = []
for d in dbase, dexp:
if not (None in d and len(d) == 1):
dicts.append(d)
result = {None: {expr}}
if dicts:
result[expr] = dicts
return result
if expr.is_Add:
# Note: we just collect all terms with identical summation indices, We
# do nothing to identify equivalent terms here, as this would require
# substitutions or pattern matching in expressions of unknown
# complexity.
result = {}
for term in expr.args:
# recurse on every term
d = get_contraction_structure(term)
for k, v in d.items():
if k in result:
result[k] |= v
else:
result[k] = v
return result
if isinstance(expr, Function):
# Collect non-trivial contraction structures in each argument
# We do not report repeated indices in separate arguments as a
# contraction
deeplist = []
for arg in expr.args:
deep = get_contraction_structure(arg)
if not (None in deep and len(deep) == 1):
deeplist.append(deep)
d = {None: {expr}}
if deeplist:
d[expr] = deeplist
return d
raise NotImplementedError('No specialized handling of '
f'type {type(expr)}')