-
Notifications
You must be signed in to change notification settings - Fork 13
/
darray.py
82 lines (72 loc) · 2.45 KB
/
darray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
from mpi4py import MPI
from mpi4py_fft.distarray import DistArray, newDistArray
from mpi4py_fft.mpifft import PFFT
# Test DistArray. Start with alignment in axis 0, then tranfer to 2 and
# finally to 1
N = (16, 14, 12)
z0 = DistArray(N, dtype=float, alignment=0)
z0[:] = np.random.randint(0, 10, z0.shape)
s0 = MPI.COMM_WORLD.allreduce(np.sum(z0))
z1 = z0.redistribute(2)
s1 = MPI.COMM_WORLD.allreduce(np.sum(z1))
z2 = z1.redistribute(1)
s2 = MPI.COMM_WORLD.allreduce(np.sum(z2))
assert s0 == s1 == s2
fft = PFFT(MPI.COMM_WORLD, darray=z2, axes=(0, 2, 1))
z3 = newDistArray(fft, forward_output=True)
z2c = z2.copy()
fft.forward(z2, z3)
fft.backward(z3, z2)
s0, s1 = np.linalg.norm(z2), np.linalg.norm(z2c)
assert abs(s0-s1) < 1e-12, s0-s1
v0 = newDistArray(fft, forward_output=False, rank=1)
#v0 = Function(fft, forward_output=False, rank=1)
v0[:] = np.random.random(v0.shape)
v0c = v0.copy()
v1 = newDistArray(fft, forward_output=True, rank=1)
for i in range(3):
v1[i] = fft.forward(v0[i], v1[i])
for i in range(3):
v0[i] = fft.backward(v1[i], v0[i])
s0, s1 = np.linalg.norm(v0c), np.linalg.norm(v0)
assert abs(s0-s1) < 1e-12
nfft = PFFT(MPI.COMM_WORLD, darray=v0[0], axes=(0, 2, 1))
for i in range(3):
v1[i] = nfft.forward(v0[i], v1[i])
for i in range(3):
v0[i] = nfft.backward(v1[i], v0[i])
s0, s1 = np.linalg.norm(v0c), np.linalg.norm(v0)
assert abs(s0-s1) < 1e-12
N = (6, 6, 6)
z = DistArray(N, dtype=float, alignment=0)
z[:] = MPI.COMM_WORLD.Get_rank()
g0 = z.get((0, slice(None), 0))
z2 = z.redistribute(2)
z = z2.redistribute(out=z)
g1 = z.get((0, slice(None), 0))
assert np.all(g0 == g1)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-12
N = (3, 3, 6, 6, 6)
z2 = DistArray(N, dtype=float, val=1, alignment=2, rank=2)
z2[:] = MPI.COMM_WORLD.Get_rank()
z1 = z2.redistribute(1)
z0 = z1.redistribute(0)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(z2)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(z0)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-12
z1 = z0.redistribute(out=z1)
z0 = z1.redistribute(out=z0)
N = (6, 6, 6, 6, 6)
m0 = DistArray(N, dtype=float, alignment=2)
m0[:] = MPI.COMM_WORLD.Get_rank()
m1 = m0.redistribute(4)
m0 = m1.redistribute(out=m0)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(m0)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(m1)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-12