-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pu.py
138 lines (108 loc) · 4.86 KB
/
train_pu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
'''
* User: Hojun Lim
* Date: 2020-06-05
'''
from torch import optim
from utils.train_utils import Trainer, load_checkpoint
from utils.eval_utils import model_fn_for_pu, model_fn_eval, eval, eval_with_training_dataset
from utils.dataset import *
from utils.log_utils import create_tb_logger
from utils.utils import *
if __name__ == "__main__":
# some cfgs, some cfg will be used in the future
#torch.autograd.set_detect_anomaly(True)
# TODO::put all kinds of cfgs and hyperparameter into a config file. e.g. yaml
cfg = {}
cfg["ckpt"] = None
cfg["num_epochs"] = 150
cfg["ckpt_save_interval"] = 10
cfg["batch_size"] = 100
cfg["grad_norm_clip"] = None
cfg["num_networks"] = 10
learning_rate = 0.001 # 0.001 makes every dataset trainable.
#TODO: Simplifiy and automate the process
#Create directory for storing results
output_dirs = {}
output_dirs["boston"] = []
output_dirs["concrete"] = []
output_dirs["energy"] = []
output_dirs["kin8nm"] = []
output_dirs["naval"] = []
output_dirs["power_plant"] = []
output_dirs["protein"] = []
output_dirs["wine"] = []
output_dirs["yacht"] = []
# output_dirs["year"] = []
for key, output_dir in output_dirs.items():
output_dirs[key] = os.path.join('./output_pu', key, 'parametric_uncertainty')
ckpt_dirs = {}
for key, output_dir in output_dirs.items():
ckpt_dirs[key] = os.path.join(output_dir, 'ckpts')
os.makedirs(ckpt_dirs[key], exist_ok=True)
data_dirs = {}
for key, val in output_dirs.items():
data_dirs[key] = os.path.join("./data", key)
data_files = {}
for key, _ in data_dirs.items():
data_files[key] = ["{}_train.csv".format(key), "{}_eval.csv".format(key), "{}_test.csv".format(key)]
train_datasets = {}
train_loaders = {}
eval_datasets = {}
eval_loaders = {}
print("Prepare training data")
for key, fname in data_files.items():
train_datasets[key] = UCIDataset(os.path.join(data_dirs[key], fname[0]))
train_loaders[key] = torch.utils.data.DataLoader(train_datasets[key],
batch_size=cfg["batch_size"],
num_workers=0,
collate_fn=train_datasets[key].collate_batch)
eval_datasets[key] = UCIDataset(os.path.join(data_dirs[key], fname[0]), testing=True)
eval_loaders[key] = torch.utils.data.DataLoader(eval_datasets[key],
batch_size=cfg["batch_size"],
num_workers=0,
collate_fn=eval_datasets[key].collate_batch)
#Prepare model
print("Prepare model")
from model.pu_fc import pu_fc, pu_fc2
models = {}
for key, dataset in train_datasets.items():
if key in ["protein", "year"]:
models[key] = pu_fc2(dataset.input_dim)
else:
models[key] = pu_fc(dataset.input_dim)
models[key].cuda()
#Prepare training
print("Prepare training")
optimizers = {}
for key, model in models.items():
optimizers[key] = optim.Adam(model.parameters(), lr=learning_rate)
starting_iteration, starting_epoch = 0, 0
#Logging
tb_loggers = {}
for key, val in output_dirs.items():
tb_loggers[key] = create_tb_logger(val)
#Training
print("Start training")
trainers = {}
for key, model in models.items():
print("*******************************Training {}*******************************\n".format(key))
trainers[key] = Trainer(model=model,
model_fn=model_fn_for_pu,
model_fn_eval=model_fn_eval,
optimizer=optimizers[key],
ckpt_dir=ckpt_dirs[key],
output_dir=output_dirs[key],
title='pu_train_{}'.format(key),
grad_norm_clip=cfg["grad_norm_clip"],
tb_logger=tb_loggers[key])
trainers[key].train(num_epochs=cfg["num_epochs"],
train_loader=train_loaders[key],
eval_loader=eval_loaders[key],
# eval_loader=None,
ckpt_save_interval=cfg["ckpt_save_interval"],
starting_iteration=starting_iteration,
starting_epoch=starting_epoch)
draw_loss_trend_figure(key, len(trainers[key].train_loss), trainers[key].train_loss, output_dir=output_dirs[key])
print("*******************************Finished training {}*******************************\n".format(key))
#Finalizing
print("Training finished\n")