-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_data.py
45 lines (30 loc) · 974 Bytes
/
test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
def absolute_value(M, complex_ = True):
r""" Generate test data for absolute value test function
"""
if complex_:
X, Y = np.meshgrid(*[np.linspace(-1,1,int(np.sqrt(M))) for i in range(2)])
X = X.reshape(-1,1) + 1j* Y.reshape(-1,1)
else:
X = np.linspace(-1,1, M).reshape(-1,1)
y = np.abs(X).flatten()
return X, y
def array_absolute_value(M, output_dim = ()):
X = np.linspace(-1,1, M)
if len(output_dim) == 0:
x0 = 0
return X.reshape(-1,1), np.abs(X - x0)
else:
x0 = np.linspace(-0.5,.5,int(np.prod(output_dim)))
Y = np.zeros((M, *output_dim))
for j, idx in enumerate(np.ndindex(output_dim)):
Y[(slice(M), *idx)] = np.abs(X - x0[j])
return X.reshape(-1,1), Y
def random_data(M, dim, complex_, seed, output_dim = ()):
np.random.seed(seed)
X = np.random.randn(M, dim)
y = np.random.randn(M, *output_dim)
if complex_:
X = X + 1j*np.random.randn(M, dim)
y = y + 1j*np.random.randn(M, *output_dim)
return X, y