-
Notifications
You must be signed in to change notification settings - Fork 222
/
tools.py
273 lines (227 loc) · 8.88 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from functools import wraps, partial
from itertools import product
import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify
from devito.tools import Tag, as_tuple
from devito.finite_differences import Differentiable
class Transpose(Tag):
"""
Utility class to change the sign of a derivative. This is only needed
for odd order derivatives, which require a minus sign for the transpose.
"""
pass
direct = Transpose('direct', 1)
transpose = Transpose('transpose', -1)
class Side(Tag):
"""
Class encapsulating the side of the shift for derivatives.
"""
def adjoint(self, matvec):
if matvec is direct:
return self
else:
if self is centered:
return centered
elif self is right:
return left
elif self is left:
return right
else:
raise ValueError("Unsupported side value")
left = Side('left', -1)
right = Side('right', 1)
centered = Side('centered', 0)
def check_input(func):
@wraps(func)
def wrapper(expr, *args, **kwargs):
if expr.is_Number:
return S.Zero
elif not isinstance(expr, Differentiable):
raise ValueError("`%s` must be of type Differentiable (found `%s`)"
% (expr, type(expr)))
else:
return func(expr, *args, **kwargs)
return wrapper
def check_symbolic(func):
@wraps(func)
def wrapper(expr, *args, **kwargs):
if expr._uses_symbolic_coefficients:
expr_dict = expr.as_coefficients_dict()
if any(len(expr_dict) > 1 for item in expr_dict):
raise NotImplementedError("Applying the chain rule to functions "
"with symbolic coefficients is not currently "
"supported")
kwargs['symbolic'] = expr._uses_symbolic_coefficients
return func(expr, *args, **kwargs)
return wrapper
def dim_with_order(dims, orders):
"""
Create all possible derivative order for each dims
for example dim_with_order((x, y), 1) outputs:
[(1, 0), (0, 1), (1, 1)]
"""
ndim = len(dims)
max_order = np.min([6, np.max(orders)])
# Get all combinations and remove (0, 0, 0)
all_comb = tuple(product(range(max_order+1), repeat=ndim))[1:]
# Only keep the one with each dimension maximum order
all_comb = [c for c in all_comb if all(c[k] <= orders[k] for k in range(ndim))]
return all_comb
def deriv_name(dims, orders):
name = []
for d, o in zip(dims, orders):
name_dim = 't' if d.is_Time else d.root.name
name.append('d%s%s' % (name_dim, o) if o > 1 else 'd%s' % name_dim)
return ''.join(name)
def generate_fd_shortcuts(function):
"""Create all legal finite-difference derivatives for the given Function."""
dimensions = function.dimensions
s_fd_order = function.space_order
t_fd_order = function.time_order if (function.is_TimeFunction or
function.is_SparseTimeFunction) else 0
orders = tuple(t_fd_order if i.is_Time else s_fd_order for i in dimensions)
from devito.finite_differences.derivative import Derivative
def deriv_function(expr, deriv_order, dims, fd_order, side=None, **kwargs):
return Derivative(expr, *as_tuple(dims), deriv_order=deriv_order,
fd_order=fd_order, side=side, **kwargs)
all_combs = dim_with_order(dimensions, orders)
derivatives = {}
# All conventional FD shortcuts
for o in all_combs:
fd_dims = tuple(d for d, o_d in zip(dimensions, o) if o_d > 0)
d_orders = tuple(o_d for d, o_d in zip(dimensions, o) if o_d > 0)
fd_orders = tuple(t_fd_order if d.is_Time else s_fd_order for d in fd_dims)
deriv = partial(deriv_function, deriv_order=d_orders, dims=fd_dims,
fd_order=fd_orders)
name_fd = deriv_name(fd_dims, d_orders)
dname = (d.root.name for d in fd_dims)
desciption = 'derivative of order %s w.r.t dimension %s' % (d_orders, dname)
derivatives[name_fd] = (deriv, desciption)
# Add non-conventional, non-centered first-order FDs
for d, o in zip(dimensions, orders):
name = 't' if d.is_Time else d.root.name
if function.is_Staggered or o == 2:
# Add centered first derivatives if staggered
deriv = partial(deriv_function, deriv_order=1, dims=d,
fd_order=o, side=centered)
name_fd = 'd%sc' % name
desciption = 'centered derivative staggered w.r.t dimension %s' % d.name
derivatives[name_fd] = (deriv, desciption)
if not function.is_Staggered:
# Left
deriv = partial(deriv_function, deriv_order=1,
dims=d, fd_order=o, side=left)
name_fd = 'd%sl' % name
desciption = 'left first order derivative w.r.t dimension %s' % d.name
derivatives[name_fd] = (deriv, desciption)
# Right
deriv = partial(deriv_function, deriv_order=1,
dims=d, fd_order=o, side=right)
name_fd = 'd%sr' % name
desciption = 'right first order derivative w.r.t dimension %s' % d.name
derivatives[name_fd] = (deriv, desciption)
return derivatives
def symbolic_weights(function, deriv_order, indices, dim):
return [function._coeff_symbol(indices[j], deriv_order, function, dim)
for j in range(0, len(indices))]
@cacheit
def numeric_weights(deriv_order, indices, x0):
return finite_diff_weights(deriv_order, indices, x0)[-1][-1]
def generate_indices(func, dim, order, side=None, x0=None):
"""
Indices for the finite-difference scheme
Parameters
----------
func: Function
Function that is differentiated
dim: Dimension
Dimensions w.r.t which the derivative is taken
order: Int
Order of the finite-difference scheme
side: Side
Side of the scheme, (centered, left, right)
x0: Dimension or Expr or Number
Origin of the scheme, ie. `x`, `x + .5 * x.spacing`, ...
Returns
-------
Ordered list of indices
"""
# If staggered finited difference
if func.is_Staggered and not dim.is_Time:
x0, ind = generate_indices_staggered(func, dim, order, side=side, x0=x0)
else:
x0 = (x0 or {dim: dim}).get(dim, dim)
# Check if called from first_derivative()
ind = generate_indices_cartesian(dim, order, side, x0)
return ind, x0
def generate_indices_cartesian(dim, order, side, x0):
"""
Indices for the finite-difference scheme on a cartesian grid
Parameters
----------
dim: Dimension
Dimensions w.r.t which the derivative is taken
order: Int
Order of the finite-difference scheme
side: Side
Side of the scheme, (centered, left, right)
x0: Dimension or Expr or Number
Origin of the scheme, ie. `x`, `x + .5 * x.spacing`, ...
Returns
-------
Ordered list of indices
"""
shift = 0
# Shift if x0 is not on the grid
offset_c = 0 if sympify(x0).is_Integer else (dim - x0)/dim.spacing
offset_c = np.sign(offset_c) * (offset_c % 1)
# left and right max offsets for indices
o_start = -order//2 + int(np.ceil(-offset_c))
o_end = order//2 + 1 - int(np.ceil(offset_c))
offset = offset_c * dim.spacing
# Spacing
diff = dim.spacing
if side in [left, right]:
shift = 1
diff *= side.val
# Indices
if order < 2:
ind = [x0, x0 + diff] if offset == 0 else [x0 - offset, x0 + offset]
else:
ind = [(x0 + (i + shift) * diff + offset) for i in range(o_start, o_end)]
return tuple(ind)
def generate_indices_staggered(func, dim, order, side=None, x0=None):
"""
Indices for the finite-difference scheme on a staggered grid
Parameters
----------
func: Function
Function that is differentiated
dim: Dimension
Dimensions w.r.t which the derivative is taken
order: Int
Order of the finite-difference scheme
side: Side
Side of the scheme, (centered, left, right)
x0: Dimension or Expr or Number
Origin of the scheme, ie. `x`, `x + .5 * x.spacing`, ...
Returns
-------
Ordered list of indices
"""
diff = dim.spacing
start = (x0 or {}).get(dim) or func.indices_ref[dim]
try:
ind0 = func.indices_ref[dim]
except AttributeError:
ind0 = start
if start != ind0:
ind = [start - diff/2 - i * diff for i in range(0, order//2)][::-1]
ind += [start + diff/2 + i * diff for i in range(0, order//2)]
if order < 2:
ind = [start - diff/2, start + diff/2]
else:
ind = [start + i * diff for i in range(-order//2, order//2+1)]
if order < 2:
ind = [start, start - diff]
return start, tuple(ind)