-
Notifications
You must be signed in to change notification settings - Fork 222
/
test_timestepping.py
110 lines (83 loc) · 3.03 KB
/
test_timestepping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import pytest
from devito import Grid, Eq, Operator, TimeFunction
@pytest.fixture
def grid(shape=(11, 11)):
return Grid(shape=shape)
@pytest.fixture
def a(grid):
"""Forward time data object, unrolled (save=True)"""
return TimeFunction(name='a', grid=grid, time_order=1, save=6)
@pytest.fixture
def b(grid):
"""Backward time data object, unrolled (save=True)"""
return TimeFunction(name='b', grid=grid, time_order=1, save=6)
@pytest.fixture
def c(grid):
"""Forward time data object, buffered (save=False)"""
return TimeFunction(name='c', grid=grid, time_order=1, save=None)
@pytest.fixture
def d(grid):
"""Forward time data object, unrolled (save=True), end order"""
return TimeFunction(name='d', grid=grid, time_order=2, save=6)
def test_forward(a):
a.data[0, :] = 1.
Operator(Eq(a.forward, a + 1.))()
for i in range(a.shape[0]):
assert np.allclose(a.data[i, :], 1. + i, rtol=1.e-12)
def test_backward(b):
b.data[-1, :] = 7.
Operator(Eq(b.backward, b - 1.))()
for i in range(b.shape[0]):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)
def test_forward_unroll(a, c, nt=5):
"""Test forward time marching with a buffered and an unrolled t"""
a.data[0, :] = 1.
c.data[0, :] = 1.
eqn_c = Eq(c.forward, c + 1.)
eqn_a = Eq(a.forward, c.forward)
Operator([eqn_c, eqn_a])(time=nt-1)
for i in range(nt):
assert np.allclose(a.data[i, :], 1. + i, rtol=1.e-12)
def test_forward_backward(a, b, nt=5):
"""Test a forward operator followed by a backward marching one"""
a.data[0, :] = 1.
b.data[0, :] = 1.
eqn_a = Eq(a.forward, a + 1.)
Operator(eqn_a)(time=nt-1)
eqn_b = Eq(b, a + 1.)
Operator(eqn_b)(time=nt-1)
for i in range(nt):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)
def test_forward_backward_overlapping(a, b, nt=5):
"""
Test a forward operator followed by a backward one, but with
overlapping operator definitions.
"""
a.data[0, :] = 1.
b.data[0, :] = 1.
op_fwd = Operator(Eq(a.forward, a + 1.))
op_bwd = Operator(Eq(b, a + 1.))
op_fwd(time=nt-1)
op_bwd(time=nt-1)
for i in range(nt):
assert np.allclose(b.data[i, :], 2. + i, rtol=1.e-12)
def test_loop_bounds_forward(d):
"""Test the automatic bound detection for forward time loops"""
d.data[:] = 1.
eqn = Eq(d, 2. + d.dt2)
Operator(eqn, opt=None)(dt=1.)
assert np.allclose(d.data[0, :], 1., rtol=1.e-12)
assert np.allclose(d.data[-1, :], 1., rtol=1.e-12)
for i in range(1, d.data.shape[0]-1):
assert np.allclose(d.data[i, :], 1. + i, rtol=1.e-12)
def test_loop_bounds_backward(d):
"""Test the automatic bound detection for backward time loops"""
d.data[:] = 5.
eqn = Eq(d.backward, d - 1)
op = Operator(eqn, opt=None)
op()
assert np.allclose(d.data[0, :], 0., rtol=1.e-12)
assert np.allclose(d.data[-1, :], 5., rtol=1.e-12)
for i in range(1, d.data.shape[0]-1):
assert np.allclose(d.data[i, :], i, rtol=1.e-12)