Skip to content

Commit

Permalink
Drop masked decoder states (#370)
Browse files Browse the repository at this point in the history
* sort by trg len within batches

* layed out plan

* make sentencepience dependency optional

* [WIP] implement truncate_batches(); unit tests failing

* fixed most tests

* document truncate_batches()

* attenders take truncate_dec_batches field

* truncate_dec_batches arg for decoders

* some fixes

* remove some outdated comments
  • Loading branch information
msperber authored and neubig committed Jun 12, 2018
1 parent 90f9d89 commit 9c56b09
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 65 deletions.
8 changes: 4 additions & 4 deletions test/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def test_batch_src(self):
trg_sents = [xnmt.input.SimpleSentenceInput([0] * ((i+3)%6 + 1)) for i in range(1,7)]
my_batcher = xnmt.batcher.SrcBatcher(batch_size=3, src_pad_token=1, trg_pad_token=2)
src, trg = my_batcher.pack(src_sents, trg_sents)
self.assertEqual([[0, 1, 1], [0, 0, 1], [0, 0, 0]], [x.words for x in src[0]])
self.assertEqual([[0, 0, 0, 0, 0, 2], [0, 0, 0, 0, 0, 0], [0, 2, 2, 2, 2, 2]], [x.words for x in trg[0]])
self.assertEqual([[0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0]], [x.words for x in src[1]])
self.assertEqual([[0, 0, 2, 2], [0, 0, 0, 2], [0, 0, 0, 0]], [x.words for x in trg[1]])
self.assertEqual([[0, 0, 1], [0, 1, 1], [0, 0, 0]], [x.words for x in src[0]])
self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 2], [0, 2, 2, 2, 2, 2]], [x.words for x in trg[0]])
self.assertEqual([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1]], [x.words for x in src[1]])
self.assertEqual([[0, 0, 0, 0], [0, 0, 0, 2], [0, 0, 2, 2]], [x.words for x in trg[1]])

def test_batch_word_src(self):
src_sents = [xnmt.input.SimpleSentenceInput([0] * i) for i in range(1,7)]
Expand Down
6 changes: 3 additions & 3 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
batched_loss = model.calc_loss(src=mark_as_batch(src_sents_trunc),
trg=mark_as_batch(trg_sents_trunc),
loss_calculator=MLELoss()).value()
self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)

def test_loss_model1(self):
layer_dim = 512
Expand Down Expand Up @@ -189,7 +189,7 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
single_sent[src_min-1] = Vocab.ES
while len(single_sent)%pad_src_to_multiple != 0:
single_sent.append(Vocab.ES)
trg_sents = self.trg_data[:batch_size]
trg_sents = sorted(self.trg_data[:batch_size], key=lambda x: len(x), reverse=True)
trg_max = max([len(x) for x in trg_sents])
trg_masks = Mask(np.zeros([batch_size, trg_max]))
for i in range(batch_size):
Expand All @@ -210,7 +210,7 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
batched_loss = model.calc_loss(src=mark_as_batch(src_sents_trunc),
trg=mark_as_batch(trg_sents_padded, trg_masks),
loss_calculator=MLELoss()).value()
self.assertAlmostEqual(single_loss, sum(batched_loss), places=4)
self.assertAlmostEqual(single_loss, np.sum(batched_loss), places=4)

def test_loss_model1(self):
layer_dim = 512
Expand Down
66 changes: 41 additions & 25 deletions xnmt/attender.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import dynet as dy

from xnmt import logger
import xnmt.batcher
from xnmt.param_collection import ParamManager
from xnmt.param_init import GlorotInitializer, ZeroInitializer
from xnmt.param_init import GlorotInitializer, ZeroInitializer, ParamInitializer
from xnmt.persistence import serializable_init, Serializable, Ref, bare

class Attender(object):
Expand Down Expand Up @@ -39,29 +40,32 @@ def get_last_attention(self):
return self.attention_vecs[-1]

class MlpAttender(Attender, Serializable):
'''
"""
Implements the attention model of Bahdanau et. al (2014)
Args:
input_dim (int): input dimension
state_dim (int): dimension of state inputs
hidden_dim (int): hidden MLP dimension
param_init (ParamInitializer): how to initialize weight matrices
bias_init (ParamInitializer): how to initialize bias vectors
'''
input_dim: input dimension
state_dim: dimension of state inputs
hidden_dim: hidden MLP dimension
param_init: how to initialize weight matrices
bias_init: how to initialize bias vectors
truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step.
"""

yaml_tag = '!MlpAttender'

@serializable_init
def __init__(self,
input_dim=Ref("exp_global.default_layer_dim"),
state_dim=Ref("exp_global.default_layer_dim"),
hidden_dim=Ref("exp_global.default_layer_dim"),
param_init=Ref("exp_global.param_init", default=bare(GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(ZeroInitializer))):
input_dim: int = Ref("exp_global.default_layer_dim"),
state_dim: int = Ref("exp_global.default_layer_dim"),
hidden_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer)),
bias_init: ParamInitializer = Ref("exp_global.bias_init", default=bare(ZeroInitializer)),
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None:
self.input_dim = input_dim
self.state_dim = state_dim
self.hidden_dim = hidden_dim
self.truncate_dec_batches = truncate_dec_batches
param_collection = ParamManager.my_params(self)
self.pW = param_collection.add_parameters((hidden_dim, input_dim), init=param_init.initializer((hidden_dim, input_dim)))
self.pV = param_collection.add_parameters((hidden_dim, state_dim), init=param_init.initializer((hidden_dim, state_dim)))
Expand All @@ -86,32 +90,41 @@ def calc_attention(self, state):
V = dy.parameter(self.pV)
U = dy.parameter(self.pU)

h = dy.tanh(dy.colwise_add(self.WI, V * state))
WI = self.WI
curr_sent_mask = self.curr_sent.mask
if self.truncate_dec_batches:
if curr_sent_mask: state, WI, curr_sent_mask = xnmt.batcher.truncate_batches(state, WI, curr_sent_mask)
else: state, WI = xnmt.batcher.truncate_batches(state, WI)
h = dy.tanh(dy.colwise_add(WI, V * state))
scores = dy.transpose(U * h)
if self.curr_sent.mask is not None:
scores = self.curr_sent.mask.add_to_tensor_expr(scores, multiplicator = -100.0)
if curr_sent_mask is not None:
scores = curr_sent_mask.add_to_tensor_expr(scores, multiplicator = -100.0)
normalized = dy.softmax(scores)
self.attention_vecs.append(normalized)
return normalized

def calc_context(self, state):
attention = self.calc_attention(state)
I = self.curr_sent.as_tensor()
if self.truncate_dec_batches: I, attention = xnmt.batcher.truncate_batches(I, attention)
return I * attention

class DotAttender(Attender, Serializable):
'''
"""
Implements dot product attention of https://arxiv.org/abs/1508.04025
Also (optionally) perform scaling of https://arxiv.org/abs/1706.03762
Args:
scale (bool): whether to perform scaling
'''
scale: whether to perform scaling
truncate_dec_batches: currently unsupported
"""

yaml_tag = '!DotAttender'

@serializable_init
def __init__(self, scale:bool=True):
def __init__(self, scale: bool = True,
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None:
if truncate_dec_batches: raise NotImplementedError("truncate_dec_batches not yet implemented for DotAttender")
self.curr_sent = None
self.scale = scale
self.attention_vecs = []
Expand All @@ -137,23 +150,26 @@ def calc_context(self, state):
return I * attention

class BilinearAttender(Attender, Serializable):
'''
"""
Implements a bilinear attention, equivalent to the 'general' linear
attention of https://arxiv.org/abs/1508.04025
Args:
input_dim (int): input dimension; if None, use exp_global.default_layer_dim
state_dim (int): dimension of state inputs; if None, use exp_global.default_layer_dim
param_init (ParamInitializer): how to initialize weight matrices; if None, use ``exp_global.param_init``
'''
truncate_dec_batches: currently unsupported
"""

yaml_tag = '!BilinearAttender'

@serializable_init
def __init__(self,
input_dim=Ref("exp_global.default_layer_dim"),
state_dim=Ref("exp_global.default_layer_dim"),
param_init=Ref("exp_global.param_init", default=bare(GlorotInitializer))):
input_dim: int = Ref("exp_global.default_layer_dim"),
state_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer)),
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None:
if truncate_dec_batches: raise NotImplementedError("truncate_dec_batches not yet implemented for BilinearAttender")
self.input_dim = input_dim
self.state_dim = state_dim
param_collection = ParamManager.my_params(self)
Expand Down
52 changes: 51 additions & 1 deletion xnmt/batcher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Union, Sequence
import math
import random
import numpy as np
import dynet as dy
from xnmt.vocab import Vocab
from xnmt.persistence import serializable_init, Serializable
import xnmt.expression_sequence
from xnmt import lstm

class Batch(list):
"""
Expand Down Expand Up @@ -108,7 +111,9 @@ def is_random(self):
"""
return False

def add_single_batch(self, src_curr, trg_curr, src_ret, trg_ret):
def add_single_batch(self, src_curr, trg_curr, src_ret, trg_ret, sort_by_trg_len=True):
if trg_curr is not None and sort_by_trg_len:
src_curr, trg_curr = zip(*sorted(zip(src_curr, trg_curr), key=lambda x: len(x[1]), reverse=True))
src_id, src_mask = pad(src_curr, pad_token=self.src_pad_token, pad_to_multiple=self.pad_src_to_multiple)
src_ret.append(Batch(src_id, src_mask))
if trg_ret is not None:
Expand Down Expand Up @@ -508,3 +513,48 @@ def pack_by_order(self, src, trg, order):
self.batch_size = (sum([len(s) for s in src]) + sum([len(s) for s in trg])) / len(src) * self.avg_batch_size
return super(WordTrgSrcBatcher, self).pack_by_order(src, trg, order)

def truncate_batches(*xl: Union[dy.Expression, Batch, Mask, lstm.UniLSTMState]) \
-> Sequence[Union[dy.Expression, Batch, Mask, lstm.UniLSTMState]]:
"""
Truncate a list of batched items so that all items have the batch size of the input with the smallest batch size.
Inputs can be of various types and would usually correspond to a single time step.
Assume that the batch elements with index 0 correspond across the inputs, so that batch elements will be truncated
from the top, i.e. starting with the highest-indexed batch elements.
Masks are not considered even if attached to a input of :class:`Batch` type.
Args:
*xl: batched timesteps of various types
Returns:
Copies of the inputs, truncated to consistent batch size.
"""
batch_sizes = []
for x in xl:
if isinstance(x, dy.Expression) or isinstance(x, xnmt.expression_sequence.ExpressionSequence):
batch_sizes.append(x.dim()[1])
elif isinstance(x, Batch):
batch_sizes.append(len(x))
elif isinstance(x, Mask):
batch_sizes.append(x.batch_size())
elif isinstance(x, lstm.UniLSTMState):
batch_sizes.append(x.output().dim()[1])
else:
raise ValueError(f"unsupported type {type(x)}")
assert batch_sizes[-1] > 0
ret = []
for i, x in enumerate(xl):
if batch_sizes[i] > min(batch_sizes):
if isinstance(x, dy.Expression) or isinstance(x, xnmt.expression_sequence.ExpressionSequence):
ret.append(x[tuple([slice(None)]*len(x.dim()[0]) + [slice(min(batch_sizes))])])
elif isinstance(x, Batch):
ret.append(mark_as_batch(x[:min(batch_sizes)]))
elif isinstance(x, Mask):
ret.append(Mask(x.np_arr[:min(batch_sizes)]))
elif isinstance(x, lstm.UniLSTMState):
ret.append(x[:,:min(batch_sizes)])
else:
raise ValueError(f"unsupported type {type(x)}")
else:
ret.append(x)
return ret
49 changes: 28 additions & 21 deletions xnmt/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xnmt.residual
from xnmt.param_init import GlorotInitializer, ZeroInitializer
from xnmt import logger
from xnmt.bridge import CopyBridge
from xnmt.bridge import Bridge, CopyBridge
from xnmt.lstm import UniLSTMSeqTransducer
from xnmt.mlp import MLP
from xnmt.param_collection import ParamManager
Expand Down Expand Up @@ -43,16 +43,17 @@ class MlpSoftmaxDecoder(Decoder, Serializable):
Standard MLP softmax decoder.
Args:
input_dim (int): input dimension
trg_embed_dim (int): dimension of target embeddings
input_feeding (bool): whether to activate input feeding
rnn_layer (UniLSTMSeqTransducer): recurrent layer of the decoder
mlp_layer (MLP): final prediction layer of the decoder
bridge (Bridge): how to initialize decoder state
label_smoothing (float): label smoothing value (if used, 0.1 is a reasonable value).
Label Smoothing is implemented with reference to Section 7 of the paper
"Rethinking the Inception Architecture for Computer Vision"
(https://arxiv.org/pdf/1512.00567.pdf)
input_dim: input dimension
trg_embed_dim: dimension of target embeddings
input_feeding: whether to activate input feeding
rnn_layer: recurrent layer of the decoder
mlp_layer: final prediction layer of the decoder
bridge: how to initialize decoder state
truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step.
label_smoothing: label smoothing value (if used, 0.1 is a reasonable value).
Label Smoothing is implemented with reference to Section 7 of the paper
"Rethinking the Inception Architecture for Computer Vision"
(https://arxiv.org/pdf/1512.00567.pdf)
"""

# TODO: This should probably take a softmax object, which can be normal or class-factored, etc.
Expand All @@ -62,15 +63,17 @@ class MlpSoftmaxDecoder(Decoder, Serializable):

@serializable_init
def __init__(self,
input_dim=Ref("exp_global.default_layer_dim"),
trg_embed_dim=Ref("exp_global.default_layer_dim"),
input_feeding=True,
rnn_layer=bare(UniLSTMSeqTransducer),
mlp_layer=bare(MLP),
bridge=bare(CopyBridge),
label_smoothing=0.0):
input_dim: int = Ref("exp_global.default_layer_dim"),
trg_embed_dim: int = Ref("exp_global.default_layer_dim"),
input_feeding: bool = True,
rnn_layer: UniLSTMSeqTransducer = bare(UniLSTMSeqTransducer),
mlp_layer: MLP = bare(MLP),
bridge: Bridge = bare(CopyBridge),
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False),
label_smoothing: float = 0.0) -> None:
self.param_col = ParamManager.my_params(self)
self.input_dim = input_dim
self.truncate_dec_batches = truncate_dec_batches
self.label_smoothing = label_smoothing
# Input feeding
self.input_feeding = input_feeding
Expand Down Expand Up @@ -123,7 +126,9 @@ def add_input(self, mlp_dec_state, trg_embedding):
inp = trg_embedding
if self.input_feeding:
inp = dy.concatenate([inp, mlp_dec_state.context])
return MlpSoftmaxDecoderState(rnn_state=mlp_dec_state.rnn_state.add_input(inp),
rnn_state = mlp_dec_state.rnn_state
if self.truncate_dec_batches: rnn_state, inp = xnmt.batcher.truncate_batches(rnn_state, inp)
return MlpSoftmaxDecoderState(rnn_state=rnn_state.add_input(inp),
context=mlp_dec_state.context)

def get_scores(self, mlp_dec_state):
Expand All @@ -148,6 +153,7 @@ def calc_loss(self, mlp_dec_state, ref_action):
return dy.pickneglogsoftmax(scores, ref_action)
# minibatch mode
else:
if self.truncate_dec_batches: scores, ref_action = xnmt.batcher.truncate_batches(scores, ref_action)
return dy.pickneglogsoftmax_batch(scores, ref_action)

else:
Expand Down Expand Up @@ -181,11 +187,12 @@ def __init__(self,
lexicon_type='bias',
lexicon_alpha=0.001,
linear_projector=None,
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False),
param_init_lin=Ref("exp_global.param_init", default=bare(GlorotInitializer)),
bias_init_lin=Ref("exp_global.bias_init", default=bare(ZeroInitializer)),
):
) -> None:
super().__init__(input_dim, trg_embed_dim, input_feeding, rnn_layer,
mlp_layer, bridge, label_smoothing)
mlp_layer, bridge, truncate_dec_batches, label_smoothing)
assert lexicon_file is not None
self.lexicon_file = lexicon_file
self.src_vocab = src_vocab
Expand Down
3 changes: 3 additions & 0 deletions xnmt/exp_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ExpGlobal(Serializable):
default_layer_dim: Default layer dimension that should be used by supporting components but can be overwritten
param_init: Default parameter initializer that should be used by supporting components but can be overwritten
bias_init: Default initializer for bias parameters that should be used by supporting components but can be overwritten
truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step.
save_num_checkpoints: save DyNet parameters for the most recent n checkpoints, useful for model averaging/ensembling
loss_comb_method: method for combining loss across batch elements ('sum' or 'avg').
commandline_args: Holds commandline arguments with which XNMT was launched
Expand All @@ -36,6 +37,7 @@ def __init__(self,
default_layer_dim: int = 512,
param_init: ParamInitializer = bare(GlorotInitializer),
bias_init: ParamInitializer = bare(ZeroInitializer),
truncate_dec_batches: bool = False,
save_num_checkpoints: int = 1,
loss_comb_method: str = "sum",
commandline_args=None,
Expand All @@ -47,6 +49,7 @@ def __init__(self,
self.default_layer_dim = default_layer_dim
self.param_init = param_init
self.bias_init = bias_init
self.truncate_dec_batches = truncate_dec_batches
self.commandline_args = commandline_args
self.save_num_checkpoints = save_num_checkpoints
self.loss_comb_method = loss_comb_method
Expand Down
2 changes: 2 additions & 0 deletions xnmt/expression_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def has_tensor(self):

def dim(self):
"""
Return dimension of the expression sequence
Returns:
result of self.as_tensor().dim(), without explicitly constructing that tensor
"""
Expand Down

0 comments on commit 9c56b09

Please sign in to comment.