-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
46 lines (38 loc) · 1.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from argparse import ArgumentParser
from utils.model.trainer import ModelTrainingArguments, Trainer
from utils.data import load_dataset, DataCollator
from translation.transformer import ModelConfig, Transformer, Tokenizer
parser = ArgumentParser()
parser.add_argument("--model-config", default="model_config.json", type=str)
parser.add_argument("--data-config", default="data_config.json", type=str)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--init-lr", default=1e-4, type=float)
parser.add_argument("--train-data-dir", default="data/preload/PhoMT/train/", type=str)
parser.add_argument("--validation-data-dir", default="data/preload/PhoMT/dev/", type=str)
parser.add_argument("--previous-state-path", default=None, type=str)
ARGS = parser.parse_args()
if __name__ == "__main__":
model_config = ModelConfig.from_json(ARGS.model_config)
data_config = ModelConfig.from_json(ARGS.data_config)
tokenizer = Tokenizer(model_config)
train_dataset = load_dataset(ARGS.train_data_dir, data_config)
val_dataset = load_dataset(ARGS.validation_data_dir, data_config)
train_collator = DataCollator(tokenizer, data_config.mlm_probability)
val_collator = DataCollator(tokenizer, None)
model = Transformer(model_config)
train_arguments = ModelTrainingArguments(
epochs=ARGS.epochs,
init_lr=ARGS.init_lr,
device="cuda",
max_warmup_steps=200000
)
trainer = Trainer(model=model,
args=train_arguments,
train_dataset=train_dataset,
validation_dataset=val_dataset,
train_data_collator=train_collator,
validation_data_collator=val_collator)
previous_state_path = ARGS.previous_state_path
if previous_state_path is not None:
trainer.resume(previous_state_path)
trainer.train()