-
Notifications
You must be signed in to change notification settings - Fork 0
/
split_data.py
73 lines (49 loc) · 1.74 KB
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import h5py
import tqdm
import numpy as np
# ----------------------------------------------------------------------
fname = 'tomo_data.h5'
print('Reading:', fname)
f = h5py.File(fname, 'r')
pulses = np.array(sorted(f.keys()))
print('pulses:', len(pulses))
# ----------------------------------------------------------------------
N = 10
r = np.arange(len(pulses))
i_train = r[(r % N) != N-1]
i_valid = r[(r % N) == N-1]
train_pulses = pulses[i_train]
valid_pulses = pulses[i_valid]
print('train_pulses:', len(train_pulses))
print('valid_pulses:', len(valid_pulses))
# ----------------------------------------------------------------------
def get_XY(pulses):
X = []
Y = []
for pulse in tqdm.tqdm(pulses):
g = f[pulse]
bolo = np.clip(g['bolo'][:], 0., None) / 1e6
tomo = np.clip(g['tomo'][:], 0., None) / 1e6
X.append(bolo)
Y.append(tomo)
X = np.concatenate(X, axis=0)
Y = np.concatenate(Y, axis=0)
return X, Y
# ----------------------------------------------------------------------
X_train, Y_train = get_XY(train_pulses)
print('X_train:', X_train.shape, X_train.dtype)
print('Y_train:', Y_train.shape, Y_train.dtype)
# ----------------------------------------------------------------------
X_valid, Y_valid = get_XY(valid_pulses)
print('X_valid:', X_valid.shape, X_valid.dtype)
print('Y_valid:', Y_valid.shape, Y_valid.dtype)
# ----------------------------------------------------------------------
def save(fname, array):
print('Writing:', fname)
np.save(fname, array)
save('X_train.npy', X_train)
save('Y_train.npy', Y_train)
save('X_valid.npy', X_valid)
save('Y_valid.npy', Y_valid)
# ----------------------------------------------------------------------
f.close()