-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
49 lines (40 loc) · 1.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import traceback
from omegaconf import OmegaConf
from torch import manual_seed
from pytorch_lightning import Trainer
from utils import (
get_curr_time_w_random_shift,
maybe_save_checkpoint,
init_log_directory,
get_logger,
get_callbacks,
)
from vggishishmodule import VGGishishModule
from datamodule import GreatesHitDataModule
def train(cfg: OmegaConf):
print(cfg.start_time, cfg.get("exp_name", ""))
manual_seed(cfg.get("seed", 666))
log_dir, ckpt_dir = init_log_directory(cfg.start_time, cfg.log_dir)
model = VGGishishModule(**cfg.model)
datamodule = GreatesHitDataModule(**cfg.dataloader)
trainer = Trainer(
callbacks=get_callbacks(ckpt_dir),
logger=get_logger(log_dir, name=cfg.trainer.get("exp_name", "")),
**cfg.trainer
)
try:
trainer.fit(model=model, datamodule=datamodule)
except BaseException as e:
print(e)
traceback.print_exc()
maybe_save_checkpoint(trainer)
if __name__ == "__main__":
args = OmegaConf.from_cli()
config = OmegaConf.load(args.get("config", "./configs/vggishish.yaml"))
config = OmegaConf.merge(config, args)
if "start_time" not in config or config.start_time is None:
config.start_time = get_curr_time_w_random_shift()
OmegaConf.resolve(
config
) # things like "${model.size}" in cfg will be resolved into values
train(config)