-
Notifications
You must be signed in to change notification settings - Fork 49
/
models.py
322 lines (272 loc) · 12.1 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import sys
sys.path += ['../']
import torch
from torch import nn
from transformers import (
RobertaConfig,
RobertaModel,
RobertaForSequenceClassification,
RobertaTokenizer,
BertModel,
BertTokenizer,
BertConfig
)
import torch.nn.functional as F
from data.process_fn import triple_process_fn, triple2dual_process_fn
from model.SEED_Encoder import SEEDEncoderConfig, SEEDTokenizer, SEEDEncoderForSequenceClassification,SEEDEncoderForMaskedLM
class EmbeddingMixin:
"""
Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward.
We inherit from RobertaModel to use from_pretrained
"""
def __init__(self, model_argobj):
if model_argobj is None:
self.use_mean = False
else:
self.use_mean = model_argobj.use_mean
print("Using mean:", self.use_mean)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.02)
def masked_mean(self, t, mask):
s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
def masked_mean_or_first(self, emb_all, mask):
# emb_all is a tuple from bert - sequence output, pooler
assert isinstance(emb_all, tuple)
if self.use_mean:
return self.masked_mean(emb_all[0], mask)
else:
return emb_all[0][:, 0]
def query_emb(self, input_ids, attention_mask):
raise NotImplementedError("Please Implement this method")
def body_emb(self, input_ids, attention_mask):
raise NotImplementedError("Please Implement this method")
class NLL(EmbeddingMixin):
def forward(
self,
query_ids,
attention_mask_q,
input_ids_a=None,
attention_mask_a=None,
input_ids_b=None,
attention_mask_b=None,
is_query=True):
if input_ids_b is None and is_query:
return self.query_emb(query_ids, attention_mask_q)
elif input_ids_b is None:
return self.body_emb(query_ids, attention_mask_q)
q_embs = self.query_emb(query_ids, attention_mask_q)
a_embs = self.body_emb(input_ids_a, attention_mask_a)
b_embs = self.body_emb(input_ids_b, attention_mask_b)
logit_matrix = torch.cat([(q_embs * a_embs).sum(-1).unsqueeze(1),
(q_embs * b_embs).sum(-1).unsqueeze(1)], dim=1) # [B, 2]
lsm = F.log_softmax(logit_matrix, dim=1)
loss = -1.0 * lsm[:, 0]
return (loss.mean(),)
class NLL_MultiChunk(EmbeddingMixin):
def forward(
self,
query_ids,
attention_mask_q,
input_ids_a=None,
attention_mask_a=None,
input_ids_b=None,
attention_mask_b=None,
is_query=True):
if input_ids_b is None and is_query:
return self.query_emb(query_ids, attention_mask_q)
elif input_ids_b is None:
return self.body_emb(query_ids, attention_mask_q)
q_embs = self.query_emb(query_ids, attention_mask_q)
a_embs = self.body_emb(input_ids_a, attention_mask_a)
b_embs = self.body_emb(input_ids_b, attention_mask_b)
[batchS, full_length] = input_ids_a.size()
chunk_factor = full_length // self.base_len
# special handle of attention mask -----
attention_mask_body = attention_mask_a.reshape(
batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor]
inverted_bias = ((1 - attention_mask_body) * (-9999)).float()
a12 = torch.matmul(
q_embs.unsqueeze(1), a_embs.transpose(
1, 2)) # [batch, 1, chunk_factor]
logits_a = (a12[:, 0, :] + inverted_bias).max(dim=-
1, keepdim=False).values # [batch]
# -------------------------------------
# special handle of attention mask -----
attention_mask_body = attention_mask_b.reshape(
batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor]
inverted_bias = ((1 - attention_mask_body) * (-9999)).float()
a12 = torch.matmul(
q_embs.unsqueeze(1), b_embs.transpose(
1, 2)) # [batch, 1, chunk_factor]
logits_b = (a12[:, 0, :] + inverted_bias).max(dim=-
1, keepdim=False).values # [batch]
# -------------------------------------
logit_matrix = torch.cat(
[logits_a.unsqueeze(1), logits_b.unsqueeze(1)], dim=1) # [B, 2]
lsm = F.log_softmax(logit_matrix, dim=1)
loss = -1.0 * lsm[:, 0]
return (loss.mean(),)
class RobertaDot_NLL_LN(NLL, RobertaForSequenceClassification):
"""None
Compress embedding to 200d, then computes NLL loss.
"""
def __init__(self, config, model_argobj=None):
NLL.__init__(self, model_argobj)
RobertaForSequenceClassification.__init__(self, config)
self.embeddingHead = nn.Linear(config.hidden_size, 768)
self.norm = nn.LayerNorm(768)
self.apply(self._init_weights)
def query_emb(self, input_ids, attention_mask):
outputs1 = self.roberta(input_ids=input_ids,
attention_mask=attention_mask)
full_emb = self.masked_mean_or_first(outputs1, attention_mask)
query1 = self.norm(self.embeddingHead(full_emb))
return query1
def body_emb(self, input_ids, attention_mask):
return self.query_emb(input_ids, attention_mask)
class RobertaDot_CLF_ANN_NLL_MultiChunk(NLL_MultiChunk, RobertaDot_NLL_LN):
def __init__(self, config):
RobertaDot_NLL_LN.__init__(self, config)
self.base_len = 512
def body_emb(self, input_ids, attention_mask):
[batchS, full_length] = input_ids.size()
chunk_factor = full_length // self.base_len
input_seq = input_ids.reshape(
batchS,
chunk_factor,
full_length //
chunk_factor).reshape(
batchS *
chunk_factor,
full_length //
chunk_factor)
attention_mask_seq = attention_mask.reshape(
batchS,
chunk_factor,
full_length //
chunk_factor).reshape(
batchS *
chunk_factor,
full_length //
chunk_factor)
outputs_k = self.roberta(input_ids=input_seq,
attention_mask=attention_mask_seq)
compressed_output_k = self.embeddingHead(
outputs_k[0]) # [batch, len, dim]
compressed_output_k = self.norm(compressed_output_k[:, 0, :])
[batch_expand, embeddingS] = compressed_output_k.size()
complex_emb_k = compressed_output_k.reshape(
batchS, chunk_factor, embeddingS)
return complex_emb_k # size [batchS, chunk_factor, embeddingS]
class SEEDEncoderDot_NLL_LN(NLL, SEEDEncoderForSequenceClassification):
"""None
Compress embedding to 200d, then computes NLL loss.
"""
def __init__(self, config, model_argobj=None):
NLL.__init__(self, model_argobj)
SEEDEncoderForSequenceClassification.__init__(self, config)
self.embeddingHead = nn.Linear(config.encoder_embed_dim, 768)
self.norm = nn.LayerNorm(768)
self.apply(self._init_weights)
def query_emb(self, input_ids, attention_mask=None):
outputs1 = self.seed_encoder.encoder(input_ids)
full_emb = self.masked_mean_or_first(outputs1, attention_mask)
query1 = self.norm(self.embeddingHead(full_emb))
return query1
def body_emb(self, input_ids, attention_mask=None):
return self.query_emb(input_ids, attention_mask)
class HFBertEncoder(BertModel):
def __init__(self, config):
BertModel.__init__(self, config)
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
self.init_weights()
@classmethod
def init_encoder(cls, args, dropout: float = 0.1):
cfg = BertConfig.from_pretrained("bert-base-uncased")
if dropout != 0:
cfg.attention_probs_dropout_prob = dropout
cfg.hidden_dropout_prob = dropout
return cls.from_pretrained("bert-base-uncased", config=cfg)
def forward(self, input_ids, attention_mask):
hidden_states = None
sequence_output, pooled_output = super().forward(input_ids=input_ids,
attention_mask=attention_mask)
pooled_output = sequence_output[:, 0, :]
return sequence_output, pooled_output, hidden_states
def get_out_size(self):
if self.encode_proj:
return self.encode_proj.out_features
return self.config.hidden_size
class BiEncoder(nn.Module):
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
"""
def __init__(self, args):
super(BiEncoder, self).__init__()
self.question_model = HFBertEncoder.init_encoder(args)
self.ctx_model = HFBertEncoder.init_encoder(args)
def query_emb(self, input_ids, attention_mask):
sequence_output, pooled_output, hidden_states = self.question_model(input_ids, attention_mask)
return pooled_output
def body_emb(self, input_ids, attention_mask):
sequence_output, pooled_output, hidden_states = self.ctx_model(input_ids, attention_mask)
return pooled_output
def forward(self, query_ids, attention_mask_q, input_ids_a = None, attention_mask_a = None, input_ids_b = None, attention_mask_b = None):
if input_ids_b is None:
q_embs = self.query_emb(query_ids, attention_mask_q)
a_embs = self.body_emb(input_ids_a, attention_mask_a)
return (q_embs, a_embs)
q_embs = self.query_emb(query_ids, attention_mask_q)
a_embs = self.body_emb(input_ids_a, attention_mask_a)
b_embs = self.body_emb(input_ids_b, attention_mask_b)
logit_matrix = torch.cat([(q_embs*a_embs).sum(-1).unsqueeze(1), (q_embs*b_embs).sum(-1).unsqueeze(1)], dim=1) #[B, 2]
lsm = F.log_softmax(logit_matrix, dim=1)
loss = -1.0*lsm[:,0]
return (loss.mean(),)
# --------------------------------------------------
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (
RobertaConfig,
) if hasattr(conf,'pretrained_config_archive_map')
),
(),
)
default_process_fn = triple_process_fn
class MSMarcoConfig:
def __init__(self, name, model, process_fn=default_process_fn, use_mean=True, tokenizer_class=RobertaTokenizer, config_class=RobertaConfig):
self.name = name
self.process_fn = process_fn
self.model_class = model
self.use_mean = use_mean
self.tokenizer_class = tokenizer_class
self.config_class = config_class
configs = [
MSMarcoConfig(name="rdot_nll",
model=RobertaDot_NLL_LN,
use_mean=False,
),
MSMarcoConfig(name="rdot_nll_multi_chunk",
model=RobertaDot_CLF_ANN_NLL_MultiChunk,
use_mean=False,
),
MSMarcoConfig(name="dpr",
model=BiEncoder,
tokenizer_class=BertTokenizer,
config_class=BertConfig,
use_mean=False,
),
MSMarcoConfig(name="seeddot_nll",
model=SEEDEncoderDot_NLL_LN,
use_mean=False,
tokenizer_class=SEEDTokenizer,
config_class=SEEDEncoderConfig,
),
]
MSMarcoConfigDict = {cfg.name: cfg for cfg in configs}