-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
119 lines (81 loc) · 3.17 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os, yaml, argparse, torch
from tokenizers import Tokenizer
from tokenizers.processors import TemplateProcessing
from module import (
load_dataloader,
load_model,
Trainer,
Tester,
Translator
)
def set_seed(SEED=42):
import random
import numpy as np
import torch.backends.cudnn as cudnn
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
cudnn.benchmark = False
cudnn.deterministic = True
class Config(object):
def __init__(self, args):
with open('config.yaml', 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
for group in params.keys():
for key, val in params[group].items():
setattr(self, key, val)
self.mode = args.mode
self.sampling = args.sampling
self.search_method = args.search
device_type = 'cuda' if torch.cuda.is_available() and \
self.mode != 'inference' else 'cpu'
self.device_type = device_type
self.device = torch.device(device_type)
self.tokenizer_path = 'data/tokenizer.json'
if self.sampling is None:
self.ckpt = 'ckpt/base_model.pt'
self.generate_config = None
else:
self.ckpt = f'ckpt/{self.sampling}_sampled_model.pt'
self.generate_kwargs = self.generate_kwargs[self.sampling]
def print_attr(self):
for attribute, value in self.__dict__.items():
print(f"* {attribute}: {value}")
def load_tokenizer(config):
assert os.path.exists(config.tokenizer_path)
tokenizer = Tokenizer.from_file(config.tokenizer_path)
tokenizer.post_processor = TemplateProcessing(
single=f"{config.bos_token} $A {config.eos_token}",
special_tokens=[(config.bos_token, config.bos_id),
(config.eos_token, config.eos_id)]
)
return tokenizer
def main(args):
set_seed()
config = Config(args)
model = load_model(config)
tokenizer = load_tokenizer(config)
if config.mode == 'train':
train_dataloader = load_dataloader(config, tokenizer, 'train')
valid_dataloader = load_dataloader(config, tokenizer, 'valid')
trainer = Trainer(config, model, train_dataloader, valid_dataloader)
trainer.train()
elif config.mode == 'test':
test_dataloader = load_dataloader(config, tokenizer, 'test')
tester = Tester(config, model, tokenizer, test_dataloader)
tester.test()
elif config.mode == 'inference':
translator = Translator(config, model, tokenizer)
translator.translate()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-mode', required=True)
parser.add_argument('-sampling', default='None', required=True)
parser.add_argument('-search', default='greedy', required=False)
args = parser.parse_args()
assert args.mode.lower() in ['train', 'test', 'inference']
assert args.sampling.lower() in ['none', 'greedy', 'beam', 'topk']
assert args.search.lower() in ['greedy', 'beam']
main(args)