-
Notifications
You must be signed in to change notification settings - Fork 6.3k
/
transformer_lm.py
276 lines (239 loc) · 14.7 KB
/
transformer_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq import options, utils
from fairseq.models import (
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
Embedding,
TransformerDecoder,
)
from fairseq.modules import (
AdaptiveInput,
CharacterTokenEmbedder,
)
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel):
@classmethod
def hub_models(cls):
return {
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
'transformer_lm.wmt19.en': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2',
'transformer_lm.wmt19.de': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2',
'transformer_lm.wmt19.ru': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2',
}
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension')
parser.add_argument('--decoder-input-dim', type=int, metavar='N',
help='decoder input dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--no-decoder-final-norm', action='store_true',
help='don\'t add an extra layernorm after the last decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--no-token-positional-embeddings', action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--character-embeddings', action='store_true',
help='if set, uses character embedding convolutions to produce token embeddings')
parser.add_argument('--character-filters', type=str, metavar='LIST',
default='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]',
help='size of character embeddings')
parser.add_argument('--character-embedding-dim', default=4, type=int, metavar='N',
help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', default=2, type=int, metavar='N',
help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', action='store_true',
help='if set, uses adaptive input')
parser.add_argument('--adaptive-input-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--adaptive-input-cutoff', metavar='EXPR',
help='comma separated list of adaptive input cutoff points.')
parser.add_argument('--tie-adaptive-weights', action='store_true',
help='if set, ties the weights of adaptive softmax and adaptive input')
parser.add_argument('--tie-adaptive-proj', action='store_true',
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)
if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(
task.source_dictionary, eval(args.character_filters),
args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
elif args.adaptive_input:
embed_tokens = AdaptiveInput(
len(task.source_dictionary), task.source_dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
)
else:
embed_tokens = Embedding(len(task.source_dictionary), args.decoder_input_dim, task.source_dictionary.pad())
if args.tie_adaptive_weights:
assert args.adaptive_input
assert args.adaptive_input_factor == args.adaptive_softmax_factor
assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format(
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(
args, task.target_dictionary, embed_tokens, no_encoder_attn=True,
)
return TransformerLanguageModel(decoder)
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, 'no_tie_adaptive_proj'):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, 'decoder_final_norm'):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.add_bos_token = getattr(args, 'add_bos_token', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', False)
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None)
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_layers = getattr(args, 'decoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
@register_model_architecture('transformer_lm', 'transformer_lm_baevski_wiki103')
def transformer_lm_baevski_wiki103(args):
args.decoder_layers = getattr(args, 'decoder_layers', 16)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.dropout = getattr(args, 'dropout', 0.3)
args.adaptive_input = getattr(args, 'adaptive_input', True)
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', True)
args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', '20000,60000')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000')
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', True)
transformer_lm_big(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw')
@register_model_architecture('transformer_lm', 'transformer_lm_baevski_gbw')
def transformer_lm_baevski_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True)
transformer_lm_big(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gpt')
def transformer_lm_gpt(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 3072)
args.decoder_layers = getattr(args, 'decoder_layers', 12)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_small')
def transformer_lm_gpt2_small(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_layers = getattr(args, 'decoder_layers', 24)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_medium')
def transformer_lm_gpt2_medium(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1280)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 5120)
args.decoder_layers = getattr(args, 'decoder_layers', 36)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 20)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gpt2_big')
def transformer_lm_gpt2_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1600)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 6400)
args.decoder_layers = getattr(args, 'decoder_layers', 48)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 25)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args)