-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
59 lines (45 loc) · 1.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import sys
from copy import deepcopy
import torch
import torch.nn as nn
from common import parse_args
from networks import create_network
from apprs import create_method
from datasets import create_dataset
from utils.utils import Logger, Criterion
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
args = parse_args()
args.logger = Logger(args, args.folder)
args.logger.now()
args.device = device
args.logger.print('\n\n',
os.uname()[1] + ':' + os.getcwd(),
'python', ' '.join(sys.argv),
'\n\n')
args.logger.print('\n', args, '\n')
args.criterion = Criterion(args)
train_data, test_data = create_dataset(args)
args.net = create_network(args)
model = create_method(args)
if all(
[
args.test_id is None,
]
):
from row_pipeline import RowPipeline as Pipeline
pipeline = Pipeline(args, train_data, test_data, model)
args.logger.print("\nTraining starts\n")
pipeline.train_all()
elif args.test_id is not None:
print("Testing")
from base_pipeline import BasePipeline as Pipeline
pipeline = Pipeline(args, train_data, test_data, model)
for task_id in range(args.test_id + 1):
pipeline.load_task_MD_stats(task_id)
pipeline.preprocess_task(task_id)
pipeline.load_train_step(task_id)
pipeline.load_model_step(task_id)
pipeline.test_auc(task_id, epoch=args.n_epochs)
pipeline.print_results(task_id)