/
main.py
executable file
·152 lines (130 loc) · 5.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import torch
import numpy as np
import yaml
import datetime
from socket import gethostname
from pathlib import Path
from common.flags import flags
from common.utils import (
build_config,
create_grid,
save_experiment_log,
)
from base_runner import MoserRunner, FFJORDRunner
from sweep import create_sweep_jobs, VanillaJobRunner
import waic
import submitit
CONFIGURATION_FILENAME = "configuration.yaml"
CHECKPOINT_NAME = "checkpoint.pt"
RUNNER_CLASSES = {
"moser": MoserRunner,
"ffjord": FFJORDRunner
}
def get_run_dir(config):
run_dir = config["cmd"]["run_dir"]
if run_dir is not None:
return Path(run_dir)
identifier = config["cmd"].get("identifier")
run_dir = os.path.join(os.getcwd(), "_experiments", config["dataset"]["type"])
if not config["cmd"]["continue_saved"]:
timestamp = torch.tensor(datetime.datetime.now().timestamp()).to(
dtype=int
)
timestamp = datetime.datetime.fromtimestamp(timestamp).strftime(
"%Y-%m-%d-%H-%M-%S"
)
if identifier is not None:
dir_name = "%s_%s" %(timestamp, identifier)
else:
dir_name = timestamp
run_dir = os.path.join(run_dir, dir_name)
return Path(run_dir)
def get_last_checkpoint(checkpoint_dir):
checkpoints = os.listdir(checkpoint_dir)
if len(checkpoints) == 1:
return os.path.join(checkpoint_dir, checkpoints[0])
epochs = [float(re.match("checkpoint_(\d+\.?\d*).pt", checkpoint).groups()[0]) for checkpoint in checkpoints]
max_epoch_index = np.argmax(epochs)
return os.path.join(checkpoint_dir, checkpoints[max_epoch_index])
def run(config):
run_dir = get_run_dir(config)
model_name = config["model"].get("name", "moser")
runner_class = RUNNER_CLASSES[model_name]
runner = runner_class(config, run_dir)
run_dir = os.path.dirname(runner.config["cmd"]["checkpoint_dir"])
with open(os.path.join(run_dir, CONFIGURATION_FILENAME), "w") as f:
yaml.safe_dump(config, f)
# Load model
if config["cmd"]["continue_saved"]:
checkpoint_path = get_last_checkpoint(runner.config["cmd"]["checkpoint_dir"])
print("continuing from checkpoint %s" %checkpoint_path)
runner.load_pretrained(checkpoint_path)
if config["cmd"]["checkpoint"] is not None:
runner.load_pretrained(config["cmd"]["checkpoint"])
# Train model
runner.start()
if config["cmd"]["mode"] == "train":
try:
runner.train()
except KeyboardInterrupt:
runner.finalize()
# Test model
if config["cmd"]["mode"] == "validate":
runner.train_loader.dataset.initial_plots(runner.config["cmd"]["results_dir"], model=runner.model)
runner.validate(split="test")
return runner
class Runner(submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
self.chkpt_path = None
def __call__(self, config):
run(config)
def main():
parser = flags.get_parser()
args = parser.parse_args()
if not args.config_yml and not args.checkpoint and not args.continue_saved:
raise ValueError("either config-yml or checkpoint needs to be given")
if args.checkpoint:
checkpoint_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(os.path.dirname(checkpoint_dir), CONFIGURATION_FILENAME)
args.config_yml = config_path
args.run_dir = os.path.join(os.path.dirname(checkpoint_dir) + "_continued", datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
if args.continue_saved:
config_path = os.path.join(args.run_dir, CONFIGURATION_FILENAME)
args.config_yml = config_path
config = build_config(args)
if args.submit == 'vanilla': # Run on cluster
run_dir = os.path.join(os.getcwd(), "_experiments", "sweep_runs")
jobs = create_sweep_jobs(args, run_dir)
runner = VanillaJobRunner()
device = args.local_rank if args.local_rank is not None else torch.cuda.device_count() - len(jobs)
for i, job in enumerate(jobs):
job.params["local_rank"] = device + i
for job in jobs:
runner.run_job(job)
elif args.submit == 'submitit':
if args.sweep_yml: # Run grid search
configs = create_grid(config, args.sweep_yml)
else:
configs = [config]
print(f"Submitting {len(configs)} jobs")
executor = submitit.AutoExecutor(folder=os.path.join(args.run_dir, "slurm", "%j"))
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(args.num_workers + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
nodes=args.num_nodes,
)
jobs = executor.map_array(Runner(), configs)
print("Submitted jobs:", ", ".join([job.job_id for job in jobs]))
log_file = save_experiment_log(args, jobs, configs)
print(f"Experiment log saved to: {log_file}")
else: # Run locally
Runner()(config)
if __name__ == "__main__":
main()