-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
63 lines (54 loc) · 1.74 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import numpy
import random
from copy import deepcopy
from src.systems.recognition import CTC_System
from src.systems.setup import process_config
from src.utils import load_json
from quantization.utils import quantize_ctc_system
import pytorch_lightning as pl
import wandb
def run(config_path, gpu_device=-1):
config = process_config(config_path)
if gpu_device >= 0:
config.gpu_device = gpu_device
seed_everything(config.seed)
SystemClass = globals()[config.system]
system = SystemClass(config)
if config.quant_params.noise_rate > 0:
quantize_ctc_system(system, config)
ckpt_callback = pl.callbacks.ModelCheckpoint(
os.path.join(config.exp_dir, 'checkpoints'),
save_top_k=-1,
period=1,
)
wandb.init(
project='speech',
entity='lyronctk',
name=config.exp_name,
config=config,
sync_tensorboard=True,
)
trainer = pl.Trainer(
default_root_dir=config.exp_dir,
gpus=([config.gpu_device] if config.cuda else None),
max_epochs=config.num_epochs,
min_epochs=config.num_epochs,
checkpoint_callback=ckpt_callback,
resume_from_checkpoint=config.continue_from_checkpoint
)
trainer.fit(system)
def seed_everything(seed):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
numpy.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config', type=str, default='path to config file')
parser.add_argument('--gpu-device', type=int, default=-1)
args = parser.parse_args()
run(args.config, gpu_device=args.gpu_device)