/
run_grbm_toy_ksd.py
57 lines (48 loc) · 1.89 KB
/
run_grbm_toy_ksd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from interface.runner import run_train_profile
import profiles.grbm.toy as profiles
from interface.utils import dict_utils, task_schedule
from multiprocessing import Process
from interface.utils.task_schedule import Task
import datetime
import os
def train():
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
prefix = "workspace/runner/grbm/toy_ksd/{}".format(now)
tasks = []
for seed in [1, 22, 77, 123, 666, 1234, 3333, 7777, 9090, 23333]:
# isksd
for n_particles in [2, 5, 10]:
profile = dict_utils.merge_dict(profiles.train_isksd, {
"seed": seed,
"workspace_root": os.path.join(prefix, "isksd%d_seed%d" % (n_particles, seed)),
"criterion": {
"kwargs": {
"n_particles": n_particles
}
}
})
p = Process(target=run_train_profile, args=(profile,))
tasks.append(Task(p, 1))
# vagesksd
for n_particles in [2, 5, 10]:
profile = dict_utils.merge_dict(profiles.train_vagesksd, {
"seed": seed,
"workspace_root": os.path.join(prefix, "vagesksd%d_seed%d" % (n_particles, seed)),
"criterion": {
"kwargs": {
"n_particles": n_particles
}
}
})
p = Process(target=run_train_profile, args=(profile,))
tasks.append(Task(p, 1))
# ksd
profile = dict_utils.merge_dict(profiles.train_ksd, {
"seed": seed,
"workspace_root": os.path.join(prefix, "ksd_seed%d" % seed)
})
p = Process(target=run_train_profile, args=(profile,))
tasks.append(Task(p, 1))
task_schedule.wait_schedule(tasks, devices=task_schedule.available_devices())
if __name__ == "__main__":
train()