Skip to content

Commit

Permalink
adding layerdrop code for training, pruning, and readme (#890)
Browse files Browse the repository at this point in the history
Summary:
TEST 1: EVALUATION TIME WORKS
checked
achieves correct model perplexity: 18.68

TEST 2: TRAINING NEW MODEL WORKS
checked

without layerdrop:
--decoder-layerdrop 0 OR no flag at all
| epoch 001:     10 / 11201 loss=27.469, nll_loss=27.469, ppl=185799477.36, wps=1764, ups=0, wpb=9216.000, bsz=3.000, num_updates=7, lr=0.0004376, gnorm=25.471, clip=1.000, oom=0.000, loss_scale=8.000, wall=37, train_wall=30
| epoch 001:     20 / 11201 loss=27.443, nll_loss=27.443, ppl=182500427.22, wps=2449, ups=0, wpb=9216.000, bsz=3.000, num_updates=17, lr=0.0010626, gnorm=25.273, clip=1.000, oom=0.000, loss_scale=8.000, wall=64, train_wall=57
| epoch 001:     30 / 11201 loss=27.404, nll_loss=27.404, ppl=177612215.78, wps=2720, ups=0, wpb=9216.000, bsz=3.000, num_updates=27, lr=0.0016876, gnorm=25.136, clip=1.000, oom=0.000, loss_scale=8.000, wall=91, train_wall=84
| epoch 001:     40 / 11201 loss=27.009, nll_loss=27.009, ppl=135079983.00, wps=2865, ups=0, wpb=9216.000, bsz=3.000, num_updates=37, lr=0.0023126, gnorm=24.311, clip=1.000, oom=0.000, loss_scale=8.000, wall=119, train_wall=112
| epoch 001:     50 / 11201 loss=26.418, nll_loss=26.418, ppl=89680259.41, wps=2952, ups=0, wpb=9216.000, bsz=3.000, num_updates=47, lr=0.0029376, gnorm=22.775, clip=1.000, oom=0.000, loss_scale=8.000, wall=147, train_wall=140

with layerdrop (regularization effect should be seen in PPL):
--decoder-layerdrop 0.2

| epoch 001:     10 / 11201 loss=25.186, nll_loss=25.186, ppl=38182937.27, wps=2428, ups=0, wpb=9216.000, bsz=3.000, num_updates=8, lr=0.0005001, gnorm=17.082, clip=1.000, oom=0.000, loss_scale=16.000, wall=30, train_wall=24
| epoch 001:     20 / 11201 loss=25.270, nll_loss=25.270, ppl=40451933.50, wps=3173, ups=0, wpb=9216.000, bsz=3.000, num_updates=18, lr=0.0011251, gnorm=17.162, clip=1.000, oom=0.000, loss_scale=16.000, wall=52, train_wall=45
| epoch 001:     30 / 11201 loss=25.349, nll_loss=25.349, ppl=42752256.68, wps=3454, ups=0, wpb=9216.000, bsz=3.000, num_updates=28, lr=0.0017501, gnorm=17.370, clip=1.000, oom=0.000, loss_scale=16.000, wall=75, train_wall=68
| epoch 001:     40 / 11201 loss=25.115, nll_loss=25.115, ppl=36343806.30, wps=3619, ups=0, wpb=9216.000, bsz=3.000, num_updates=38, lr=0.0023751, gnorm=16.945, clip=1.000, oom=0.000, loss_scale=16.000, wall=97, train_wall=90
| epoch 001:     50 / 11201 loss=24.804, nll_loss=24.804, ppl=29284345.78, wps=3716, ups=0, wpb=9216.000, bsz=3.000, num_updates=48, lr=0.0030001, gnorm=16.406, clip=1.000, oom=0.000, loss_scale=16.000, wall=119, train_wall=112

TEST 3: PICKING UP TRAINING FROM EXISTING MODEL
checked

| loaded checkpoint /checkpoint/angelafan/structured_0.1_block_8_sd02/checkpoint_last.pt (epoch 272 @ 381066 updates)
| loading train data for epoch 272
| loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train

TEST 4: EVALUATING EXISTING BERT MODEL REPROS RESULTS
| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Accuracy:  0.9231651376146789
achieves correct accuracy on SST2 for this model

TEST 5: TRAINING NEW BERT MODEL WORKS
checked and works

TEST 6: NMT

without layerdrop
--encoder-layerdrop 0 --decoder-layerdrop 0 OR combinations of flag specified and not specified

| epoch 001:     10 / 92203 loss=15.820, nll_loss=15.830, ppl=58267.93, wps=4902, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=7.207, clip=0.000, oom=0.000, loss_scale=128.000, wall=60, train_wall=3
| epoch 001:     20 / 92203 loss=15.523, nll_loss=15.501, ppl=46359.29, wps=5037, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.869, clip=0.000, oom=0.000, loss_scale=128.000, wall=63, train_wall=6
| epoch 001:     30 / 92203 loss=15.185, nll_loss=15.123, ppl=35695.79, wps=5085, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.186, clip=0.000, oom=0.000, loss_scale=128.000, wall=66, train_wall=9
| epoch 001:     40 / 92203 loss=14.940, nll_loss=14.849, ppl=29505.60, wps=5116, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=5.610, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=12
| epoch 001:     50 / 92203 loss=14.745, nll_loss=14.630, ppl=25346.87, wps=5070, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.104, clip=0.000, oom=0.000, loss_scale=128.000, wall=71, train_wall=15

with layerdrop (regularization effect should be seen in PPL)

A) works with --encoder-layerdrop 0.2 --decoder-layerdrop 0.2
B) works with different settings --encoder-layerdrop 0.3 --decoder-layerdrop 0.5
C) works with one on and one off --encoder-layerdrop 0.2 --decoder-layerdrop 0

| epoch 001:     10 / 92203 loss=15.817, nll_loss=15.828, ppl=58158.54, wps=5355, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=6.959, clip=0.000, oom=0.000, loss_scale=128.000, wall=59, train_wall=3
| epoch 001:     20 / 92203 loss=15.650, nll_loss=15.641, ppl=51111.63, wps=5515, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.825, clip=0.000, oom=0.000, loss_scale=128.000, wall=61, train_wall=6
| epoch 001:     30 / 92203 loss=15.440, nll_loss=15.408, ppl=43491.58, wps=5602, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.576, clip=0.000, oom=0.000, loss_scale=128.000, wall=64, train_wall=8
| epoch 001:     40 / 92203 loss=15.247, nll_loss=15.193, ppl=37457.14, wps=5676, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=6.124, clip=0.000, oom=0.000, loss_scale=128.000, wall=67, train_wall=11
| epoch 001:     50 / 92203 loss=15.055, nll_loss=14.977, ppl=32259.92, wps=5598, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.661, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=14

TEST 7: PRUNING TESTCASES

A) after adding the pruning flags, model can evaluate as a full model
checked, reaches correct PPL
num. model params: 246933504
| Evaluated 217646 tokens in 196.3s (1108.99 tokens/s)
| Loss: 2.9275, Perplexity: 18.68

B) after adding pruning flags, model can be pruned. this works with multiple flag settings
checked three cases:
num. model params: 146163712
| Evaluated 217646 tokens in 106.0s (2054.07 tokens/s)
| Loss: 3.0932, Perplexity: 22.05

num. model params: 209144832
| Evaluated 217646 tokens in 162.8s (1336.99 tokens/s)
| Loss: 2.9526, Perplexity: 19.16

C) model can pick up training if you want to finetune the pruned model
checked:
| loading train data for epoch 272
| loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train
| WARNING: overflow detected, setting loss scale to: 64.0
| WARNING: overflow detected, setting loss scale to: 32.0
| epoch 272:   1500 / 5601 loss=5.015, nll_loss=5.015, ppl=32.33, wps=11598, ups=1, wpb=18432.000, bsz=6.000, num_updates=98, lr=0.0061251, gnorm=0.613, clip=1.000, oom=0.000, loss_scale=32.000, wall=156, train_wall=252396

D) works with BERT
checked:
without specifying any flags, reproduces the correct standard accuracy
with flags, produces the correct pruned accuracy

| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Accuracy:  0.9231651376146789

| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop
| Accuracy:  0.9220183486238532
Pull Request resolved: fairinternal/fairseq-py#890

Reviewed By: edunov

Differential Revision: D18094657

Pulled By: huihuifan

fbshipit-source-id: 2bbaa2ff0039e906782694fc2038b8c17a8693e7
  • Loading branch information
huihuifan authored and facebook-github-bot committed Oct 27, 2019
1 parent eb68afc commit dabbef4
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 24 deletions.
66 changes: 66 additions & 0 deletions examples/layerdrop/README.md
@@ -0,0 +1,66 @@
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
This page contains information for how to train models with LayerDrop.

Looking for pretrained models? They will be added shortly.

Looking for code for other forms of Structured Dropout? It will be added shortly.

## Citation:
```bibtex
@article{fan2019reducing,
title={Reducing Transformer Depth on Demand with Structured Dropout},
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1909.11556},
year={2019}
}
```

## Example usage

To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
```
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
```

To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
```
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
```
Setting these flags should print a message such as:
```
| Pruning model to specified layer configuration
```
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
```
num. model params: 246933504
```
while a model pruned to 8 Layers prints:
```
num. model params: 146163712
```

If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
```
python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.


Looking to reproduce the results in the paper?

1. For Translation on WMT en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md)
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta)
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)


## Tips

1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).

2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2.

3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.

## Having an issue or have a question?

Please open an issue in this repository with the details of your question. Thanks!
66 changes: 65 additions & 1 deletion fairseq/checkpoint_utils.py
Expand Up @@ -183,7 +183,7 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):

# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
model.load_state_dict(state['model'], strict=True, args=args)
ensemble.append(model)
return ensemble, args, task

Expand Down Expand Up @@ -334,6 +334,70 @@ def _upgrade_state_dict(state):
return state


def prune_state_dict(state_dict, args):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
Training with LayerDrop allows models to be robust to pruning at inference
time. This function prunes state_dict to allow smaller models to be loaded
from a larger model and re-maps the existing state_dict for this to occur.
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
if not args:
# args should not be none, but don't crash if it is.
return state_dict

encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None

if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict

# apply pruning
print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop")

def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")])
mapping_dict = {}
for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i)

regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return {
"substitution_regex": regex,
"mapping_dict": mapping_dict
}

pruning_passes = []
if encoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
if decoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))

new_state_dict = {}
for layer_name in state_dict.keys():
match = re.search("\.layers\.(\d+)\.", layer_name)
# if layer has no number in it, it is a supporting layer, such as an
# embedding
if not match:
new_state_dict[layer_name] = state_dict[layer_name]
continue

# otherwise, layer should be pruned.
original_layer_number = match.group(1)
# figure out which mapping dict to replace from
for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):]
new_state_dict[new_state_key] = state_dict[layer_name]

return new_state_dict


def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
):
Expand Down
6 changes: 4 additions & 2 deletions fairseq/models/fairseq_model.py
Expand Up @@ -13,6 +13,7 @@
import torch.nn.functional as F

from fairseq import utils
from fairseq.checkpoint_utils import prune_state_dict
from fairseq.data import Dictionary
from fairseq.models import FairseqDecoder, FairseqEncoder

Expand Down Expand Up @@ -58,15 +59,16 @@ def max_positions(self):
"""Maximum length supported by the model."""
return None

def load_state_dict(self, state_dict, strict=True):
def load_state_dict(self, state_dict, strict=True, args=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
new_state_dict = prune_state_dict(state_dict, args)
return super().load_state_dict(new_state_dict, strict)

def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
Expand Down
15 changes: 15 additions & 0 deletions fairseq/models/roberta/model.py
Expand Up @@ -78,6 +78,11 @@ def add_args(parser):
help='number of positional embeddings to learn')
parser.add_argument('--load-checkpoint-heads', action='store_true',
help='(re-)register and load heads when loading checkpoints')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')

@classmethod
def build_model(cls, args, task):
Expand Down Expand Up @@ -245,6 +250,15 @@ class RobertaEncoder(FairseqDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args

# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None

self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
Expand All @@ -255,6 +269,7 @@ def __init__(self, args, dictionary):
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layerdrop=args.encoder_layerdrop,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=True,
Expand Down
57 changes: 40 additions & 17 deletions fairseq/models/transformer.py
Expand Up @@ -25,6 +25,7 @@
TransformerDecoderLayer,
TransformerEncoderLayer,
)
import random

DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
Expand Down Expand Up @@ -130,6 +131,15 @@ def add_args(parser):
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# fmt: on

@classmethod
Expand All @@ -139,6 +149,11 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_architecture(args)

if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

if not hasattr(args, 'max_source_positions'):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
Expand Down Expand Up @@ -275,6 +290,7 @@ def __init__(self, args, dictionary, embed_tokens):
self.register_buffer('version', torch.Tensor([3]))

self.dropout = args.dropout
self.encoder_layerdrop = args.encoder_layerdrop

embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
Expand All @@ -300,6 +316,7 @@ def __init__(self, args, dictionary, embed_tokens):
else:
self.layer_norm = None


def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
Expand Down Expand Up @@ -345,9 +362,12 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa

# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.encoder_layerdrop):
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)

if self.layer_norm:
x = self.layer_norm(x)
Expand Down Expand Up @@ -435,6 +455,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.register_buffer('version', torch.Tensor([3]))

self.dropout = args.dropout
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed

input_embed_dim = embed_tokens.embedding_dim
Expand Down Expand Up @@ -594,20 +615,22 @@ def extract_features(
else:
self_attn_mask = None

x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)

inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.decoder_layerdrop):
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()

if attn is not None:
if alignment_heads is not None:
Expand Down
8 changes: 8 additions & 0 deletions fairseq/models/transformer_lm.py
Expand Up @@ -98,6 +98,11 @@ def add_args(parser):
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# fmt: on

@classmethod
Expand All @@ -107,6 +112,9 @@ def build_model(cls, args, task):
# make sure all arguments are present in older models
base_lm_architecture(args)

if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))

if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)

Expand Down
13 changes: 10 additions & 3 deletions fairseq/modules/transformer_sentence_encoder.py
Expand Up @@ -14,6 +14,7 @@
PositionalEmbedding,
TransformerSentenceEncoderLayer,
)
import random


def init_bert_params(module):
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
layerdrop : float = 0.0,
max_seq_len: int = 256,
num_segments: int = 2,
use_position_embeddings: bool = True,
Expand All @@ -97,6 +99,7 @@ def __init__(
self.padding_idx = padding_idx
self.vocab_size = vocab_size
self.dropout = dropout
self.layerdrop = layerdrop
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
self.num_segments = num_segments
Expand Down Expand Up @@ -208,9 +211,13 @@ def forward(
inner_states.append(x)

for layer in self.layers:
x, _ = layer(x, self_attn_padding_mask=padding_mask)
if not last_state_only:
inner_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.layerdrop):
x, _ = layer(x, self_attn_padding_mask=padding_mask)
if not last_state_only:
inner_states.append(x)


# T x B x C -> B x T x C
x = x.transpose(0, 1)
Expand Down
2 changes: 1 addition & 1 deletion fairseq/trainer.py
Expand Up @@ -181,7 +181,7 @@ def load_checkpoint(

# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
self.get_model().load_state_dict(state['model'], strict=True, args=self.args)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
except Exception:
Expand Down

0 comments on commit dabbef4

Please sign in to comment.