-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
69 lines (52 loc) · 2.21 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
Main file for creating simulated data or loading real data
and running MetaRegGNN and sample selection methods.
Usage:
For data processing:
python demo.py --mode data
For inferences:
python demo.py --mode infer
For more information:
python demo.py -h
'''
import argparse
import pickle
import torch
import numpy as np
import proposed_method.data_utils as data_utils
import evaluators
from config import Config
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['data', 'infer'],
help="Creates data and topological features OR make inferences on data")
opts = parser.parse_args()
if opts.mode == 'data':
'''
Connectome and scores are simulated to the folder specified in config.py.
'''
data_utils.create_dataset()
print(f"Data and topological features are created and saved at {Config.DATA_FOLDER} successfully.")
elif opts.mode == 'infer':
'''
Cross validation will be used to train and generate inferences
on the data saved in the folder specified in config.py.
Overall MAE and RMSE will be printed and predictions will be saved
in same data folder.
'''
#print(f"{opts.model} will be run on the data.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mae_evaluator = lambda p, s: np.mean(np.abs(p - s))
rmse_evaluator = lambda p, s: np.sqrt(np.mean((p - s) ** 2))
preds, scores, _ = evaluators.evaluate_MetaRegGNN(shuffle=Config.SHUFFLE, random_state=Config.MODEL_SEED,
dropout=Config.MetaRegGNN.DROPOUT,
lr=Config.MetaRegGNN.LR, wd=Config.MetaRegGNN.WD, device=device,
num_epoch=Config.MetaRegGNN.NUM_EPOCH)
print(f"MAE: {mae_evaluator(preds, scores):.3f}")
print(f"RMSE: {rmse_evaluator(preds, scores):.3f}")
with open(f"{Config.RESULT_FOLDER}preds.pkl", 'wb') as f:
pickle.dump(preds, f)
with open(f"{Config.RESULT_FOLDER}scores.pkl", 'wb') as f:
pickle.dump(scores, f)
print(f"Predictions are successfully saved at {Config.RESULT_FOLDER}.")
else:
raise Exception("Unknown argument.")