-
Notifications
You must be signed in to change notification settings - Fork 6
/
web_run.py
127 lines (94 loc) · 3.78 KB
/
web_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import argparse
import pickle
import os
import torch
import yaml
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from corpus.BratWriter import Writer, BratFile
from corpus.InferenceDataset import InferenceDataset
from corpus.WLPDataset import WLPDataset
def multi_batchify(samples):
samples = sorted(samples, key=lambda s: len(s.SENT), reverse=True)
SENT, X, C = zip(*[(sample.SENT, sample.X, sample.C) for sample in samples])
return SENT, X, C
def argmax(var):
assert isinstance(var, Variable)
_, preds = torch.max(var.data, 1)
preds = preds.cpu().numpy().tolist()
return preds
def write_brat(sents, pred_list, save_path):
print("Writing Brat File ...")
bratfile = BratFile(save_path, "brat")
for sent in sents:
bratfile.writer(sent, pred_list, "brat", ignore_label=[])
def to_variables(X, C, lm_vocab_size):
x_var = X
c_var = C
lm_x = [[lm_vocab_size - 1 if (x >= lm_vocab_size) else x for x in x1d] for x1d in X]
return x_var, c_var, lm_x
def roll(pred, seq_lengths):
# converts 1d list to 2d list
ret = []
start = 0
for seq_l in seq_lengths:
ret.append(pred[start:start + seq_l])
start += seq_l
return ret
def test(name, data, tag_idx, model, lm_vocab_size, char_level):
pred_list = []
sents = []
for SENT, X, C in tqdm(data, desc=name, total=len(data)):
np.set_printoptions(threshold=np.nan)
model.init_state(len(X))
x_var, c_var, lm_x = to_variables(X=X, C=C, lm_vocab_size=lm_vocab_size)
if char_level == "Attention":
lm_f_out, lm_b_out, seq_out, seq_lengths, emb, char_emb = model(x_var, c_var)
else:
lm_f_out, lm_b_out, seq_out, seq_lengths = model(x_var, c_var)
pred = argmax(seq_out)
preds = roll(pred, seq_lengths)
for pred, sent in zip(preds, SENT):
pred_list.append(pred[1:-1])
sents.append(sent)
return sents, pred_list
def inference(p_txt, cfg):
model_save_path = cfg['MODEL_SAVE_PATH']
print("Loading Dataset ...")
corpus = pickle.load(open(cfg["CORPUS_FILE"], "rb"))
dataset = InferenceDataset(p_txt=p_txt,
word_index=corpus.word_index,
char_index=corpus.char_index,
is_oov=corpus.is_oov,
sent_start=cfg['SENT_START'],
sent_end=cfg['SENT_END'],
word_start=cfg['WORD_START'],
word_end=cfg['WORD_END'],
unk=cfg['UNK'])
data_loader = DataLoader(dataset, batch_size=cfg['BATCH_SIZE'], num_workers=8, collate_fn=multi_batchify)
print("Loading Model ...")
the_model = torch.load(model_save_path)
print("Testing ...")
sents, pred_list = test("test", data_loader, corpus.tag_idx, the_model, cfg['LM_VOCAB_SIZE'], cfg['CHAR_LEVEL'])
brat_writer = Writer(cfg['CONF_DIR'], cfg['BRAT_SAVE_PATH'], "full_out", corpus.tag_idx)
print(sents, pred_list)
sents = dataset.undo_sort(sents)
pred_list = dataset.undo_sort(pred_list)
brat_writer.gen_one_file(sents, pred_list, cfg['BRAT_SAVE_PATH'], "brat")
def init_args():
parser = argparse.ArgumentParser(description='Sequence labeler.')
parser.add_argument('--yaml', type=str,
help='config file path')
args = parser.parse_args()
return args
def parse_yaml(cfg_path):
with open(cfg_path, 'r') as stream:
return yaml.load(stream)
if __name__ == '__main__':
args = init_args()
yaml_cfg = parse_yaml(args.yaml)
with open(yaml_cfg['SAMPLE_PROTOCOL_FILE'], 'r') as p_f:
p_text = p_f.read()
inference(p_text, yaml_cfg)