-
Notifications
You must be signed in to change notification settings - Fork 414
/
matcher.py
executable file
·256 lines (243 loc) · 10.3 KB
/
matcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
# coding=utf-8
# Copyright (c) Microsoft. All rights reserved.
import os
import torch
import torch.nn as nn
from pretrained_models import MODEL_CLASSES
from module.dropout_wrapper import DropoutWrapper
from module.san import SANClassifier, MaskLmHeader
from module.san_model import SanModel
from module.pooler import Pooler
from torch.nn.modules.normalization import LayerNorm
from data_utils.task_def import EncoderModelType, TaskType
import tasks
from experiments.exp_def import TaskDef
def generate_decoder_opt(enable_san, max_opt):
opt_v = 0
if enable_san and max_opt < 2:
opt_v = max_opt
return opt_v
class SANBertNetwork(nn.Module):
def __init__(self, opt, bert_config=None, initial_from_local=False):
super(SANBertNetwork, self).__init__()
self.dropout_list = nn.ModuleList()
if opt["encoder_type"] not in EncoderModelType._value2member_map_:
raise ValueError("encoder_type is out of pre-defined types")
self.encoder_type = opt["encoder_type"]
self.preloaded_config = None
literal_encoder_type = EncoderModelType(self.encoder_type).name.lower()
config_class, model_class, _ = MODEL_CLASSES[literal_encoder_type]
if not initial_from_local:
# self.bert = model_class.from_pretrained(opt['init_checkpoint'], config=self.preloaded_config)
self.bert = model_class.from_pretrained(
opt["init_checkpoint"], cache_dir=opt["transformer_cache"]
)
else:
self.preloaded_config = config_class.from_dict(opt) # load config from opt
self.preloaded_config.output_hidden_states = (
True # return all hidden states
)
self.bert = model_class(self.preloaded_config)
hidden_size = self.bert.config.hidden_size
if opt.get("dump_feature", False):
self.config = opt
return
if opt["update_bert_opt"] > 0:
for p in self.bert.parameters():
p.requires_grad = False
task_def_list = opt["task_def_list"]
self.task_def_list = task_def_list
self.decoder_opt = []
self.task_types = []
for task_id, task_def in enumerate(task_def_list):
self.decoder_opt.append(
generate_decoder_opt(task_def.enable_san, opt["answer_opt"])
)
self.task_types.append(task_def.task_type)
# create output header
self.scoring_list = nn.ModuleList()
self.dropout_list = nn.ModuleList()
for task_id in range(len(task_def_list)):
task_def: TaskDef = task_def_list[task_id]
lab = task_def.n_class
decoder_opt = self.decoder_opt[task_id]
task_type = self.task_types[task_id]
task_dropout_p = (
opt["dropout_p"] if task_def.dropout_p is None else task_def.dropout_p
)
dropout = DropoutWrapper(task_dropout_p, opt["vb_dropout"])
self.dropout_list.append(dropout)
task_obj = tasks.get_task_obj(task_def)
if task_obj is not None:
# Move this to task_obj
self.pooler = Pooler(
hidden_size, dropout_p=opt["dropout_p"], actf=opt["pooler_actf"]
)
out_proj = task_obj.train_build_task_layer(
decoder_opt, hidden_size, lab, opt, prefix="answer", dropout=dropout
)
elif task_type == TaskType.Span:
assert decoder_opt != 1
out_proj = nn.Linear(hidden_size, 2)
elif task_type == TaskType.SpanYN:
assert decoder_opt != 1
out_proj = nn.Linear(hidden_size, 2)
elif task_type == TaskType.SeqenceLabeling:
out_proj = nn.Linear(hidden_size, lab)
elif task_type == TaskType.MaskLM:
if opt["encoder_type"] == EncoderModelType.ROBERTA:
# TODO: xiaodl
out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight)
else:
out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight)
elif task_type == TaskType.SeqenceGeneration:
# use orginal header
out_proj = None
else:
if decoder_opt == 1:
out_proj = SANClassifier(
hidden_size,
hidden_size,
lab,
opt,
prefix="answer",
dropout=dropout,
)
else:
out_proj = nn.Linear(hidden_size, lab)
self.scoring_list.append(out_proj)
self.config = opt
def embed_encode(self, input_ids, token_type_ids=None, attention_mask=None):
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
embedding_output = self.bert.embeddings(input_ids, token_type_ids)
return embedding_output
def encode(
self,
input_ids,
token_type_ids,
attention_mask,
inputs_embeds=None,
y_input_ids=None,
):
if self.encoder_type == EncoderModelType.T5:
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
last_hidden_state = outputs.last_hidden_state
all_hidden_states = outputs.hidden_states # num_layers + 1 (embeddings)
elif self.encoder_type == EncoderModelType.T5G:
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=y_input_ids,
)
# return logits from LM header
last_hidden_state = outputs.logits
all_hidden_states = (
outputs.encoder_last_hidden_state
) # num_layers + 1 (embeddings)
else:
outputs = self.bert(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
)
last_hidden_state = outputs.last_hidden_state
all_hidden_states = outputs.hidden_states # num_layers + 1 (embeddings)
return last_hidden_state, all_hidden_states
def forward(
self,
input_ids,
token_type_ids,
attention_mask,
premise_mask=None,
hyp_mask=None,
task_id=0,
y_input_ids=None,
fwd_type=0,
embed=None,
):
if fwd_type == 3:
generated = self.bert.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=self.config["max_answer_len"],
num_beams=self.config["num_beams"],
repetition_penalty=self.config["repetition_penalty"],
length_penalty=self.config["length_penalty"],
early_stopping=True,
)
return generated
elif fwd_type == 2:
assert embed is not None
last_hidden_state, all_hidden_states = self.encode(
None, token_type_ids, attention_mask, embed, y_input_ids
)
elif fwd_type == 1:
return self.embed_encode(input_ids, token_type_ids, attention_mask)
else:
last_hidden_state, all_hidden_states = self.encode(
input_ids, token_type_ids, attention_mask, y_input_ids=y_input_ids
)
decoder_opt = self.decoder_opt[task_id]
task_type = self.task_types[task_id]
task_obj = tasks.get_task_obj(self.task_def_list[task_id])
if task_obj is not None:
pooled_output = self.pooler(last_hidden_state)
logits = task_obj.train_forward(
last_hidden_state,
pooled_output,
premise_mask,
hyp_mask,
decoder_opt,
self.dropout_list[task_id],
self.scoring_list[task_id],
)
return logits
elif task_type == TaskType.Span:
assert decoder_opt != 1
last_hidden_state = self.dropout_list[task_id](last_hidden_state)
logits = self.scoring_list[task_id](last_hidden_state)
start_scores, end_scores = logits.split(1, dim=-1)
start_scores = start_scores.squeeze(-1)
end_scores = end_scores.squeeze(-1)
return start_scores, end_scores
elif task_type == TaskType.SpanYN:
assert decoder_opt != 1
last_hidden_state = self.dropout_list[task_id](last_hidden_state)
logits = self.scoring_list[task_id](last_hidden_state)
start_scores, end_scores = logits.split(1, dim=-1)
start_scores = start_scores.squeeze(-1)
end_scores = end_scores.squeeze(-1)
return start_scores, end_scores
elif task_type == TaskType.SeqenceLabeling:
pooled_output = last_hidden_state
pooled_output = self.dropout_list[task_id](pooled_output)
pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2))
logits = self.scoring_list[task_id](pooled_output)
return logits
elif task_type == TaskType.MaskLM:
last_hidden_state = self.dropout_list[task_id](last_hidden_state)
logits = self.scoring_list[task_id](last_hidden_state)
return logits
elif task_type == TaskType.SeqenceGeneration:
logits = last_hidden_state.view(-1, last_hidden_state.size(-1))
return logits
else:
if decoder_opt == 1:
max_query = hyp_mask.size(1)
assert max_query > 0
assert premise_mask is not None
assert hyp_mask is not None
hyp_mem = last_hidden_state[:, :max_query, :]
logits = self.scoring_list[task_id](
last_hidden_state, hyp_mem, premise_mask, hyp_mask
)
else:
pooled_output = self.dropout_list[task_id](pooled_output)
logits = self.scoring_list[task_id](pooled_output)
return logits