/
equation.py
257 lines (214 loc) · 8.33 KB
/
equation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""User API to specify equations."""
import sympy
from sympy.solvers.solveset import linear_coeffs
from cached_property import cached_property
from devito.finite_differences import default_rules
from devito.logger import warning
from devito.tools import as_tuple
from devito.types.lazy import Evaluable
__all__ = ['Eq', 'Inc', 'solve']
class Eq(sympy.Eq, Evaluable):
"""
An equal relation between two objects, the left-hand side and the
right-hand side.
The left-hand side may be a Function or a SparseFunction. The right-hand
side may be any arbitrary expressions with numbers, Dimensions, Constants,
Functions and SparseFunctions as operands.
Parameters
----------
lhs : Function or SparseFunction
The left-hand side.
rhs : expr-like, optional
The right-hand side. Defaults to 0.
subdomain : SubDomain, optional
To restrict the computation of the Eq to a particular sub-region in the
computational domain.
coefficients : Substitutions, optional
Can be used to replace symbolic finite difference weights with user
defined weights.
implicit_dims : Dimension or list of Dimension, optional
An ordered list of Dimensions that do not explicitly appear in either the
left-hand side or in the right-hand side, but that should be honored when
constructing an Operator.
Examples
--------
>>> from devito import Grid, Function, Eq
>>> grid = Grid(shape=(4, 4))
>>> f = Function(name='f', grid=grid)
>>> Eq(f, f + 1)
Eq(f(x, y), f(x, y) + 1)
Any SymPy expressions may be used in the right-hand side.
>>> from sympy import sin
>>> Eq(f, sin(f.dx)**2)
Eq(f(x, y), sin(Derivative(f(x, y), x))**2)
Notes
-----
An Eq can be thought of as an assignment in an imperative programming language
(e.g., ``a[i] = b[i]*c``).
"""
is_Increment = False
def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None,
**kwargs):
kwargs['evaluate'] = False
obj = sympy.Eq.__new__(cls, lhs, rhs, **kwargs)
obj._subdomain = subdomain
obj._substitutions = coefficients
obj._implicit_dims = as_tuple(implicit_dims)
return obj
@property
def subdomain(self):
"""The SubDomain in which the Eq is defined."""
return self._subdomain
@cached_property
def evaluate(self):
"""
Evaluate the Equation or system of Equations.
The RHS of the Equation is evaluated at the indices of the LHS if required.
"""
try:
lhs, rhs = self.lhs.evaluate, self.rhs._eval_at(self.lhs).evaluate
except AttributeError:
lhs, rhs = self._evaluate_args()
eq = self.func(lhs, rhs, subdomain=self.subdomain,
coefficients=self.substitutions,
implicit_dims=self._implicit_dims)
if eq._uses_symbolic_coefficients:
# NOTE: As Coefficients.py is expanded we will not want
# all rules to be expunged during this procress.
rules = default_rules(eq, eq._symbolic_functions)
try:
eq = eq.xreplace({**eq.substitutions.rules, **rules})
except AttributeError:
if bool(rules):
eq = eq.xreplace(rules)
return eq
@property
def _flatten(self):
"""
Flatten vectorial/tensorial Equation into list of scalar Equations.
"""
if self.lhs.is_Matrix:
# Maps the Equations to retrieve the rhs from relevant lhs
eqs = dict(zip(as_tuple(self.lhs), as_tuple(self.rhs)))
# Get the relevant equations from the lhs structure. .values removes
# the symmetric duplicates and off-diagonal zeros.
lhss = self.lhs.values()
return [self.func(l, eqs[l], subdomain=self.subdomain,
coefficients=self.substitutions,
implicit_dims=self._implicit_dims)
for l in lhss]
else:
return [self]
@property
def substitutions(self):
return self._substitutions
@property
def implicit_dims(self):
return self._implicit_dims
@cached_property
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)
@cached_property
def _symbolic_functions(self):
try:
return self.lhs._symbolic_functions.union(self.rhs._symbolic_functions)
except AttributeError:
pass
try:
return self.lhs._symbolic_functions
except AttributeError:
pass
try:
return self.rhs._symbolic_functions
except AttributeError:
return frozenset()
else:
TypeError('Failed to retrieve symbolic functions')
@property
def func(self):
return lambda *args, **kwargs:\
self.__class__(*args,
subdomain=kwargs.pop('subdomain', self._subdomain),
coefficients=kwargs.pop('coefficients', self._substitutions),
implicit_dims=kwargs.pop('implicit_dims', self._implicit_dims),
**kwargs)
def xreplace(self, rules):
return self.func(self.lhs.xreplace(rules), self.rhs.xreplace(rules))
def __str__(self):
return "%s(%s, %s)" % (self.__class__.__name__, self.lhs, self.rhs)
__repr__ = __str__
class Inc(Eq):
"""
An increment relation between two objects, the left-hand side and the
right-hand side.
Parameters
----------
lhs : Function or SparseFunction
The left-hand side.
rhs : expr-like
The right-hand side.
subdomain : SubDomain, optional
To restrict the computation of the Eq to a particular sub-region in the
computational domain.
coefficients : Substitutions, optional
Can be used to replace symbolic finite difference weights with user
defined weights.
implicit_dims : Dimension or list of Dimension, optional
An ordered list of Dimensions that do not explicitly appear in either the
left-hand side or in the right-hand side, but that should be honored when
constructing an Operator.
Examples
--------
Inc may be used to express tensor contractions. Below, a summation along
the user-defined Dimension ``i``.
>>> from devito import Grid, Dimension, Function, Inc
>>> grid = Grid(shape=(4, 4))
>>> x, y = grid.dimensions
>>> i = Dimension(name='i')
>>> f = Function(name='f', grid=grid)
>>> g = Function(name='g', shape=(10, 4, 4), dimensions=(i, x, y))
>>> Inc(f, g)
Inc(f(x, y), g(i, x, y))
Notes
-----
An Inc can be thought of as the augmented assignment '+=' in an imperative
programming language (e.g., ``a[i] += c``).
"""
is_Increment = True
def __str__(self):
return "Inc(%s, %s)" % (self.lhs, self.rhs)
__repr__ = __str__
def solve(eq, target, **kwargs):
"""
Algebraically rearrange an Eq w.r.t. a given symbol.
This is a wrapper around ``sympy.solve``.
Parameters
----------
eq : expr-like
The equation to be rearranged.
target : symbol
The symbol w.r.t. which the equation is rearranged. May be a `Function`
or any other symbolic object.
**kwargs
Symbolic optimizations applied while rearranging the equation. For more
information. refer to ``sympy.solve.__doc__``.
"""
if isinstance(eq, Eq):
eq = eq.lhs - eq.rhs if eq.rhs != 0 else eq.lhs
sols = []
for e, t in zip(as_tuple(eq), as_tuple(target)):
# Try first linear solver
try:
cc = linear_coeffs(e._eval_at(t).evaluate, t)
sols.append(-cc[1]/cc[0])
except ValueError:
warning("Equation is not affine w.r.t the target, falling back to standard"
"sympy.solve that may be slow")
kwargs['rational'] = False # Avoid float indices
kwargs['simplify'] = False # Do not attempt premature optimisation
sols.append(sympy.solve(e.evaluate, t, **kwargs)[0])
# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
else:
return sols[0]