-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
47 lines (38 loc) · 1.66 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate dataset for ascend 310"""
import os
import numpy as np
from sklearn import preprocessing
from src import dataloader, config
from src.argparser import arg_parser
args = arg_parser()
cfg = config.stgcn_chebconv_45min_cfg
cfg.batch_size = 1
if __name__ == "__main__":
zscore = preprocessing.StandardScaler()
dataset = dataloader.create_dataset(args.data_url+args.data_path, cfg.batch_size,
cfg.n_his, cfg.n_pred, zscore, True, mode=2)
img_path = os.path.join(args.result_path, "00_data")
os.mkdir(img_path)
label_list = []
# dataset is an instance of Dataset object
iterator = dataset.create_dict_iterator(output_numpy=True)
for i, data in enumerate(iterator):
file_name = "STGCN_data_bs" + str(cfg.batch_size) + "_" + str(i) + ".bin"
file_path = img_path + "/" + file_name
data['inputs'].tofile(file_path)
label_list.append(data['labels'])
np.save(args.result_path + "label_ids.npy", label_list)
print("="*20, "export bin files finished", "="*20)