-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_dpr.py
111 lines (95 loc) · 4.12 KB
/
train_dpr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import datetime
import logging
import os
import sys
import pytz
import torch
from omegaconf import OmegaConf
from transformers import AutoTokenizer, TrainingArguments, set_seed
import wandb
from dataset.DPR_Dataset import DenseRetrievalTrainDataset, DenseRetrievalValidDataset
from model.Retrieval.BertEncoder import BertEncoder
from model.Retrieval.RobertaEncoder import RobertaEncoder
from trainer.DenseRetrievalTrainer import DenseRetrievalTrainer
logger = logging.getLogger(__name__)
def main(config):
config = OmegaConf.load(f"./config/{args.config}.yaml")
# wandb 설정
now_time = datetime.datetime.now(pytz.timezone("Asia/Seoul")).strftime("%m-%d-%H-%M")
run_id = f"{config.wandb.name}_{now_time}"
wandb.init(
entity=config.wandb.team,
project=config.wandb.project,
group=config.model.name_or_path,
id=run_id,
tags=config.wandb.tags,
)
config.dense.train.update(config.dense.optimizer)
if config.dense.train.output_dir is None:
config.dense.train.output_dir = os.path.join("saved_models/DPR", config.dense.model.name_or_path, run_id)
# logging 설정
if not os.path.exists("./logs"):
os.makedirs("./logs")
with open("./logs/DPR_logs.log", "w+") as f:
f.write("***** Log file Start *****\n")
LOG_FORMAT = "%(asctime)s - %(message)s"
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
file_handler = logging.FileHandler("./logs/DPR_logs.log", mode="a", encoding="utf-8")
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
logger.addHandler(file_handler)
# verbosity 설정 : Transformers logger의 정보로 사용합니다 (on main process only)
logger.info("config", config)
# 모델을 초기화하기 전에 난수를 고정합니다.
set_seed(config.utils.seed)
# 토크나이저
tokenizer = AutoTokenizer.from_pretrained(config.dense.model.name_or_path)
# 데이터셋
train_dataset = DenseRetrievalTrainDataset(
data_path=config.dense.path.train,
max_context_length=config.dense.tokenizer.max_context_length,
max_question_length=config.dense.tokenizer.max_question_length,
tokenizer=tokenizer,
hard_negative=config.dense.train.hard_negative,
)
valid_dataset = DenseRetrievalValidDataset(
data_path=config.dense.path.valid,
max_context_length=config.dense.tokenizer.max_context_length,
tokenizer=tokenizer,
)
logger.info(f" train_dataset: {len(train_dataset)} | valid_dataset: {len(valid_dataset)}")
# 모델
logger.info(f" Encoder model: {config.dense.model.name_or_path}")
if config.dense.tokenizer.return_token_type_ids == True:
p_encoder = RobertaEncoder.from_pretrained(config.dense.model.name_or_path)
q_encoder = RobertaEncoder.from_pretrained(config.dense.model.name_or_path)
else:
p_encoder = BertEncoder.from_pretrained(config.dense.model.name_or_path)
q_encoder = BertEncoder.from_pretrained(config.dense.model.name_or_path)
if torch.cuda.is_available():
p_encoder.cuda()
q_encoder.cuda()
# 학습
training_args = TrainingArguments(
output_dir=config.dense.train.output_dir,
evaluation_strategy="epoch",
learning_rate=config.dense.optimizer.learning_rate,
per_device_train_batch_size=config.dense.train.batch_size,
per_device_eval_batch_size=config.dense.train.batch_size,
num_train_epochs=config.dense.train.num_train_epochs,
weight_decay=config.dense.optimizer.weight_decay,
gradient_accumulation_steps=config.dense.optimizer.gradient_accumulation_steps,
)
training_args.report_to = ["wandb"]
trainer = DenseRetrievalTrainer(training_args, config, tokenizer, p_encoder, q_encoder, train_dataset, valid_dataset)
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", "-c", type=str, default="custom_config")
args, _ = parser.parse_known_args()
main(args)