This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
transformer.py
473 lines (424 loc) · 15.3 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Transformer Agents.
"""
from typing import Optional
from parlai.core.params import ParlaiParser
from parlai.core.opt import Opt
from parlai.core.agents import Agent
from parlai.utils.torch import padded_3d
from parlai.core.torch_classifier_agent import TorchClassifierAgent
from parlai.core.torch_ranker_agent import TorchRankerAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.utils.misc import recursive_getattr, warn_once
from parlai.utils.logging import logging
from parlai.utils.fsdp import should_use_fsdp
from .modules import (
TransformerMemNetModel,
TransformerGeneratorModel,
TransformerLinearWrapper,
)
import torch
def _check_positional_embeddings(opt):
"""
Checks positional embedding compatibility with FSDP.
"""
if not opt.get('learn_positional_embeddings') and should_use_fsdp(opt):
# note: we're doing on-the-fly setting here, abusing pass-by-reference
# this only works because we're calling this from build_model, which is
# only done in the original instantiation of an agent.
opt['learn_positional_embeddings'] = True
warn_once(
"Using --ddp_backend zeroX requires --learn-positional-embeddings "
"true. Forcing this to be true."
)
def add_common_cmdline_args(parser):
"""
Add common command line args.
"""
parser.add_argument(
'-esz',
'--embedding-size',
type=int,
default=300,
help='Size of all embedding layers. Must be a multiple of --n-heads.',
)
parser.add_argument(
'-nl', '--n-layers', type=int, default=2, help='Number of transformer layers.'
)
parser.add_argument(
'-hid',
'--ffn-size',
type=int,
default=300,
help='Hidden size of the FFN layers',
)
parser.add_argument(
'--dropout',
type=float,
default=0.0,
help='Dropout used around embeddings and before layer normalizations. '
'This is used in Vaswani 2017 and works well on large datasets.',
)
parser.add_argument(
'--attention-dropout',
type=float,
default=0.0,
help='Dropout used after attention softmax. This is not used in Vaswani 2017.',
)
parser.add_argument(
'--relu-dropout',
type=float,
default=0.0,
help='Dropout used after the ReLU in the FFN. Not used in Vaswani 2017, '
'but used in Tensor2Tensor.',
)
parser.add_argument(
'--n-heads', type=int, default=2, help='Number of multihead attention heads'
)
parser.add_argument(
'--learn-positional-embeddings',
type='bool',
default=False,
help='If off, sinusoidal embeddings are used. If on, position embeddings are '
'learned from scratch.',
)
parser.add_argument('--embeddings-scale', type='bool', default=True)
parser.add_argument(
'--n-positions',
type=int,
default=None,
hidden=True,
help='Number of positional embeddings to learn. Defaults '
'to truncate or 1024 if not provided.',
)
parser.add_argument(
'--n-segments',
type=int,
default=0,
help='The number of segments that support the model. '
'If zero no segment and no langs_embedding.',
)
parser.add_argument(
'--variant',
choices={'aiayn', 'xlm', 'prelayernorm', 'bart'},
default='aiayn',
help='Chooses locations of layer norms, etc. prelayernorm '
'is used to match some fairseq models',
recommended='xlm',
)
parser.add_argument(
'--activation',
choices={'relu', 'gelu'},
default='relu',
help='Nonlinear activation to use. AIAYN uses relu, but '
'more recent papers prefer gelu.',
recommended='gelu',
)
parser.add_argument(
'--output-scaling',
type=float,
default=1.0,
help='scale the output of every transformer by this quantity.',
)
parser.add_argument(
'--share-word-embeddings',
type='bool',
default=True,
help='Share word embeddings table for candidate and context'
'in the memory network',
)
parser.add_argument(
'-nel',
'--n-encoder-layers',
type=int,
default=-1,
help='This will override the n-layers for asymmetrical transformers',
)
parser.add_argument(
'-ndl',
'--n-decoder-layers',
type=int,
default=-1,
help='This will override the n-layers for asymmetrical transformers',
)
parser.add_argument(
'--model-parallel',
type='bool',
default=False,
help='Shard the layers across multiple GPUs.',
)
parser.add_argument(
'--checkpoint-activations',
type='bool',
default=False,
help='Recompute activations on backward pass to conserve memory.',
)
class Transformer(Agent):
"""
Placeholder Transformer Agent.
Placeholder class, which just throws an error telling the user to specify whether
they want the ranker or the generator.
"""
def __init__(self, opt, shared=None):
raise RuntimeError(
"`--model transformer` is not a valid choice. Please select either "
"`--model transformer/ranker` or `--model transformer/generator"
)
class TransformerRankerAgent(TorchRankerAgent):
"""
Transformer Ranker Agent.
Implementation of a TorchRankerAgent, where the model is a Transformer
"""
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add command-line arguments specifically for this agent.
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
agent = parser.add_argument_group('Transformer Arguments')
add_common_cmdline_args(agent)
# memory and knowledge arguments
agent.add_argument(
'--use-memories',
type='bool',
default=False,
help='use memories: must implement the function '
'`_vectorize_memories` to use this',
)
agent.add_argument(
'--wrap-memory-encoder',
type='bool',
default=False,
help='wrap memory encoder with MLP',
)
agent.add_argument(
'--memory-attention',
type=str,
default='sqrt',
choices=['cosine', 'dot', 'sqrt'],
help='similarity for basic attention mechanism '
'when using transformer to encode memories',
)
# model specific arguments
agent.add_argument('--normalize-sent-emb', type='bool', default=False)
agent.add_argument('--share-encoders', type='bool', default=True)
parser.add_argument(
'--share-word-embeddings',
type='bool',
default=True,
help='Share word embeddings table for candidate and context'
'in the memory network',
)
agent.add_argument(
'--learn-embeddings', type='bool', default=True, help='learn embeddings'
)
agent.add_argument(
'--data-parallel',
type='bool',
default=False,
help='use model in data parallel, requires ' 'multiple gpus',
)
agent.add_argument(
'--reduction-type',
type=str,
default='mean',
choices=['first', 'max', 'mean'],
help='Type of reduction at the end of transformer',
)
parser.set_defaults(learningrate=0.0001, optimizer='adamax', truncate=1024)
cls.dictionary_class().add_cmdline_args(parser, partial_opt=partial_opt)
return agent
def _score(self, output, cands):
if cands.dim() == 2:
return torch.matmul(output, cands.t())
elif cands.dim() == 3:
return torch.bmm(output.unsqueeze(1), cands.transpose(1, 2)).squeeze(1)
else:
raise RuntimeError(
'Unexpected candidate dimensions {}' ''.format(cands.dim())
)
def build_model(self, states=None):
"""
Build and return model.
"""
_check_positional_embeddings(self.opt)
model = TransformerMemNetModel(self.opt, self.dict)
if self.opt['embedding_type'] != 'random':
self._copy_embeddings(model.embeddings.weight, self.opt['embedding_type'])
return model
def batchify(self, obs_batch, sort=False):
"""
Override so that we can add memories to the Batch object.
"""
batch = super().batchify(obs_batch, sort)
if self.opt['use_memories']:
valid_obs = [(i, ex) for i, ex in enumerate(obs_batch) if self.is_valid(ex)]
valid_inds, exs = zip(*valid_obs)
mems = None
if any('memory_vecs' in ex for ex in exs):
mems = [ex.get('memory_vecs', None) for ex in exs]
batch.memory_vecs = mems
return batch
def _vectorize_memories(self, obs):
# TODO: move this to Torch Ranker Agent
raise NotImplementedError(
'Abstract class: user must implement this function to use memories'
)
def vectorize(self, *args, **kwargs):
"""
Override to include vectorization of memories.
"""
kwargs['add_start'] = False
kwargs['add_end'] = False
obs = super().vectorize(*args, **kwargs)
if self.opt['use_memories']:
obs = self._vectorize_memories(obs)
return obs
def encode_candidates(self, padded_cands):
"""
Encode candidates.
"""
_, cands = self.model(xs=None, mems=None, cands=padded_cands)
return cands
def score_candidates(self, batch, cand_vecs, cand_encs=None):
"""
Score candidates.
"""
# convoluted check that not all memories are empty
if (
self.opt['use_memories']
and batch.memory_vecs is not None
and sum(len(m) for m in batch.memory_vecs)
):
mems = padded_3d(batch.memory_vecs, pad_idx=self.NULL_IDX)
else:
mems = None
if cand_encs is not None:
# we pre-encoded the candidates, do not re-encode here
cand_vecs = None
context_h, cands_h = self.model(xs=batch.text_vec, mems=mems, cands=cand_vecs)
if cand_encs is not None:
cands_h = cand_encs
scores = self._score(context_h, cands_h)
return scores
class TransformerGeneratorAgent(TorchGeneratorAgent):
"""
TransformerGeneratorAgent.
Implementation of TorchGeneratorAgent, where the model is a Transformer
"""
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add command-line arguments specifically for this agent.
"""
agent = parser.add_argument_group('Transformer Arguments')
add_common_cmdline_args(agent)
cls.dictionary_class().add_cmdline_args(parser, partial_opt=partial_opt)
super().add_cmdline_args(parser, partial_opt=partial_opt)
return agent
def build_model(self, states=None):
"""
Build and return model.
"""
_check_positional_embeddings(self.opt)
model = TransformerGeneratorModel(self.opt, self.dict)
if self.opt['embedding_type'] != 'random':
self._copy_embeddings(
model.encoder.embeddings.weight, self.opt['embedding_type']
)
return model
def _resize_token_embeddings(self, state_dict, msg=None):
"""
Resize the token embeddings when are adding extra special tokens.
"""
# map extra special tokens carefully
new_size = self.model.embeddings.weight.size()[0]
orig_size = state_dict['embeddings.weight'].size()[0]
logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
if new_size <= orig_size:
# new size should be greater than original size,
# as we are adding special tokens
raise RuntimeError(msg)
for emb_weights in [
'embeddings.weight',
'encoder.embeddings.weight',
'decoder.embeddings.weight',
]:
# get new_embs
old_embs = state_dict[emb_weights]
new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
# copy over old weights
new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
# reset in state dict
state_dict[emb_weights] = new_embs
return state_dict
class TransformerClassifierAgent(TorchClassifierAgent):
"""
Classifier based on Transformer.
"""
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
TransformerRankerAgent.add_cmdline_args(
parser, partial_opt=partial_opt
) # add transformer args
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser.add_argument(
'--load-from-pretrained-ranker',
type='bool',
default=False,
help='load model from base transformer ranking model '
'(used for pretraining)',
)
parser.set_defaults(reduction_type='first')
return parser
def build_model(self):
_check_positional_embeddings(self.opt)
num_classes = len(self.class_list)
self.base_model = TransformerMemNetModel(self.opt, self.dict)
return TransformerLinearWrapper(self.base_model.context_encoder, num_classes)
def vectorize(self, *args, **kwargs):
"""
Add the start and end token to the text.
"""
kwargs['add_start'] = True
kwargs['add_end'] = True
obs = super().vectorize(*args, **kwargs)
return obs
def _set_text_vec(self, *args, **kwargs):
"""
Add the start and end token to the text.
"""
obs = super()._set_text_vec(*args, **kwargs)
if 'text_vec' in obs and 'added_start_end' not in obs:
obs.force_set(
'text_vec', self._add_start_end_tokens(obs['text_vec'], True, True)
)
obs['added_start_end'] = True
# check truncation after adding start end tokens
if obs.get('text_vec') is not None:
truncated_vec = self._check_truncate(
obs['text_vec'], self.text_truncate, True
)
obs.force_set('text_vec', torch.LongTensor(truncated_vec))
return obs
def score(self, batch):
return self.model(batch.text_vec)
def load_state_dict(self, state_dict):
"""
Load the state dict into model.
This is easily overridable to facilitate transfer of state dicts.
"""
if self.is_finetune and self.opt['load_from_pretrained_ranker']:
self.base_model.load_state_dict(state_dict, strict=False)
else:
self.model.load_state_dict(state_dict)