Skip to content

Commit

Permalink
Self-Attentional Acoustic Model (#416)
Browse files Browse the repository at this point in the history
* add self_attentional_am

* PositionEmbedder

* add nn.py

* add unit test

* initial fixes

* fixed some type annotations

* fixed PositionEmbedder

* use add_serializable_component in SAAMMultiHeadedSelfAttention

* use add_serializable_component for TransformerEncoderLayer

* use add_serializable_component for TransformerSeqTransducer

* add some missing add_serializable_component calls

* rename to SAAMSeqTransducer

* add some comments, and import sklearn only if needed
  • Loading branch information
msperber committed Jun 12, 2018
1 parent d9e227b commit 90f9d89
Show file tree
Hide file tree
Showing 9 changed files with 735 additions and 4 deletions.
48 changes: 48 additions & 0 deletions test/config/self_attentional_am.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
speech-self-att: !Experiment
exp_global: !ExpGlobal
model_file: examples/output/{EXP}.mod
log_file: examples/output/{EXP}.log
dropout: 0.2
preproc: !PreprocRunner
overwrite: False
tasks:
- !PreprocExtract
in_files:
- examples/data/LDC94S13A.yaml
out_files:
- examples/data/LDC94S13A.h5
specs: !MelFiltExtractor {}
model: !DefaultTranslator
src_reader: !H5Reader
transpose: True
trg_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !NoopEmbedder
emb_dim: 40
encoder: !SAAMSeqTransducer
layers: 2
input_dim: 40
hidden_dim: 32
downsample_factor: 2
ff_hidden_dim: 32
pos_encoding_type: embedding
diag_gauss_mask: 3.0
ff_lstm: True
attender: !MlpAttender
state_dim: 32
hidden_dim: 32
input_dim: 32
trg_embedder: !SimpleWordEmbedder
emb_dim: 32
train: !SimpleTrainingRegimen
src_file: examples/data/LDC94S13A.h5
trg_file: examples/data/LDC94S13A.char
run_for_epochs: 1
batcher: !SrcBatcher
batch_size: 3
pad_src_to_multiple: 4
src_pad_token: ~
dev_tasks:
- !LossEvalTask
src_file: examples/data/LDC94S13A.h5
ref_file: examples/data/LDC94S13A.char
7 changes: 5 additions & 2 deletions test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,14 @@ def test_report(self):
def test_retrieval(self):
run.main(["test/config/retrieval.yaml"])

def test_score(self):
run.main(["test/config/score.yaml"])

def test_segmenting(self):
run.main(["test/config/segmenting.yaml"])

def test_score(self):
run.main(["test/config/score.yaml"])
def test_self_attentional_am(self):
run.main(["test/config/self_attentional_am.yaml"])

def test_speech(self):
run.main(["test/config/speech.yaml"])
Expand Down
3 changes: 2 additions & 1 deletion xnmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
import xnmt.retriever
import xnmt.segmenting_composer
import xnmt.segmenting_encoder
import xnmt.specialized_encoders
import xnmt.specialized_encoders.tilburg_harwath
import xnmt.specialized_encoders.self_attentional_am
import xnmt.training_regimen
import xnmt.training_task
import xnmt.transformer
Expand Down
27 changes: 26 additions & 1 deletion xnmt/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from xnmt.expression_sequence import ExpressionSequence, LazyNumpyExpressionSequence
from xnmt.linear import Linear
from xnmt.param_collection import ParamManager
from xnmt.param_init import GlorotInitializer, ZeroInitializer
from xnmt.param_init import GlorotInitializer, ZeroInitializer, ParamInitializer
from xnmt.persistence import serializable_init, Serializable, Ref, Path, bare

class Embedder(object):
Expand Down Expand Up @@ -277,6 +277,31 @@ def embed(self, x):
ret = dy.noise(ret, self.weight_noise)
return ret

class PositionEmbedder(Embedder, Serializable):

yaml_tag = '!PositionEmbedder'

@serializable_init
@register_xnmt_handler
def __init__(self, max_pos: int, emb_dim: int = Ref("exp_global.default_layer_dim"),
param_init: ParamInitializer = Ref("exp_global.param_init", default=bare(GlorotInitializer))):
"""
max_pos: largest embedded position
emb_dim: embedding size
param_init: how to initialize embedding matrix
"""
self.max_pos = max_pos
self.emb_dim = emb_dim
param_collection = ParamManager.my_params(self)
param_init = param_init
dim = (self.emb_dim, max_pos)
self.embeddings = param_collection.add_parameters(dim, init=param_init.initializer(dim, is_lookup=True))

def embed(self, word): raise NotImplementedError("Position-embedding for individual words not implemented yet.")
def embed_sent(self, sent_len):
embeddings = dy.strided_select(dy.parameter(self.embeddings), [1,1], [0,0], [self.emb_dim, sent_len])
return ExpressionSequence(expr_tensor=embeddings, mask=None)

class NoopEmbedder(Embedder, Serializable):
"""
This embedder performs no lookups but only passes through the inputs.
Expand Down
12 changes: 12 additions & 0 deletions xnmt/expression_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ def has_tensor(self):
"""
return self.expr_tensor is not None

def dim(self):
"""
Returns:
result of self.as_tensor().dim(), without explicitly constructing that tensor
"""
if self.has_tensor(): return self.as_tensor().dim()
else:
if self.tensor_transposed:
return tuple([len(self)] + list(self[0].dim()[0])), self[0].dim()[1]
else:
return tuple(list(self[0].dim()[0]) + [len(self)]), self[0].dim()[1]

class LazyNumpyExpressionSequence(ExpressionSequence):
"""
This is initialized via numpy arrays, and dynet expressions are only created
Expand Down
23 changes: 23 additions & 0 deletions xnmt/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
This module holds normalizers for neural networks. Currently implemented is layer norm, later batch norm etc. may be added.
"""

import dynet as dy

from xnmt import param_collection
from xnmt.persistence import Serializable, serializable_init


class LayerNorm(Serializable):
yaml_tag = "!LayerNorm"

@serializable_init
def __init__(self, d_hid):
subcol = param_collection.ParamManager.my_params(self)
self.p_g = subcol.add_parameters(dim=d_hid, init=dy.ConstInitializer(1.0))
self.p_b = subcol.add_parameters(dim=d_hid, init=dy.ConstInitializer(0.0))

def __call__(self, x):
g = dy.parameter(self.p_g)
b = dy.parameter(self.p_b)
return dy.layer_norm(x, g, b)
Empty file.

0 comments on commit 90f9d89

Please sign in to comment.