-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
99 lines (77 loc) · 2.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from importlib import import_module
from torchvision.transforms.functional import torch
import utils
from trainer.trainer import Trainer
packages = {
"dataloader_module": import_module(".dataloader", "dataloader"),
"model_module": import_module("model"),
"loss_module": import_module("model", "losses"),
"optimizer_module": import_module("torch.optim"),
"lr_scheduler_module": import_module("torch.optim.lr_scheduler"),
}
def main(config):
project_name = config["name"]
if config["deterministic"]:
utils.set_seed(config["seed"])
logger = config.get_logger("train")
train_dataloader = config.build("train_dataloader", packages["dataloader_module"])
validation_dataloader = config.build(
"validation_dataloader", packages["dataloader_module"]
)
logger.info(f"Dataloader build success.")
generator = config.build("model", packages["model_module"], submodule="generator")
discriminator = config.build(
"model", packages["model_module"], submodule="discriminator"
)
logger.info(f"Model build success.\n{generator}\n{discriminator}")
loss_function = config.build("loss", packages["loss_module"])
logger.info("Loss build success.")
optimizer_g = config.build(
"optimizer",
packages["optimizer_module"],
generator.parameters(),
submodule="optim_g",
)
optimizer_d = config.build(
"optimizer",
packages["optimizer_module"],
discriminator.parameters(),
submodule="optim_d",
)
logger.info("Optimizer build success.")
lr_scheduler = config.build(
"lr_scheduler", packages["lr_scheduler_module"], optimizer_g
)
if config["lr_scheduler"]["use_warmup"]:
warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer_g,
lr_lambda=lambda current_step: 1
/ (10 ** (float(config["lr_scheduler"]["warmup_epochs"] - current_step))),
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer_g,
[warmup_scheduler, lr_scheduler],
[config["lr_scheduler"]["warmup_epochs"]],
)
if lr_scheduler is not None:
logger.info("LR scheduler build success.")
trainer = Trainer(
project_name,
train_dataloader,
validation_dataloader,
generator,
discriminator,
loss_function,
optimizer_g,
optimizer_d,
lr_scheduler,
config["trainer"],
logger,
)
logger.info("Trainer build success.")
logger.info("Start training. Good luck !!")
trainer.train()
if __name__ == "__main__":
config_file, args = utils.parse_argument()
config = utils.ConfigParser(config_file, args)
main(config)