Skip to content
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
"""
This module implements self-attentional acoustic models as described in the following paper:
Self-Attentional Acoustic Models
Matthias Sperber, Jan Niehues, Graham Neubig, Sebastian Stüker, Alex Waibel
Interspeech 2018
https://arxiv.org/abs/1803.09519
The main class to be aware of is :class:`SAAMSeqTransducer`.
"""
import typing
import logging
import numbers
yaml_logger = logging.getLogger('yaml')
from collections.abc import Sequence
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.stats import entropy
import numpy as np
import xnmt.tensor_tools as tt
import xnmt.param_initializers
from xnmt import norms, events, param_collections
from xnmt.modelparts import transforms
from xnmt.transducers import recurrent, base as transducers
from xnmt.expression_seqs import ExpressionSequence
from xnmt.persistence import Serializable, serializable_init, Ref, bare
from xnmt.modelparts import embedders
if xnmt.backend_dynet:
import dynet as dy
LOG_ATTENTION = False
@xnmt.require_dynet
class SAAMTimeDistributed(object):
"""
A Callable that puts the time-dimension of an input expression into the batch dimension via a reshape.
"""
def __call__(self, x: ExpressionSequence) -> tt.Tensor:
"""
Move the time-dimension of an input expression into the batch dimension via a reshape.
Args:
x: expression of dimensions ((hidden, timesteps), batch_size)
Returns:
expression of dimensions ((hidden,), timesteps*batch_size)
"""
batch_size = x[0].dim()[1]
model_dim = x[0].dim()[0][0]
seq_len = x.sent_len()
total_words = seq_len * batch_size
input_tensor = x.as_tensor()
return dy.reshape(input_tensor, (model_dim,), batch_size=total_words)
@xnmt.require_dynet
class SAAMPositionwiseFeedForward(Serializable):
"""
Interleaved feed-forward components of the transformer, computed as layer_norm(dropout(linear(relu(linear))) + x).
Args:
input_dim: the size of input for the first-layer of the FFN.
hidden_dim: the hidden layer size of the second-layer of the FNN.
nonlinearity: non-linearity, a DyNet unary operation (rectify according to the transformer paper)
linear_transforms: list with 2 items corresponding to the first and second linear transformations to be used
(pass ``None`` to create automatically)
layer_norm: layer norm object to be applied (pass ``None`` to create automatically)
"""
yaml_tag = "!SAAMPositionwiseFeedForward"
@serializable_init
def __init__(self, input_dim: int, hidden_dim: int, nonlinearity: str = "rectify",
linear_transforms: typing.Optional[typing.Sequence[transforms.Linear]] = None,
layer_norm: typing.Optional[norms.LayerNorm] = None) -> None:
w_12 = self.add_serializable_component("linear_transforms", linear_transforms,
lambda: [transforms.Linear(input_dim, hidden_dim),
transforms.Linear(hidden_dim, input_dim)])
self.w_1 = w_12[0]
self.w_2 = w_12[1]
self.layer_norm = self.add_serializable_component("layer_norm", layer_norm, lambda: norms.LayerNorm(input_dim))
self.nonlinearity = getattr(dy, nonlinearity)
def __call__(self, x, p):
residual = x
output = self.w_2.transform(self.nonlinearity(self.w_1.transform(x)))
if p > 0.0:
output = dy.dropout(output, p)
return self.layer_norm.transform(output + residual)
@xnmt.require_dynet
class SAAMMultiHeadedSelfAttention(Serializable):
"""
Args:
head_count: number of self-att heads
model_dim: model dimension
downsample_factor: downsampling factor (>=1)
input_dim: input dimension
ignore_masks: don't apply any masking
plot_attention: None or path to directory to write plots to
diag_gauss_mask: False to disable, otherwise a float denoting the std of the mask
square_mask_std: whether to square the std parameter to facilitate training
cross_pos_encoding_type: None 'embedding'
kq_pos_encoding_type: None or 'embedding'
kq_pos_encoding_size:
max_len: max sequence length (used to determine positional embedding size)
param_init: initializer for weight matrices
bias_init: initializer for bias vectors
linear_kvq:
kq_positional_embedder:
layer_norm:
res_shortcut:
desc: useful to describe layer if plot_attention is given.
"""
yaml_tag = "!SAAMMultiHeadedSelfAttention"
@events.register_xnmt_handler
@serializable_init
def __init__(self, head_count: int, model_dim: int, downsample_factor: int = 1, input_dim: int = None,
ignore_masks: bool = False, plot_attention: typing.Optional[str] = None,
diag_gauss_mask: typing.Union[bool, numbers.Real] = False,
square_mask_std: bool = True, cross_pos_encoding_type: typing.Optional[str] = None,
kq_pos_encoding_type: typing.Optional[str] = None, kq_pos_encoding_size: int = 40, max_len: int = 1500,
param_init: xnmt.param_initializers.ParamInitializer = xnmt.param_initializers.GlorotInitializer(),
bias_init: xnmt.param_initializers.ParamInitializer = xnmt.param_initializers.ZeroInitializer(),
linear_kvq = None, kq_positional_embedder = None, layer_norm = None, res_shortcut = None,
desc: typing.Any = None) -> None:
if input_dim is None: input_dim = model_dim
self.input_dim = input_dim
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
self.head_count = head_count
assert downsample_factor >= 1
self.downsample_factor = downsample_factor
self.plot_attention = plot_attention
self.plot_attention_counter = 0
self.desc = desc
self.ignore_masks = ignore_masks
self.diag_gauss_mask = diag_gauss_mask
self.square_mask_std = square_mask_std
self.kq_pos_encoding_type = kq_pos_encoding_type
self.kq_pos_encoding_size = kq_pos_encoding_size
self.max_len = max_len
subcol = param_collections.ParamManager.my_params(self)
if self.kq_pos_encoding_type is None:
self.linear_kvq = self.add_serializable_component("linear_kvq", linear_kvq,
lambda: transforms.Linear(input_dim * downsample_factor,
head_count * self.dim_per_head * 3,
param_init=param_init,
bias_init=bias_init))
else:
self.linear_kq, self.linear_v = \
self.add_serializable_component("linear_kvq",
linear_kvq,
lambda: [
transforms.Linear(input_dim * downsample_factor + self.kq_pos_encoding_size,
head_count * self.dim_per_head * 2, param_init=param_init,
bias_init=bias_init),
transforms.Linear(input_dim * downsample_factor, head_count * self.dim_per_head,
param_init=param_init, bias_init=bias_init)])
assert self.kq_pos_encoding_type == "embedding"
self.kq_positional_embedder = self.add_serializable_component("kq_positional_embedder",
kq_positional_embedder,
lambda: embedders.PositionEmbedder(
max_pos=self.max_len,
emb_dim=self.kq_pos_encoding_size,
param_init=param_init))
if self.diag_gauss_mask:
if self.diag_gauss_mask == "rand":
rand_init = np.exp((np.random.random(size=(self.head_count,))) * math.log(1000))
self.diag_gauss_mask_sigma = subcol.add_parameters(dim=(1, 1, self.head_count),
init=dy.NumpyInitializer(rand_init))
else:
self.diag_gauss_mask_sigma = subcol.add_parameters(dim=(1, 1, self.head_count),
init=dy.ConstInitializer(self.diag_gauss_mask))
self.layer_norm = self.add_serializable_component("layer_norm", layer_norm, lambda: norms.LayerNorm(model_dim))
if model_dim != input_dim * downsample_factor:
self.res_shortcut = self.add_serializable_component("res_shortcut", res_shortcut,
lambda: transforms.Linear(input_dim * downsample_factor,
model_dim,
param_init=param_init,
bias_init=bias_init))
self.cross_pos_encoding_type = cross_pos_encoding_type
if cross_pos_encoding_type == "embedding":
self.cross_pos_emb_p1 = subcol.add_parameters(dim=(self.max_len, self.dim_per_head, self.head_count),
init=dy.NormalInitializer(mean=1.0, var=0.001))
self.cross_pos_emb_p2 = subcol.add_parameters(dim=(self.max_len, self.dim_per_head, self.head_count),
init=dy.NormalInitializer(mean=1.0, var=0.001))
elif cross_pos_encoding_type is not None:
raise NotImplementedError()
def plot_att_mat(self, mat, filename, dpi=1200):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.matshow(mat)
ax.set_aspect('auto')
fig.savefig(filename, dpi=dpi)
fig.clf()
plt.close('all')
def shape_projection(self, x, batch_size):
total_words = x.dim()[1]
seq_len = total_words / batch_size
out = dy.reshape(x, (self.model_dim, seq_len), batch_size=batch_size)
out = dy.transpose(out)
return dy.reshape(out, (seq_len, self.dim_per_head), batch_size=batch_size * self.head_count)
def __call__(self, x: tt.Tensor, att_mask: np.ndarray, batch_mask: np.ndarray, p: numbers.Real):
"""
x: expression of dimensions (input_dim, time) x batch
att_mask: numpy array of dimensions (time, time); pre-transposed
batch_mask: numpy array of dimensions (batch, time)
p: dropout prob
"""
sent_len = x.dim()[0][1]
batch_size = x[0].dim()[1]
if self.downsample_factor > 1:
if sent_len % self.downsample_factor != 0:
raise ValueError(
"For 'reshape' downsampling, sequence lengths must be multiples of the downsampling factor. "
"Configure batcher accordingly.")
if batch_mask is not None: batch_mask = batch_mask[:, ::self.downsample_factor]
sent_len_out = sent_len // self.downsample_factor
sent_len = sent_len_out
out_mask = x.mask
if self.downsample_factor > 1 and out_mask is not None:
out_mask = out_mask.lin_subsampled(reduce_factor=self.downsample_factor)
x = ExpressionSequence(expr_tensor=dy.reshape(x.as_tensor(), (
x.dim()[0][0] * self.downsample_factor, x.dim()[0][1] / self.downsample_factor), batch_size=batch_size),
mask=out_mask)
residual = SAAMTimeDistributed()(x)
else:
residual = SAAMTimeDistributed()(x)
sent_len_out = sent_len
if self.model_dim != self.input_dim * self.downsample_factor:
residual = self.res_shortcut.transform(residual)
# Concatenate all the words together for doing vectorized affine transform
if self.kq_pos_encoding_type is None:
kvq_lin = self.linear_kvq.transform(SAAMTimeDistributed()(x))
key_up = self.shape_projection(dy.pick_range(kvq_lin, 0, self.head_count * self.dim_per_head), batch_size)
value_up = self.shape_projection(
dy.pick_range(kvq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head),
batch_size)
query_up = self.shape_projection(
dy.pick_range(kvq_lin, 2 * self.head_count * self.dim_per_head, 3 * self.head_count * self.dim_per_head),
batch_size)
else:
assert self.kq_pos_encoding_type == "embedding"
encoding = self.kq_positional_embedder.embed_sent(sent_len).as_tensor()
kq_lin = self.linear_kq.transform(
SAAMTimeDistributed()(
ExpressionSequence(expr_tensor=dy.concatenate([x.as_tensor(), encoding]))))
key_up = self.shape_projection(dy.pick_range(kq_lin, 0, self.head_count * self.dim_per_head), batch_size)
query_up = self.shape_projection(
dy.pick_range(kq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head), batch_size)
v_lin = self.linear_v.transform(SAAMTimeDistributed()(x))
value_up = self.shape_projection(v_lin, batch_size)
if self.cross_pos_encoding_type:
assert self.cross_pos_encoding_type == "embedding"
emb1 = dy.pick_range(dy.parameter(self.cross_pos_emb_p1), 0, sent_len)
emb2 = dy.pick_range(dy.parameter(self.cross_pos_emb_p2), 0, sent_len)
key_up = dy.reshape(key_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size)
key_up = dy.concatenate_cols([dy.cmult(key_up, emb1), dy.cmult(key_up, emb2)])
key_up = dy.reshape(key_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size)
query_up = dy.reshape(query_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size)
query_up = dy.concatenate_cols([dy.cmult(query_up, emb2), dy.cmult(query_up, -emb1)])
query_up = dy.reshape(query_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size)
scaled = query_up * dy.transpose(
key_up / math.sqrt(self.dim_per_head)) # scale before the matrix multiplication to save memory
# Apply Mask here
if not self.ignore_masks:
if att_mask is not None:
att_mask_inp = att_mask * -100.0
if self.downsample_factor > 1:
att_mask_inp = att_mask_inp[::self.downsample_factor, ::self.downsample_factor]
scaled += dy.inputTensor(att_mask_inp)
if batch_mask is not None:
# reshape (batch, time) -> (time, head_count*batch), then *-100
inp = np.resize(np.broadcast_to(batch_mask.T[:, np.newaxis, :],
(sent_len, self.head_count, batch_size)),
(1, sent_len, self.head_count * batch_size)) \
* -100
mask_expr = dy.inputTensor(inp, batched=True)
scaled += mask_expr
if self.diag_gauss_mask:
diag_growing = np.zeros((sent_len, sent_len, self.head_count))
for i in range(sent_len):
for j in range(sent_len):
diag_growing[i, j, :] = -(i - j) ** 2 / 2.0
e_diag_gauss_mask = dy.inputTensor(diag_growing)
e_sigma = dy.parameter(self.diag_gauss_mask_sigma)
if self.square_mask_std:
e_sigma = dy.square(e_sigma)
e_sigma_sq_inv = dy.cdiv(dy.ones(e_sigma.dim()[0], batch_size=batch_size), dy.square(e_sigma))
e_diag_gauss_mask_final = dy.cmult(e_diag_gauss_mask, e_sigma_sq_inv)
scaled += dy.reshape(e_diag_gauss_mask_final, (sent_len, sent_len), batch_size=batch_size * self.head_count)
# Computing Softmax here.
attn = dy.softmax(scaled, d=1)
if LOG_ATTENTION:
yaml_logger.info({"key": "selfatt_mat_ax0", "value": np.average(attn.value(), axis=0).dumps(), "desc": self.desc})
yaml_logger.info({"key": "selfatt_mat_ax1", "value": np.average(attn.value(), axis=1).dumps(), "desc": self.desc})
yaml_logger.info({"key": "selfatt_mat_ax0_ent", "value": entropy(attn.value()).dumps(), "desc": self.desc})
yaml_logger.info(
{"key": "selfatt_mat_ax1_ent", "value": entropy(attn.value().transpose()).dumps(), "desc": self.desc})
self.select_att_head = 0
if self.select_att_head is not None:
attn = dy.reshape(attn, (sent_len, sent_len, self.head_count), batch_size=batch_size)
sel_mask = np.zeros((1, 1, self.head_count))
sel_mask[0, 0, self.select_att_head] = 1.0
attn = dy.cmult(attn, dy.inputTensor(sel_mask))
attn = dy.reshape(attn, (sent_len, sent_len), batch_size=self.head_count * batch_size)
# Applying dropout to attention
if p > 0.0:
drop_attn = dy.dropout(attn, p)
else:
drop_attn = attn
# Computing weighted attention score
attn_prod = drop_attn * value_up
# Reshaping the attn_prod to input query dimensions
out = dy.reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), batch_size=batch_size)
out = dy.transpose(out)
out = dy.reshape(out, (self.model_dim,), batch_size=batch_size * sent_len_out)
# out = dy.reshape_transpose_reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), (self.model_dim,), pre_batch_size=batch_size, post_batch_size=batch_size*sent_len_out)
if self.plot_attention:
from sklearn.metrics.pairwise import cosine_similarity
assert batch_size == 1
mats = []
for i in range(attn.dim()[1]):
mats.append(dy.pick_batch_elem(attn, i).npvalue())
self.plot_att_mat(mats[-1],
"{}.sent_{}.head_{}.png".format(self.plot_attention, self.plot_attention_counter, i),
300)
avg_mat = np.average(mats, axis=0)
self.plot_att_mat(avg_mat,
"{}.sent_{}.head_avg.png".format(self.plot_attention, self.plot_attention_counter),
300)
cosim_before = cosine_similarity(x.as_tensor().npvalue().T)
self.plot_att_mat(cosim_before,
"{}.sent_{}.cosim_before.png".format(self.plot_attention, self.plot_attention_counter),
600)
cosim_after = cosine_similarity(out.npvalue().T)
self.plot_att_mat(cosim_after,
"{}.sent_{}.cosim_after.png".format(self.plot_attention, self.plot_attention_counter),
600)
self.plot_attention_counter += 1
# Adding dropout and layer normalization
if p > 0.0:
res = dy.dropout(out, p) + residual
else:
res = out + residual
ret = self.layer_norm.transform(res)
return ret
@events.handle_xnmt_event
def on_new_epoch(self, training_task, num_sents):
yaml_logger.info(
{"key": "self_att_mask_var: ", "val": [float(x) for x in list(self.diag_gauss_mask_sigma.as_array().flat)],
"desc": self.desc})
@xnmt.require_dynet
class TransformerEncoderLayer(Serializable):
yaml_tag = "!TransformerEncoderLayer"
@serializable_init
def __init__(self, hidden_dim, head_count=8, ff_hidden_dim=2048, downsample_factor=1,
input_dim=None, diagonal_mask_width=None, ignore_masks=False,
plot_attention=None, nonlinearity="rectify", diag_gauss_mask=False,
square_mask_std=True, cross_pos_encoding_type=None,
ff_lstm=False, kq_pos_encoding_type=None, kq_pos_encoding_size=40, max_len=1500,
param_init=Ref("exp_global.param_init", default=bare(xnmt.param_initializers.GlorotInitializer)),
bias_init=Ref("exp_global.bias_init", default=bare(xnmt.param_initializers.ZeroInitializer)),
dropout=None, self_attn=None, feed_forward=None, desc=None):
self.self_attn = self.add_serializable_component("self_attn",
self_attn,
lambda: SAAMMultiHeadedSelfAttention(head_count, hidden_dim,
downsample_factor,
input_dim=input_dim,
ignore_masks=ignore_masks,
plot_attention=plot_attention,
diag_gauss_mask=diag_gauss_mask,
square_mask_std=square_mask_std,
param_init=param_init,
bias_init=bias_init,
cross_pos_encoding_type=cross_pos_encoding_type,
kq_pos_encoding_type=kq_pos_encoding_type,
kq_pos_encoding_size=kq_pos_encoding_size,
max_len=max_len,
desc=desc))
self.ff_lstm = ff_lstm
if ff_lstm:
self.feed_forward = self.add_serializable_component("feed_forward",
feed_forward,
lambda: recurrent.BiLSTMSeqTransducer(layers=1,
input_dim=hidden_dim,
hidden_dim=hidden_dim,
var_dropout=dropout,
param_init=param_init,
bias_init=bias_init))
else:
self.feed_forward = self.add_serializable_component("feed_forward",
feed_forward,
lambda: SAAMPositionwiseFeedForward(hidden_dim, ff_hidden_dim,
nonlinearity=nonlinearity,
param_init=param_init))
self.head_count = head_count
self.downsample_factor = downsample_factor
self.diagonal_mask_width = diagonal_mask_width
if diagonal_mask_width: assert diagonal_mask_width % 2 == 1
def set_dropout(self, dropout):
self.dropout = dropout
def transduce(self, x: ExpressionSequence) -> ExpressionSequence:
seq_len = x.sent_len()
batch_size = x[0].dim()[1]
att_mask = None
if self.diagonal_mask_width is not None:
if self.diagonal_mask_width is None:
att_mask = np.zeros((seq_len, seq_len))
else:
att_mask = np.ones((seq_len, seq_len))
for i in range(seq_len):
from_i = max(0, i - self.diagonal_mask_width // 2)
to_i = min(seq_len, i + self.diagonal_mask_width // 2 + 1)
att_mask[from_i:to_i, from_i:to_i] = 0.0
mid = self.self_attn(x=x, att_mask=att_mask, batch_mask=x.mask.np_arr if x.mask else None, p=self.dropout)
if self.downsample_factor > 1:
seq_len = int(math.ceil(seq_len / float(self.downsample_factor)))
hidden_dim = mid.dim()[0][0]
out_mask = x.mask
if self.downsample_factor > 1 and out_mask is not None:
out_mask = out_mask.lin_subsampled(reduce_factor=self.downsample_factor)
if self.ff_lstm:
mid_re = dy.reshape(mid, (hidden_dim, seq_len), batch_size=batch_size)
out = self.feed_forward.transduce(ExpressionSequence(expr_tensor=mid_re, mask=out_mask))
out = dy.reshape(out.as_tensor(), (hidden_dim,), batch_size=seq_len * batch_size)
else:
out = self.feed_forward.transduce(mid, p=self.dropout)
self._recent_output = out
return ExpressionSequence(
expr_tensor=dy.reshape(out, (out.dim()[0][0], seq_len), batch_size=batch_size),
mask=out_mask)
@xnmt.require_dynet
class SAAMSeqTransducer(transducers.SeqTransducer, Serializable):
"""
Args:
input_dim: input dimension
layers: number of layers
hidden_dim: hidden dimension
head_count: number of self-attention heads
ff_hidden_dim: hidden dimension of the interleaved feed-forward layers
dropout: dropout probability
downsample_factor: downsampling factor (>=1)
diagonal_mask_width: if given, apply hard masking of this width
ignore_masks: if True, don't apply any masking
plot_attention: if given, plot self-attention matrices for each layer
nonlinearity: nonlinearity to apply in FF module (string that corresponds to a DyNet unary operation)
pos_encoding_type: None, trigonometric, embedding
pos_encoding_combine: add, concat
pos_encoding_size: if add, must eqal input_dim, otherwise can be chosen freely
max_len: needed to initialize pos embeddings
diag_gauss_mask: whether to apply a soft Gaussian mask with learnable variance
square_mask_std: initial standard deviation of soft Gaussian mask
cross_pos_encoding_type: 'embedding' or None
ff_lstm: if True, use interleaved LSTMs, otherwise add depth using position-wise feed-forward components
kq_pos_encoding_type: None or 'embedding'
kq_pos_encoding_size (int): size of position embeddings
param_init: initializer for weight matrices
bias_init: initializer for bias vectors
positional_embedder:
"""
yaml_tag = u'!SAAMSeqTransducer'
@serializable_init
@events.register_xnmt_handler
def __init__(self, input_dim:int=512, layers:int=1, hidden_dim:int=512, head_count:int=8, ff_hidden_dim:int=2048,
dropout:numbers.Real=Ref("exp_global.dropout", default=0.0), downsample_factor:int=1, diagonal_mask_width:int=None,
ignore_masks:bool=False, plot_attention:typing.Optional[str]=None, nonlinearity:str="rectify", pos_encoding_type:typing.Optional[str]=None,
pos_encoding_combine:str="concat", pos_encoding_size:int=40, max_len:int=1500, diag_gauss_mask:typing.Union[bool,numbers.Real]=False,
square_mask_std:numbers.Real=True, cross_pos_encoding_type:typing.Optional[str]=None, ff_lstm:bool=False, kq_pos_encoding_type:typing.Optional[str]=None,
kq_pos_encoding_size:int=40,
param_init:xnmt.param_initializers.ParamInitializer=Ref("exp_global.param_init", default=bare(xnmt.param_initializers.GlorotInitializer)),
bias_init:xnmt.param_initializers.ParamInitializer=Ref("exp_global.bias_init", default=bare(xnmt.param_initializers.ZeroInitializer)),
positional_embedder=None, modules=None):
self.input_dim = input_dim = (
input_dim + (pos_encoding_size if (pos_encoding_type and pos_encoding_combine == "concat") else 0))
self.hidden_dim = hidden_dim
self.dropout = dropout
self.layers = layers
self.pos_encoding_type = pos_encoding_type
self.pos_encoding_combine = pos_encoding_combine
self.pos_encoding_size = pos_encoding_size
self.max_len = max_len
self.position_encoding_block = None
if self.pos_encoding_type == "embedding":
self.positional_embedder = \
self.add_serializable_component("positional_embedder",
positional_embedder,
lambda: embedders.PositionEmbedder(max_pos=self.max_len,
emb_dim=input_dim if self.pos_encoding_combine == "add" else self.pos_encoding_size))
self.modules = self.add_serializable_component("modules", modules,
lambda: self.make_modules(layers=layers,
plot_attention=plot_attention,
hidden_dim=hidden_dim,
downsample_factor=downsample_factor,
input_dim=input_dim,
head_count=head_count,
ff_hidden_dim=ff_hidden_dim,
diagonal_mask_width=diagonal_mask_width,
ignore_masks=ignore_masks,
nonlinearity=nonlinearity,
diag_gauss_mask=diag_gauss_mask,
square_mask_std=square_mask_std,
cross_pos_encoding_type=cross_pos_encoding_type,
ff_lstm=ff_lstm,
kq_pos_encoding_type=kq_pos_encoding_type,
kq_pos_encoding_size=kq_pos_encoding_size,
dropout=dropout,
param_init=param_init,
bias_init=bias_init))
def make_modules(self, layers, plot_attention, hidden_dim, downsample_factor, input_dim, head_count, ff_hidden_dim, diagonal_mask_width, ignore_masks, nonlinearity, diag_gauss_mask,
square_mask_std,cross_pos_encoding_type,ff_lstm,kq_pos_encoding_type,kq_pos_encoding_size,dropout,param_init,bias_init):
modules = []
for layer_i in range(layers):
if plot_attention is not None:
plot_attention_layer = f"{plot_attention}.layer_{layer_i}"
else:
plot_attention_layer = None
modules.append(TransformerEncoderLayer(hidden_dim,
downsample_factor=downsample_factor,
input_dim=input_dim if layer_i == 0 else hidden_dim,
head_count=head_count, ff_hidden_dim=ff_hidden_dim,
diagonal_mask_width=diagonal_mask_width,
ignore_masks=ignore_masks,
plot_attention=plot_attention_layer,
nonlinearity=nonlinearity,
diag_gauss_mask=diag_gauss_mask,
square_mask_std=square_mask_std,
cross_pos_encoding_type=cross_pos_encoding_type,
ff_lstm=ff_lstm,
max_len=self.max_len,
kq_pos_encoding_type=kq_pos_encoding_type,
kq_pos_encoding_size=kq_pos_encoding_size,
dropout=dropout,
param_init=param_init[layer_i] if isinstance(param_init,
Sequence) else param_init,
bias_init=bias_init[layer_i] if isinstance(bias_init,
Sequence) else bias_init,
desc=f"layer_{layer_i}"))
return modules
def transduce(self, sent: ExpressionSequence) -> ExpressionSequence:
if self.pos_encoding_type == "trigonometric":
if self.position_encoding_block is None or self.position_encoding_block.shape[2] < sent.sent_len():
self.initialize_position_encoding(int(sent.sent_len() * 1.2),
self.input_dim if self.pos_encoding_combine == "add" else self.pos_encoding_size)
encoding = dy.inputTensor(self.position_encoding_block[0, :, :sent.sent_len()])
elif self.pos_encoding_type == "embedding":
encoding = self.positional_embedder.embed_sent(sent.sent_len()).as_tensor()
if self.pos_encoding_type:
if self.pos_encoding_combine == "add":
sent = ExpressionSequence(expr_tensor=sent.as_tensor() + encoding, mask=sent.mask)
else: # concat
sent = ExpressionSequence(expr_tensor=dy.concatenate([sent.as_tensor(), encoding]),
mask=sent.mask)
elif self.pos_encoding_type:
raise ValueError(f"unknown encoding type {self.pos_encoding_type}")
for module in self.modules:
enc_sent = module.transduce(sent)
sent = enc_sent
self._final_states = [transducers.FinalTransducerState(sent[-1])]
return sent
def get_final_states(self):
return self._final_states
@events.handle_xnmt_event
def on_set_train(self, val):
for module in self.modules:
module.set_dropout(self.dropout if val else 0.0)
def initialize_position_encoding(self, length, n_units):
# Implementation in the Google tensor2tensor repo
channels = n_units
position = np.arange(length, dtype='f')
num_timescales = channels // 2
log_timescale_increment = (np.log(10000. / 1.) / (float(num_timescales) - 1))
inv_timescales = 1. * np.exp(np.arange(num_timescales).astype('f') * -log_timescale_increment)
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.reshape(signal, [1, length, channels])
self.position_encoding_block = np.transpose(signal, (0, 2, 1))