forked from HypoX64/candock
-
Notifications
You must be signed in to change notification settings - Fork 7
/
options.py
67 lines (56 loc) · 4.04 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import os
import numpy as np
import torch
#python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edf --signal_name 'EEG Fpz-Cz' --sample_num 8 --model_name multi_scale_resnet_1d --batchsize 32 --network_save_freq 100 --epochs 40 --lr 0.0005
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 200 --model_name resnet18 --batchsize 32 --epochs 10 --fold_num 5 --pretrained
#python3 train_new.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name LSTM --batchsize 32 --network_save_freq 100 --epochs 10
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name CinC_Challenge_2018 --signal_name C4-M1 --sample_num 10 --model_name resnet18 --batchsize 32
#filedir = '/media/hypo/Hypo/physionet_org_train'
# filedir ='E:\physionet_org_train'
#python3 train.py --dataset_name sleep-edf --model_name resnet50 --batchsize 4 --epochs 50 --pretrained
#'/media/hypo/Hypo/physionet_org_train'
class Options():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
self.parser.add_argument('--no_cuda', action='store_true', help='if input, do not use gpu')
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate')
self.parser.add_argument('--Cross_Validation', type=str, default='k_fold',help='k-fold')
self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold')
self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
help='your dataset path')
self.parser.add_argument('--dataset_name', type=str, default='sleep-edf',help='Choose dataset sleep-edf|sleep-edf|CinC_Challenge_2018|')
self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1|EEG Fpz-Cz')
self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.parser.add_argument('--model_name', type=str, default='lstm',help='Choose model')
self.parser.add_argument('--epochs', type=int, default=50,help='end epoch')
self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal')
self.parser.add_argument('--network_save_freq', type=int, default=5,help='the freq to save network')
self.initialized = True
def getparse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
if self.opt.dataset_name == 'sleep-edf':
self.opt.sample_num = 8
if self.opt.no_cuda:
self.opt.no_cudnn == True
# if self.opt.weight_mod == 'normal':
# weight = np.array([1,1,1,1,1])
# elif self.opt.weight_mod == 'avg_best':
# if self.opt.dataset_name == 'CinC_Challenge_2018':
# weight = np.log(1/np.array([0.15,0.3,0.08,0.13,0.18]))
# elif self.opt.dataset_name == 'sleep-edfx':
# weight = np.log(1/np.array([0.04,0.20,0.04,0.08,0.63]))
# elif self.opt.dataset_name == 'sleep-edf':
# weight = np.log(1/np.array([0.08,0.23,0.01,0.10,0.53]))
# if self.opt.select_sleep_time:
# weight = np.log(1/np.array([0.16,0.44,0.05,0.19,0.53]))
# self.opt.weight = weight
return self.opt