-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
120 lines (94 loc) · 2.94 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
import pytorch_lightning as pl
import sys
from torch.utils.data import DataLoader
from baseline.data import ArgsBase
from baseline.data import OLIDataModule
from baseline.data import OLIDataset
from baseline.model import ClassificationModule
from utils import find_best_ckpt
def test(test_file, args):
test_dataset = OLIDataset(
filepath=test_file,
enc_model=args.bert,
max_seq_len=args.max_seq_len
)
test_dataloader = DataLoader(test_dataset,
batch_size=args.batch_size,
num_workers=5, shuffle=False)
if args.load_from is None:
# random initialization
SEED = 123
pl.seed_everything(SEED)
task_model = ClassificationModule(args=args)
elif args.load_from.endswith('ckpt'):
print(f'Loaded model from {args.load_from}')
task_model = ClassificationModule.load_from_checkpoint(checkpoint_path=args.load_from,
args=args, strict=False)
else:
best_ckpt = find_best_ckpt(args.load_from, metric=f'val_{args.best}')
print(f'Loaded model from {best_ckpt}')
task_model = ClassificationModule.load_from_checkpoint(checkpoint_path=best_ckpt,
args=args, strict=False)
task_model.eval()
task_model.freeze()
trainer = pl.Trainer(
gpus=[args.device]
)
trainer.test(
model=task_model,
test_dataloaders=test_dataloader,
verbose=False,
)
def main(args):
# Load validation dataset
data_dir = os.path.join(args.data_dir, args.lang)
test_file = os.path.join(data_dir, args.val_file)
test(test_file, args)
# Load test dataset
test_file = os.path.join(data_dir, args.test_file)
test(test_file, args)
if __name__ == '__main__':
sys.path.append(
os.path.dirname(os.path.abspath(os.path.dirname("__file__")))
)
parser = argparse.ArgumentParser()
parser.add_argument(
'--bert',
type=str,
default='mbert',
help='pre-trained model to use: bert, kobert, mbert, xlm'
)
parser.add_argument(
'--lang',
type=str,
default='da',
help='task language: da, ko, en'
)
parser.add_argument(
'--device',
default=0,
type=int,
)
parser.add_argument(
'--load_from',
default=None,
type=str,
help='path to load model to resume training'
)
parser.add_argument(
'--batch_size',
type=int,
default=64
)
parser.add_argument(
'--best',
type=str,
default='f1'
)
parser = ArgsBase.add_model_specific_args(parser)
parser = ClassificationModule.add_model_specific_args(parser)
parser = OLIDataModule.add_model_specific_args(parser)
args = parser.parse_args()
main(args)