Skip to content

Commit

Permalink
Transformer dropout (#544)
Browse files Browse the repository at this point in the history
* positional and residual dropout

* attention dropout
  • Loading branch information
msperber authored and neubig committed Nov 14, 2018
1 parent c2028c7 commit 3977c4b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 40 deletions.
4 changes: 4 additions & 0 deletions examples/21_self_attention.yaml
Expand Up @@ -36,6 +36,7 @@
- !PositionalSeqTransducer
input_dim: 512
max_pos: 100
dropout: 0.1
- !ModularSeqTransducer
modules: !Repeat
times: 2
Expand All @@ -45,13 +46,16 @@
input_dim: 512
child: !MultiHeadAttentionSeqTransducer
num_heads: 8
dropout: 0.1
layer_norm: True
dropout: 0.1
- !ResidualSeqTransducer
input_dim: 512
child: !TransformSeqTransducer
transform: !MLP
activation: relu
layer_norm: True
dropout: 0.1
attender: !MlpAttender
hidden_dim: 512
state_dim: 512
Expand Down
30 changes: 19 additions & 11 deletions xnmt/transducers/positional.py
Expand Up @@ -3,47 +3,53 @@

import dynet as dy

from xnmt.expression_seqs import ExpressionSequence
from xnmt.param_collections import ParamManager
from xnmt.param_initializers import ParamInitializer, GlorotInitializer
from xnmt import events, expression_seqs, param_collections, param_initializers
from xnmt.transducers import base as transducers
from xnmt.persistence import serializable_init, Serializable, bare, Ref
from xnmt.transducers.base import SeqTransducer, FinalTransducerState


# Note: alternatively, this could wrap "PositionEmbedder", but it seems to me
# that PositionEmbedder is probably not necessary in the first place, so
# it probably makes more sense to have this as a SeqTransducer that
# adds positional embeddings to an input
class PositionalSeqTransducer(SeqTransducer, Serializable):
class PositionalSeqTransducer(transducers.SeqTransducer, Serializable):
yaml_tag = '!PositionalSeqTransducer'

@events.register_xnmt_handler
@serializable_init
def __init__(self,
max_pos: numbers.Integral,
op: str = 'sum',
emb_type: str = 'param',
input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))):
dropout=Ref("exp_global.dropout", default=0.0),
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer))):
"""
max_pos: largest embedded position
op: how to combine positional encodings with the original encodings, can be "sum" or "concat"
type: what type of embddings to use, "param"=parameterized (others, such as the trigonometric embeddings are todo)
input_dim: embedding size
dropout: apply dropout to output of this transducer
param_init: how to initialize embedding matrix
"""
self.max_pos = max_pos
self.input_dim = input_dim
self.dropout = dropout
self.op = op
self.emb_type = emb_type
param_init = param_init
dim = (self.input_dim, max_pos)
param_collection = ParamManager.my_params(self)
param_collection = param_collections.ParamManager.my_params(self)
self.embedder = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True))

def get_final_states(self) -> List[FinalTransducerState]:
@ events.handle_xnmt_event
def on_set_train(self, val):
self.train = val

def get_final_states(self) -> List[transducers.FinalTransducerState]:
return self._final_states

def transduce(self, src: ExpressionSequence) -> ExpressionSequence:
def transduce(self, src: expression_seqs.ExpressionSequence) -> expression_seqs.ExpressionSequence:
sent_len = len(src)
embeddings = dy.strided_select(dy.parameter(self.embedder), [1,1], [0,0], [self.input_dim, sent_len])
if self.op == 'sum':
Expand All @@ -52,6 +58,8 @@ def transduce(self, src: ExpressionSequence) -> ExpressionSequence:
output = dy.concatenate([embeddings, src.as_tensor()])
else:
raise ValueError(f'Illegal op {op} in PositionalTransducer (options are "sum"/"concat")')
output_seq = ExpressionSequence(expr_tensor=output, mask=src.mask)
self._final_states = [FinalTransducerState(output_seq[-1])]
if self.train and self.dropout > 0.0:
output = dy.dropout(output, self.dropout)
output_seq = expression_seqs.ExpressionSequence(expr_tensor=output, mask=src.mask)
self._final_states = [transducers.FinalTransducerState(output_seq[-1])]
return output_seq
35 changes: 23 additions & 12 deletions xnmt/transducers/residual.py
Expand Up @@ -3,43 +3,54 @@

import dynet as dy

from xnmt.expression_seqs import ExpressionSequence
from xnmt.persistence import serializable_init, Serializable
from xnmt.transducers.base import SeqTransducer, FinalTransducerState
from xnmt.param_collections import ParamManager
from xnmt import events, expression_seqs, param_collections
from xnmt.transducers import base as transducers
from xnmt.persistence import Ref, serializable_init, Serializable

class ResidualSeqTransducer(SeqTransducer, Serializable):
class ResidualSeqTransducer(transducers.SeqTransducer, Serializable):
"""
A sequence transducer that wraps a :class:`xnmt.transducer.SeqTransducer` in an additive residual
A sequence transducer that wraps a :class:`xnmt.transducers.base.SeqTransducer` in an additive residual
connection, and optionally performs some variety of normalization.
Args:
child the child transducer to wrap
layer_norm: whether to perform layer normalization
dropout: whether to apply residual dropout
"""

yaml_tag = '!ResidualSeqTransducer'

@events.register_xnmt_handler
@serializable_init
def __init__(self, child: SeqTransducer, input_dim: numbers.Integral, layer_norm: bool = False):
def __init__(self, child: transducers.SeqTransducer, input_dim: numbers.Integral, layer_norm: bool = False,
dropout=Ref("exp_global.dropout", default=0.0)) -> None:
self.child = child
self.dropout = dropout
self.input_dim = input_dim
self.layer_norm = layer_norm
if layer_norm:
model = ParamManager.my_params(self)
model = param_collections.ParamManager.my_params(self)
self.ln_g = model.add_parameters(dim=(input_dim,))
self.ln_b = model.add_parameters(dim=(input_dim,))

def transduce(self, seq: ExpressionSequence) -> ExpressionSequence:
seq_tensor = self.child.transduce(seq).as_tensor() + seq.as_tensor()
@ events.handle_xnmt_event
def on_set_train(self, val):
self.train = val

def transduce(self, seq: expression_seqs.ExpressionSequence) -> expression_seqs.ExpressionSequence:

if self.train and self.dropout > 0.0:
seq_tensor = dy.dropout(self.child.transduce(seq).as_tensor(), self.dropout) + seq.as_tensor()
else:
seq_tensor = self.child.transduce(seq).as_tensor() + seq.as_tensor()
if self.layer_norm:
d = seq_tensor.dim()
seq_tensor = dy.reshape(seq_tensor, (d[0][0],), batch_size=d[0][1]*d[1])
seq_tensor = dy.layer_norm(seq_tensor, self.ln_g, self.ln_b)
seq_tensor = dy.reshape(seq_tensor, d[0], batch_size=d[1])
return ExpressionSequence(expr_tensor=seq_tensor)
return expression_seqs.ExpressionSequence(expr_tensor=seq_tensor)

def get_final_states(self) -> List[FinalTransducerState]:
def get_final_states(self) -> List[transducers.FinalTransducerState]:
# TODO: is this OK to do?
return self.child.get_final_states()

40 changes: 23 additions & 17 deletions xnmt/transducers/self_attention.py
Expand Up @@ -3,33 +3,37 @@
from math import sqrt
from typing import List

from xnmt.events import register_xnmt_handler, handle_xnmt_event
from xnmt.expression_seqs import ExpressionSequence
from xnmt.param_collections import ParamManager
from xnmt.param_initializers import GlorotInitializer, ZeroInitializer
from xnmt import events, expression_seqs, param_collections, param_initializers
from xnmt.persistence import serializable_init, Serializable, bare, Ref
from xnmt.transducers.base import SeqTransducer, FinalTransducerState
from xnmt.transducers import base as transducers

class MultiHeadAttentionSeqTransducer(SeqTransducer, Serializable):
class MultiHeadAttentionSeqTransducer(transducers.SeqTransducer, Serializable):
"""
This implements the Multi-headed attention layer of "Attention is All You Need":
https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
Args:
input_dim: size of inputs
dropout: dropout to apply to attention matrix
param_init: how to initialize param matrices
bias_init: how to initialize bias params
num_heads: number of attention heads
"""
yaml_tag = '!MultiHeadAttentionSeqTransducer'

@register_xnmt_handler
@events.register_xnmt_handler
@serializable_init
def __init__(self,
input_dim=Ref("exp_global.default_layer_dim"),
param_init=Ref("exp_global.param_init", default=bare(GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(ZeroInitializer)),
dropout=Ref("exp_global.dropout", default=0.0),
param_init=Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)),
num_heads=8):
assert(input_dim % num_heads == 0)

param_collection = ParamManager.my_params(self)
self.dropout = dropout

param_collection = param_collections.ParamManager.my_params(self)

self.input_dim = input_dim
self.num_heads = num_heads
Expand All @@ -38,18 +42,18 @@ def __init__(self,
self.pWq, self.pWk, self.pWv, self.pWo = [param_collection.add_parameters(dim=(input_dim, input_dim), init=param_init.initializer((input_dim, input_dim))) for _ in range(4)]
self.pbq, self.pbk, self.pbv, self.pbo = [param_collection.add_parameters(dim=(1, input_dim), init=bias_init.initializer((1, input_dim,))) for _ in range(4)]

@handle_xnmt_event
@events.handle_xnmt_event
def on_start_sent(self, src):
self._final_states = None

def get_final_states(self) -> List[FinalTransducerState]:
def get_final_states(self) -> List[transducers.FinalTransducerState]:
return self._final_states

@handle_xnmt_event
@events.handle_xnmt_event
def on_set_train(self, val):
self.train = val

def transduce(self, expr_seq: ExpressionSequence) -> ExpressionSequence:
def transduce(self, expr_seq: expression_seqs.ExpressionSequence) -> expression_seqs.ExpressionSequence:
"""
transduce the sequence
Expand Down Expand Up @@ -84,15 +88,17 @@ def transduce(self, expr_seq: ExpressionSequence) -> ExpressionSequence:
mask = dy.inputTensor(np.repeat(expr_seq.mask.np_arr, self.num_heads, axis=0).transpose(), batched=True) * -1e10
attn_score = attn_score + mask
attn_prob = dy.softmax(attn_score, d=1)
if self.train and self.dropout > 0.0:
attn_prob = dy.dropout(attn_prob, self.dropout)
# Reduce using attention and resize to match [(length, model_size) x batch]
o = dy.reshape(attn_prob * v, (x_len, self.input_dim), batch_size=x_batch)
# Final transformation
# o = dy.affine_transform([bo, attn_prob * v, Wo])
o = bo + o * Wo

expr_seq = ExpressionSequence(expr_transposed_tensor=o, mask=expr_seq.mask)
expr_seq = expression_seqs.ExpressionSequence(expr_transposed_tensor=o, mask=expr_seq.mask)

self._final_states = [FinalTransducerState(expr_seq[-1], None)]
self._final_states = [transducers.FinalTransducerState(expr_seq[-1], None)]

return expr_seq

Expand Down

0 comments on commit 3977c4b

Please sign in to comment.