From 4ac2c5f2cc8a8b1f221f1e8e9b7839f07c25d997 Mon Sep 17 00:00:00 2001 From: Stephan Peitz Date: Sun, 29 Sep 2019 05:08:24 -0700 Subject: [PATCH] Implementation of the WeCNLP abstract "Cross+Self-Attention for Transformer Models" (#1097) Summary: This PR implements a new attention module which combines cross-attention (encoder-decoder attention) and the decoder self-attention. This work was accepted as an abstract at WeCNLP 2019 (https://www.wecnlp.ai/wecnlp-2019). Cross+Self-Attention reduces the amount of parameter and increases the inference speed without any degradation in translation quality. More details can be found in the attached [abstract](https://github.com/pytorch/fairseq/files/3561282/paper.pdf) Pull Request resolved: https://github.com/pytorch/fairseq/pull/1097 Differential Revision: D17653168 Pulled By: myleott fbshipit-source-id: deb834c2c78a229d7418ffbfea20ba3ce252991c --- fairseq/models/transformer.py | 66 ++++++++++++++++++++++++-- fairseq/modules/multihead_attention.py | 10 +++- fairseq/modules/transformer_layer.py | 34 ++++++++++--- tests/test_binaries.py | 22 ++++++++- 4 files changed, 120 insertions(+), 12 deletions(-) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index dd10ae5357..910c2eda09 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -122,6 +122,13 @@ def add_args(parser): 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') + # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) + parser.add_argument('--no-cross-attention', default=False, action='store_true', + help='do not perform cross-attention') + parser.add_argument('--cross-self-attention', default=False, action='store_true', + help='perform cross+self-attention') + parser.add_argument('--layer-wise-attention', default=False, action='store_true', + help='perform layer-wise attention (cross-attention or cross+self-attention)') # fmt: on @classmethod @@ -180,7 +187,12 @@ def build_encoder(cls, args, src_dict, embed_tokens): @classmethod def build_decoder(cls, args, tgt_dict, embed_tokens): - return TransformerDecoder(args, tgt_dict, embed_tokens) + return TransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, 'no_cross_attention', False), + ) class TransformerEncoder(FairseqEncoder): @@ -211,6 +223,8 @@ def __init__(self, args, dictionary, embed_tokens): learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None + self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) + self.layers = nn.ModuleList([]) self.layers.extend([ TransformerEncoderLayer(args) @@ -230,13 +244,15 @@ def forward_embedding(self, src_tokens): x = F.dropout(x, p=self.dropout, training=self.training) return x, embed - def forward(self, src_tokens, src_lengths, cls_input=None): + def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False): """ Args: src_tokens (LongTensor): tokens in the source language of shape `(batch, src_len)` src_lengths (torch.LongTensor): lengths of each source sentence of shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). Returns: dict: @@ -244,7 +260,13 @@ def forward(self, src_tokens, src_lengths, cls_input=None): shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. """ + if self.layer_wise_attention: + return_all_hiddens = True + x, encoder_embedding = self.forward_embedding(src_tokens) # B x T x C -> T x B x C @@ -255,17 +277,24 @@ def forward(self, src_tokens, src_lengths, cls_input=None): if not encoder_padding_mask.any(): encoder_padding_mask = None + encoder_states = [] if return_all_hiddens else None + # encoder layers for layer in self.layers: x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) if self.layer_norm: x = self.layer_norm(x) + if return_all_hiddens: + encoder_states[-1] = x return { 'encoder_out': x, # T x B x C 'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_embedding': encoder_embedding, # B x T x C + 'encoder_states': encoder_states, # List[T x B x C] } def reorder_encoder_out(self, encoder_out, new_order): @@ -285,6 +314,9 @@ def reorder_encoder_out(self, encoder_out, new_order): if encoder_out['encoder_padding_mask'] is not None: encoder_out['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'].index_select(0, new_order) + if encoder_out.get('encoder_states', None) is not None: + for idx, state in enumerate(encoder_out['encoder_states']): + encoder_out['encoder_states'][idx] = state.index_select(1, new_order) return encoder_out def max_positions(self): @@ -293,6 +325,14 @@ def max_positions(self): return self.max_source_positions return min(self.max_source_positions, self.embed_positions.max_positions()) + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: + self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) + return self._future_mask[:dim, :dim] + def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): @@ -350,6 +390,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None + self.cross_self_attention = getattr(args, 'cross_self_attention', False) + self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) + self.layers = nn.ModuleList([]) self.layers.extend([ TransformerDecoderLayer(args, no_encoder_attn) @@ -435,14 +478,26 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta inner_states = [x] + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + if not self_attn_padding_mask.any() and not self.cross_self_attention: + self_attn_padding_mask = None + # decoder layers - for layer in self.layers: + for idx, layer in enumerate(self.layers): + encoder_state = None + if encoder_out is not None: + if self.layer_wise_attention: + encoder_state = encoder_out['encoder_states'][idx] + else: + encoder_state = encoder_out['encoder_out'] + x, attn = layer( x, - encoder_out['encoder_out'] if encoder_out is not None else None, + encoder_state, encoder_out['encoder_padding_mask'] if encoder_out is not None else None, incremental_state, self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, + self_attn_padding_mask=self_attn_padding_mask, ) inner_states.append(x) @@ -553,6 +608,9 @@ def base_architecture(args): args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False) args.adaptive_input = getattr(args, 'adaptive_input', False) + args.no_cross_attention = getattr(args, 'no_cross_attention', False) + args.cross_self_attention = getattr(args, 'cross_self_attention', False) + args.layer_wise_attention = getattr(args, 'layer_wise_attention', False) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 8c28255dfb..9aaea82484 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -186,8 +186,15 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No v = prev_value else: v = torch.cat((prev_value, v), dim=1) + if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + if static_kv: + key_padding_mask = prev_key_padding_mask + else: + key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask self._set_input_buffer(incremental_state, saved_state) @@ -311,7 +318,8 @@ def reorder_incremental_state(self, incremental_state, new_order): input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: for k in input_buffer.keys(): - input_buffer[k] = input_buffer[k].index_select(0, new_order) + if input_buffer[k] is not None: + input_buffer[k] = input_buffer[k].index_select(0, new_order) self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index f4a80cceea..63c6cdf552 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils @@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module): def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False): super().__init__() self.embed_dim = args.decoder_embed_dim + self.cross_self_attention = getattr(args, 'cross_self_attention', False) self.self_attn = MultiheadAttention( embed_dim=self.embed_dim, num_heads=args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, - self_attention=True + self_attention=not self.cross_self_attention, ) self.dropout = args.dropout self.activation_fn = utils.get_activation_fn( @@ -208,13 +210,27 @@ def forward( if prev_self_attn_state is not None: if incremental_state is None: incremental_state = {} - prev_key, prev_value = prev_self_attn_state + prev_key, prev_value = prev_self_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] self.self_attn._set_input_buffer(incremental_state, saved_state) + + if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)): + if self_attn_mask is not None: + self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_() + self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1) + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + x, attn = self.self_attn( query=x, - key=x, - value=x, + key=y, + value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, @@ -230,9 +246,12 @@ def forward( if prev_attn_state is not None: if incremental_state is None: incremental_state = {} - prev_key, prev_value = prev_attn_state + prev_key, prev_value = prev_attn_state[:2] saved_state = {"prev_key": prev_key, "prev_value": prev_value} + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( query=x, key=encoder_out, @@ -256,7 +275,10 @@ def forward( x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) if self.onnx_trace and incremental_state is not None: saved_state = self.self_attn._get_input_buffer(incremental_state) - self_attn_state = saved_state["prev_key"], saved_state["prev_value"] + if self_attn_padding_mask is not None: + self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"] + else: + self_attn_state = saved_state["prev_key"], saved_state["prev_value"] return x, attn, self_attn_state return x, attn diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 8cede3c9fa..f77806bd6a 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -154,6 +154,23 @@ def test_transformer(self): ], run_validation=True) generate_main(data_dir) + def test_transformer_cross_self_attention(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'transformer_iwslt_de_en', [ + '--encoder-layers', '2', + '--decoder-layers', '2', + '--encoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--decoder-embed-dim', '8', + '--no-cross-attention', + '--cross-self-attention', + '--layer-wise-attention', + ], run_validation=True) + generate_main(data_dir, extra_flags=[]) + def test_lightconv(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_lightconv') as data_dir: @@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation' def generate_main(data_dir, extra_flags=None): + if extra_flags is None: + extra_flags = [ + '--print-alignment', + ] generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, @@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None): '--max-len-b', '5', '--gen-subset', 'valid', '--no-progress-bar', - '--print-alignment', ] + (extra_flags or []), )