-
Notifications
You must be signed in to change notification settings - Fork 6.3k
/
model.py
409 lines (345 loc) · 16.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from .hub_interface import RobertaHubInterface
@register_model('roberta')
class RobertaModel(FairseqLanguageModel):
@classmethod
def hub_models(cls):
return {
'roberta.base': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz',
'roberta.large': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz',
'roberta.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz',
'roberta.large.wsc': 'http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz',
}
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--encoder-layers', type=int, metavar='L',
help='num encoder layers')
parser.add_argument('--encoder-embed-dim', type=int, metavar='H',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='F',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-attention-heads', type=int, metavar='A',
help='num encoder attention heads')
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--pooler-activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use for pooler layer')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN')
parser.add_argument('--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers')
parser.add_argument('--max-positions', type=int,
help='number of positional embeddings to learn')
parser.add_argument('--load-checkpoint-heads', action='store_true',
help='(re-)register and load heads when loading checkpoints')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample
encoder = RobertaEncoder(args, task.source_dictionary)
return cls(args, encoder)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, classification_head_name=None, **kwargs):
if classification_head_name is not None:
features_only = True
x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
print(
'WARNING: re-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = RobertaClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
@property
def supported_targets(self):
return {'self'}
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='gpt2', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
if getattr(self.args, 'load_checkpoint_heads', False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
print(
'WARNING: deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
):
print(
'WARNING: deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
print('Overwriting', prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
@register_model('xlmr')
class XLMRModel(RobertaModel):
@classmethod
def hub_models(cls):
return {
'xlmr.base.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz',
'xlmr.large.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz',
}
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
@register_model('camembert')
class CamembertModel(RobertaModel):
@classmethod
def hub_models(cls):
return {
'camembert.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/camembert.v0.tar.gz',
}
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the unmasked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class RobertaEncoder(FairseqDecoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layerdrop=args.encoder_layerdrop,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=True,
apply_bert_init=True,
activation_fn=args.activation_fn,
)
self.lm_head = RobertaLMHead(
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=self.sentence_encoder.embed_tokens.weight,
)
def forward(self, src_tokens, features_only=False, return_all_hiddens=False, masked_tokens=None, **unused):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
inner_states, _ = self.sentence_encoder(
src_tokens,
last_state_only=not return_all_hiddens,
)
features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
@register_model_architecture('roberta', 'roberta')
def base_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.0)
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
@register_model_architecture('roberta', 'roberta_base')
def roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture('roberta', 'roberta_large')
def roberta_large_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 24)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)
@register_model_architecture('roberta', 'xlm')
def xlm_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 16)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1280)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1280*4)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)