From dabbef467692ef4ffb7de8a01235876bd7320a93 Mon Sep 17 00:00:00 2001 From: Angela Fan Date: Sun, 27 Oct 2019 12:09:29 -0700 Subject: [PATCH] adding layerdrop code for training, pruning, and readme (#890) Summary: TEST 1: EVALUATION TIME WORKS checked achieves correct model perplexity: 18.68 TEST 2: TRAINING NEW MODEL WORKS checked without layerdrop: --decoder-layerdrop 0 OR no flag at all | epoch 001: 10 / 11201 loss=27.469, nll_loss=27.469, ppl=185799477.36, wps=1764, ups=0, wpb=9216.000, bsz=3.000, num_updates=7, lr=0.0004376, gnorm=25.471, clip=1.000, oom=0.000, loss_scale=8.000, wall=37, train_wall=30 | epoch 001: 20 / 11201 loss=27.443, nll_loss=27.443, ppl=182500427.22, wps=2449, ups=0, wpb=9216.000, bsz=3.000, num_updates=17, lr=0.0010626, gnorm=25.273, clip=1.000, oom=0.000, loss_scale=8.000, wall=64, train_wall=57 | epoch 001: 30 / 11201 loss=27.404, nll_loss=27.404, ppl=177612215.78, wps=2720, ups=0, wpb=9216.000, bsz=3.000, num_updates=27, lr=0.0016876, gnorm=25.136, clip=1.000, oom=0.000, loss_scale=8.000, wall=91, train_wall=84 | epoch 001: 40 / 11201 loss=27.009, nll_loss=27.009, ppl=135079983.00, wps=2865, ups=0, wpb=9216.000, bsz=3.000, num_updates=37, lr=0.0023126, gnorm=24.311, clip=1.000, oom=0.000, loss_scale=8.000, wall=119, train_wall=112 | epoch 001: 50 / 11201 loss=26.418, nll_loss=26.418, ppl=89680259.41, wps=2952, ups=0, wpb=9216.000, bsz=3.000, num_updates=47, lr=0.0029376, gnorm=22.775, clip=1.000, oom=0.000, loss_scale=8.000, wall=147, train_wall=140 with layerdrop (regularization effect should be seen in PPL): --decoder-layerdrop 0.2 | epoch 001: 10 / 11201 loss=25.186, nll_loss=25.186, ppl=38182937.27, wps=2428, ups=0, wpb=9216.000, bsz=3.000, num_updates=8, lr=0.0005001, gnorm=17.082, clip=1.000, oom=0.000, loss_scale=16.000, wall=30, train_wall=24 | epoch 001: 20 / 11201 loss=25.270, nll_loss=25.270, ppl=40451933.50, wps=3173, ups=0, wpb=9216.000, bsz=3.000, num_updates=18, lr=0.0011251, gnorm=17.162, clip=1.000, oom=0.000, loss_scale=16.000, wall=52, train_wall=45 | epoch 001: 30 / 11201 loss=25.349, nll_loss=25.349, ppl=42752256.68, wps=3454, ups=0, wpb=9216.000, bsz=3.000, num_updates=28, lr=0.0017501, gnorm=17.370, clip=1.000, oom=0.000, loss_scale=16.000, wall=75, train_wall=68 | epoch 001: 40 / 11201 loss=25.115, nll_loss=25.115, ppl=36343806.30, wps=3619, ups=0, wpb=9216.000, bsz=3.000, num_updates=38, lr=0.0023751, gnorm=16.945, clip=1.000, oom=0.000, loss_scale=16.000, wall=97, train_wall=90 | epoch 001: 50 / 11201 loss=24.804, nll_loss=24.804, ppl=29284345.78, wps=3716, ups=0, wpb=9216.000, bsz=3.000, num_updates=48, lr=0.0030001, gnorm=16.406, clip=1.000, oom=0.000, loss_scale=16.000, wall=119, train_wall=112 TEST 3: PICKING UP TRAINING FROM EXISTING MODEL checked | loaded checkpoint /checkpoint/angelafan/structured_0.1_block_8_sd02/checkpoint_last.pt (epoch 272 @ 381066 updates) | loading train data for epoch 272 | loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train TEST 4: EVALUATING EXISTING BERT MODEL REPROS RESULTS | [input] dictionary: 50265 types | [label] dictionary: 9 types | Accuracy: 0.9231651376146789 achieves correct accuracy on SST2 for this model TEST 5: TRAINING NEW BERT MODEL WORKS checked and works TEST 6: NMT without layerdrop --encoder-layerdrop 0 --decoder-layerdrop 0 OR combinations of flag specified and not specified | epoch 001: 10 / 92203 loss=15.820, nll_loss=15.830, ppl=58267.93, wps=4902, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=7.207, clip=0.000, oom=0.000, loss_scale=128.000, wall=60, train_wall=3 | epoch 001: 20 / 92203 loss=15.523, nll_loss=15.501, ppl=46359.29, wps=5037, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.869, clip=0.000, oom=0.000, loss_scale=128.000, wall=63, train_wall=6 | epoch 001: 30 / 92203 loss=15.185, nll_loss=15.123, ppl=35695.79, wps=5085, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.186, clip=0.000, oom=0.000, loss_scale=128.000, wall=66, train_wall=9 | epoch 001: 40 / 92203 loss=14.940, nll_loss=14.849, ppl=29505.60, wps=5116, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=5.610, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=12 | epoch 001: 50 / 92203 loss=14.745, nll_loss=14.630, ppl=25346.87, wps=5070, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.104, clip=0.000, oom=0.000, loss_scale=128.000, wall=71, train_wall=15 with layerdrop (regularization effect should be seen in PPL) A) works with --encoder-layerdrop 0.2 --decoder-layerdrop 0.2 B) works with different settings --encoder-layerdrop 0.3 --decoder-layerdrop 0.5 C) works with one on and one off --encoder-layerdrop 0.2 --decoder-layerdrop 0 | epoch 001: 10 / 92203 loss=15.817, nll_loss=15.828, ppl=58158.54, wps=5355, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=6.959, clip=0.000, oom=0.000, loss_scale=128.000, wall=59, train_wall=3 | epoch 001: 20 / 92203 loss=15.650, nll_loss=15.641, ppl=51111.63, wps=5515, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.825, clip=0.000, oom=0.000, loss_scale=128.000, wall=61, train_wall=6 | epoch 001: 30 / 92203 loss=15.440, nll_loss=15.408, ppl=43491.58, wps=5602, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.576, clip=0.000, oom=0.000, loss_scale=128.000, wall=64, train_wall=8 | epoch 001: 40 / 92203 loss=15.247, nll_loss=15.193, ppl=37457.14, wps=5676, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=6.124, clip=0.000, oom=0.000, loss_scale=128.000, wall=67, train_wall=11 | epoch 001: 50 / 92203 loss=15.055, nll_loss=14.977, ppl=32259.92, wps=5598, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.661, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=14 TEST 7: PRUNING TESTCASES A) after adding the pruning flags, model can evaluate as a full model checked, reaches correct PPL num. model params: 246933504 | Evaluated 217646 tokens in 196.3s (1108.99 tokens/s) | Loss: 2.9275, Perplexity: 18.68 B) after adding pruning flags, model can be pruned. this works with multiple flag settings checked three cases: num. model params: 146163712 | Evaluated 217646 tokens in 106.0s (2054.07 tokens/s) | Loss: 3.0932, Perplexity: 22.05 num. model params: 209144832 | Evaluated 217646 tokens in 162.8s (1336.99 tokens/s) | Loss: 2.9526, Perplexity: 19.16 C) model can pick up training if you want to finetune the pruned model checked: | loading train data for epoch 272 | loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train | WARNING: overflow detected, setting loss scale to: 64.0 | WARNING: overflow detected, setting loss scale to: 32.0 | epoch 272: 1500 / 5601 loss=5.015, nll_loss=5.015, ppl=32.33, wps=11598, ups=1, wpb=18432.000, bsz=6.000, num_updates=98, lr=0.0061251, gnorm=0.613, clip=1.000, oom=0.000, loss_scale=32.000, wall=156, train_wall=252396 D) works with BERT checked: without specifying any flags, reproduces the correct standard accuracy with flags, produces the correct pruned accuracy | [input] dictionary: 50265 types | [label] dictionary: 9 types | Accuracy: 0.9231651376146789 | [input] dictionary: 50265 types | [label] dictionary: 9 types | Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop | Accuracy: 0.9220183486238532 Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/890 Reviewed By: edunov Differential Revision: D18094657 Pulled By: huihuifan fbshipit-source-id: 2bbaa2ff0039e906782694fc2038b8c17a8693e7 --- examples/layerdrop/README.md | 66 +++++++++++++++++++ fairseq/checkpoint_utils.py | 66 ++++++++++++++++++- fairseq/models/fairseq_model.py | 6 +- fairseq/models/roberta/model.py | 15 +++++ fairseq/models/transformer.py | 57 +++++++++++----- fairseq/models/transformer_lm.py | 8 +++ .../modules/transformer_sentence_encoder.py | 13 +++- fairseq/trainer.py | 2 +- 8 files changed, 209 insertions(+), 24 deletions(-) create mode 100644 examples/layerdrop/README.md diff --git a/examples/layerdrop/README.md b/examples/layerdrop/README.md new file mode 100644 index 0000000000..82ec4b6d53 --- /dev/null +++ b/examples/layerdrop/README.md @@ -0,0 +1,66 @@ +# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019) +This page contains information for how to train models with LayerDrop. + +Looking for pretrained models? They will be added shortly. + +Looking for code for other forms of Structured Dropout? It will be added shortly. + +## Citation: +```bibtex +@article{fan2019reducing, + title={Reducing Transformer Depth on Demand with Structured Dropout}, + author={Fan, Angela and Grave, Edouard and Joulin, Armand}, + journal={arXiv preprint arXiv:1909.11556}, + year={2019} +} +``` + +## Example usage + +To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently. +``` +--encoder-layerdrop 0.2 --decoder-layerdrop 0.2 +``` + +To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep. +``` +--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14 +``` +Setting these flags should print a message such as: +``` +| Pruning model to specified layer configuration +``` +You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints: +``` +num. model params: 246933504 +``` +while a model pruned to 8 Layers prints: +``` +num. model params: 146163712 +``` + +If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling: +``` +python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}" +``` +This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model. + + +Looking to reproduce the results in the paper? + +1. For Translation on WMT en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md) +2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta) +3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) + + +## Tips + +1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1). + +2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2. + +3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good. + +## Having an issue or have a question? + +Please open an issue in this repository with the details of your question. Thanks! diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ded8ce32f5..abf1bcc65f 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -183,7 +183,7 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None): # build model for ensemble model = task.build_model(args) - model.load_state_dict(state['model'], strict=True) + model.load_state_dict(state['model'], strict=True, args=args) ensemble.append(model) return ensemble, args, task @@ -334,6 +334,70 @@ def _upgrade_state_dict(state): return state +def prune_state_dict(state_dict, args): + """Prune the given state_dict if desired for LayerDrop + (https://arxiv.org/abs/1909.11556). + + Training with LayerDrop allows models to be robust to pruning at inference + time. This function prunes state_dict to allow smaller models to be loaded + from a larger model and re-maps the existing state_dict for this to occur. + + It's called by functions that load models from checkpoints and does not + need to be called directly. + """ + if not args: + # args should not be none, but don't crash if it is. + return state_dict + + encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None + decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None + + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + # apply pruning + print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop") + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")]) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + return { + "substitution_regex": regex, + "mapping_dict": mapping_dict + } + + pruning_passes = [] + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + new_state_dict = {} + for layer_name in state_dict.keys(): + match = re.search("\.layers\.(\d+)\.", layer_name) + # if layer has no number in it, it is a supporting layer, such as an + # embedding + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + # otherwise, layer should be pruned. + original_layer_number = match.group(1) + # figure out which mapping dict to replace from + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name): + new_layer_number = pruning_pass["mapping_dict"][original_layer_number] + substitution_match = pruning_pass["substitution_regex"].search(layer_name) + new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):] + new_state_dict[new_state_key] = state_dict[layer_name] + + return new_state_dict + + def load_pretrained_component_from_model( component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str ): diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index bd73bd5c23..2d9e942d3b 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from fairseq import utils +from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary from fairseq.models import FairseqDecoder, FairseqEncoder @@ -58,7 +59,7 @@ def max_positions(self): """Maximum length supported by the model.""" return None - def load_state_dict(self, state_dict, strict=True): + def load_state_dict(self, state_dict, strict=True, args=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. @@ -66,7 +67,8 @@ def load_state_dict(self, state_dict, strict=True): this additionally "upgrades" *state_dicts* from old checkpoints. """ self.upgrade_state_dict(state_dict) - return super().load_state_dict(state_dict, strict) + new_state_dict = prune_state_dict(state_dict, args) + return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): """Upgrade old state dicts to work with newer code.""" diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index ac94a04845..1c4c243e7e 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -78,6 +78,11 @@ def add_args(parser): help='number of positional embeddings to learn') parser.add_argument('--load-checkpoint-heads', action='store_true', help='(re-)register and load heads when loading checkpoints') + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for encoder') + parser.add_argument('--encoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') @classmethod def build_model(cls, args, task): @@ -245,6 +250,15 @@ class RobertaEncoder(FairseqDecoder): def __init__(self, args, dictionary): super().__init__(dictionary) self.args = args + + # RoBERTa is a sentence encoder model, so users will intuitively trim + # encoder layers. However, the implementation uses the fairseq decoder, + # so we fix here. + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + args.decoder_layers_to_keep = args.encoder_layers_to_keep + args.encoder_layers_to_keep = None + self.sentence_encoder = TransformerSentenceEncoder( padding_idx=dictionary.pad(), vocab_size=len(dictionary), @@ -255,6 +269,7 @@ def __init__(self, args, dictionary): dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, + layerdrop=args.encoder_layerdrop, max_seq_len=args.max_positions, num_segments=0, encoder_normalize_before=True, diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index f5f23f1b95..573c41373b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -25,6 +25,7 @@ TransformerDecoderLayer, TransformerEncoderLayer, ) +import random DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -130,6 +131,15 @@ def add_args(parser): help='perform cross+self-attention') parser.add_argument('--layer-wise-attention', default=False, action='store_true', help='perform layer-wise attention (cross-attention or cross+self-attention)') + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for encoder') + parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for decoder') + parser.add_argument('--encoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument('--decoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') # fmt: on @classmethod @@ -139,6 +149,11 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + if not hasattr(args, 'max_source_positions'): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if not hasattr(args, 'max_target_positions'): @@ -275,6 +290,7 @@ def __init__(self, args, dictionary, embed_tokens): self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout + self.encoder_layerdrop = args.encoder_layerdrop embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx @@ -300,6 +316,7 @@ def __init__(self, args, dictionary, embed_tokens): else: self.layer_norm = None + def forward_embedding(self, src_tokens): # embed tokens and positions embed = self.embed_scale * self.embed_tokens(src_tokens) @@ -345,9 +362,12 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa # encoder layers for layer in self.layers: - x = layer(x, encoder_padding_mask) - if return_all_hiddens: - encoder_states.append(x) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not self.training or (dropout_probability > self.encoder_layerdrop): + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) if self.layer_norm: x = self.layer_norm(x) @@ -435,6 +455,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.register_buffer('version', torch.Tensor([3])) self.dropout = args.dropout + self.decoder_layerdrop = args.decoder_layerdrop self.share_input_output_embed = args.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim @@ -594,20 +615,22 @@ def extract_features( else: self_attn_mask = None - x, layer_attn = layer( - x, - encoder_state, - encoder_out['encoder_padding_mask'] if encoder_out is not None else None, - incremental_state, - self_attn_mask=self_attn_mask, - self_attn_padding_mask=self_attn_padding_mask, - need_attn=(idx == alignment_layer), - need_head_weights=(idx == alignment_layer), - ) - - inner_states.append(x) - if layer_attn is not None and idx == alignment_layer: - attn = layer_attn.float() + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not self.training or (dropout_probability > self.decoder_layerdrop): + x, layer_attn = layer( + x, + encoder_state, + encoder_out['encoder_padding_mask'] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=(idx == alignment_layer), + need_head_weights=(idx == alignment_layer), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float() if attn is not None: if alignment_heads is not None: diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 87c7719209..f04dd36032 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -98,6 +98,11 @@ def add_args(parser): help='if set, ties the projection weights of adaptive softmax and adaptive input') parser.add_argument('--decoder-learned-pos', action='store_true', help='use learned positional embeddings in the decoder') + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for decoder') + parser.add_argument('--decoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') # fmt: on @classmethod @@ -107,6 +112,9 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_lm_architecture(args) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + if getattr(args, 'max_target_positions', None) is None: args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 9be7ab3080..f7e3973080 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -14,6 +14,7 @@ PositionalEmbedding, TransformerSentenceEncoderLayer, ) +import random def init_bert_params(module): @@ -77,6 +78,7 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, + layerdrop : float = 0.0, max_seq_len: int = 256, num_segments: int = 2, use_position_embeddings: bool = True, @@ -97,6 +99,7 @@ def __init__( self.padding_idx = padding_idx self.vocab_size = vocab_size self.dropout = dropout + self.layerdrop = layerdrop self.max_seq_len = max_seq_len self.embedding_dim = embedding_dim self.num_segments = num_segments @@ -208,9 +211,13 @@ def forward( inner_states.append(x) for layer in self.layers: - x, _ = layer(x, self_attn_padding_mask=padding_mask) - if not last_state_only: - inner_states.append(x) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not self.training or (dropout_probability > self.layerdrop): + x, _ = layer(x, self_attn_padding_mask=padding_mask) + if not last_state_only: + inner_states.append(x) + # T x B x C -> B x T x C x = x.transpose(0, 1) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 545357ebef..5de30e2246 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -181,7 +181,7 @@ def load_checkpoint( # load model parameters try: - self.get_model().load_state_dict(state['model'], strict=True) + self.get_model().load_state_dict(state['model'], strict=True, args=self.args) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict(state['criterion'], strict=True) except Exception: