Skip to content

Commit

Permalink
Language Model (#452)
Browse files Browse the repository at this point in the history
* first working LM

* fix unit tests

* fix issues related to batch loss evaluation

* fix unit test
  • Loading branch information
msperber committed Jul 13, 2018
1 parent bc5bff5 commit 7cdd8e4
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 26 deletions.
27 changes: 27 additions & 0 deletions test/config/lm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
lm: !Experiment
model: !LanguageModel
src_reader: !PlainTextReader
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
src_embedder: !SimpleWordEmbedder
emb_dim: 512
rnn: !UniLSTMSeqTransducer
layers: 1
scorer: !Softmax
vocab: !Vocab {vocab_file: examples/data/head.en.vocab}
train: !SimpleTrainingRegimen
batcher: !SrcBatcher
batch_size: 32
trainer: !AdamTrainer
alpha: 0.001
run_for_epochs: 2
src_file: examples/data/head.en
trg_file: examples/data/head.en
dev_tasks:
- !LossEvalTask
src_file: examples/data/head.en
ref_file: examples/data/head.en
# final evaluation
evaluate:
- !LossEvalTask
src_file: examples/data/head.en
ref_file: examples/data/head.en
3 changes: 3 additions & 0 deletions test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def test_ensembling(self):
def test_forced(self):
run.main(["test/config/forced.yaml"])

def test_lm(self):
run.main(["test/config/lm.yaml"])

def test_load_model(self):
run.main(["test/config/load_model.yaml"])

Expand Down
6 changes: 3 additions & 3 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
batch_size=5
src_sents = self.src_data[:batch_size]
src_min = min([x.sent_len() for x in src_sents])
src_sents_trunc = [s[:src_min] for s in src_sents]
src_sents_trunc = [s.words[:src_min] for s in src_sents]
for single_sent in src_sents_trunc:
single_sent[src_min-1] = Vocab.ES
while len(single_sent)%pad_src_to_multiple != 0:
single_sent.append(Vocab.ES)
trg_sents = self.trg_data[:batch_size]
trg_min = min([x.sent_len() for x in trg_sents])
trg_sents_trunc = [s[:trg_min] for s in trg_sents]
trg_sents_trunc = [s.words[:trg_min] for s in trg_sents]
for single_sent in trg_sents_trunc: single_sent[trg_min-1] = Vocab.ES

src_sents_trunc = [SimpleSentenceInput(s) for s in src_sents_trunc]
Expand Down Expand Up @@ -176,7 +176,7 @@ def assert_single_loss_equals_batch_loss(self, model, pad_src_to_multiple=1):
batch_size = 5
src_sents = self.src_data[:batch_size]
src_min = min([x.sent_len() for x in src_sents])
src_sents_trunc = [s[:src_min] for s in src_sents]
src_sents_trunc = [s.words[:src_min] for s in src_sents]
for single_sent in src_sents_trunc:
single_sent[src_min-1] = Vocab.ES
while len(single_sent)%pad_src_to_multiple != 0:
Expand Down
1 change: 1 addition & 0 deletions xnmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import xnmt.inference
import xnmt.input
import xnmt.input_reader
import xnmt.lm
import xnmt.lstm
import xnmt.exp_global
import xnmt.optimizer
Expand Down
18 changes: 11 additions & 7 deletions xnmt/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import dynet as dy

from xnmt.batcher import Batcher
from xnmt.batcher import Batcher, SrcBatcher
from xnmt.evaluator import Evaluator
from xnmt.model_base import GeneratorModel
from xnmt.model_base import GeneratorModel, TrainableModel
from xnmt.inference import Inference
import xnmt.input_reader
from xnmt.persistence import serializable_init, Serializable, Ref, bare
Expand Down Expand Up @@ -41,8 +41,8 @@ class LossEvalTask(EvalTask, Serializable):
yaml_tag = '!LossEvalTask'

@serializable_init
def __init__(self, src_file: str, ref_file: str, model: GeneratorModel = Ref("model"),
batcher: Optional[Batcher] = Ref("train.batcher", default=None),
def __init__(self, src_file: str, ref_file: Optional[str] = None, model: TrainableModel = Ref("model"),
batcher: Batcher = Ref("train.batcher", default=bare(xnmt.batcher.SrcBatcher, batch_size=32)),
loss_calculator: LossCalculator = bare(AutoRegressiveMLELoss), max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None,
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), desc: Any = None):
Expand All @@ -67,9 +67,13 @@ def eval(self) -> tuple:
self.model.set_train(False)
if self.src_data is None:
self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
self.src_file, self.ref_file, batcher=self.batcher,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
xnmt.input_reader.read_parallel_corpus(src_reader=self.model.src_reader,
trg_reader=self.model.trg_reader,
src_file=self.src_file,
trg_file=self.ref_file,
batcher=self.batcher,
max_src_len=self.max_src_len,
max_trg_len=self.max_trg_len)
loss_val = FactoredLossVal()
ref_words_cnt = 0
for src, trg in zip(self.src_batches, self.ref_batches):
Expand Down
4 changes: 2 additions & 2 deletions xnmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from xnmt import logger
from xnmt.exp_global import ExpGlobal
from xnmt.eval_task import EvalTask
from xnmt.model_base import GeneratorModel
from xnmt.model_base import TrainableModel
from xnmt.param_collection import ParamManager
from xnmt.preproc_runner import PreprocRunner
from xnmt.training_regimen import TrainingRegimen
Expand Down Expand Up @@ -31,7 +31,7 @@ class Experiment(Serializable):
def __init__(self,
exp_global:Optional[ExpGlobal] = bare(ExpGlobal),
preproc:Optional[PreprocRunner] = None,
model:Optional[GeneratorModel] = None,
model:Optional[TrainableModel] = None,
train:Optional[TrainingRegimen] = None,
evaluate:Optional[List[EvalTask]] = None,
random_search_report:Optional[dict] = None) -> None:
Expand Down
5 changes: 4 additions & 1 deletion xnmt/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def len_unpadded(self):
return sum(x != vocab.Vocab.ES for x in self.words)

def __getitem__(self, key):
ret = self.words[key]
if isinstance(ret, list): # support for slicing
return SimpleSentenceInput(ret)
return self.words[key]

def get_padded_sent(self, token, pad_len):
Expand Down Expand Up @@ -141,7 +144,7 @@ def get_truncated_sent(self, trunc_len: int) -> 'Input':
if trunc_len == 0:
return self
new_words = self.words[:-trunc_len]
return self.__class__(new_words, self.vocab)
return self.__class__(new_words)

def __str__(self):
return " ".join(map(str, self.words))
Expand Down
5 changes: 3 additions & 2 deletions xnmt/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xnmt.events import register_xnmt_handler, handle_xnmt_event
from xnmt.vocab import Vocab
import xnmt.input
import xnmt.batcher

class InputReader(object):
"""
Expand Down Expand Up @@ -424,8 +425,8 @@ def read_sents(self, filename, filter_ids=None):
return [l for l in self.iterate_filtered(filename, filter_ids)]

###### A utility function to read a parallel corpus
def read_parallel_corpus(src_reader, trg_reader, src_file, trg_file,
batcher=None, sample_sents=None, max_num_sents=None, max_src_len=None, max_trg_len=None):
def read_parallel_corpus(src_reader: InputReader, trg_reader: InputReader, src_file: str, trg_file: str,
batcher: xnmt.batcher.Batcher=None, sample_sents=None, max_num_sents=None, max_src_len=None, max_trg_len=None):
"""
A utility function to read a parallel corpus.
Expand Down
66 changes: 66 additions & 0 deletions xnmt/lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import dynet as dy
import numpy as np

from xnmt import batcher, embedder, events, input_reader, loss, lstm, model_base, scorer, transducer, transform
from xnmt.persistence import serializable_init, Serializable, bare

class LanguageModel(model_base.TrainableModel, model_base.EventTrigger, Serializable):
"""
A simple unidirectional language model predicting the next token.
Args:
src_reader: A reader for the source side.
src_embedder: A word embedder for the input language
rnn: An RNN, usually unidirectional LSTM with one or more layers
transform: A transform to be applied before making predictions
scorer: The class to actually make predictions
"""

yaml_tag = '!LanguageModel'

@events.register_xnmt_handler
@serializable_init
def __init__(self,
src_reader:input_reader.InputReader,
src_embedder:embedder.Embedder=bare(embedder.SimpleWordEmbedder),
rnn:transducer.SeqTransducer=bare(lstm.UniLSTMSeqTransducer),
transform:transform.Transform=bare(transform.NonLinear),
scorer:scorer.Scorer=bare(scorer.Softmax)):
super().__init__(src_reader=src_reader, trg_reader=src_reader)
self.src_embedder = src_embedder
self.rnn = rnn
self.transform = transform
self.scorer = scorer

def shared_params(self):
return [{".src_embedder.emb_dim", ".encoder.input_dim"},]

def get_primary_loss(self):
return "mle"

def calc_loss(self, src, trg, loss_calculator):
if not batcher.is_batched(src):
src = batcher.ListBatch([src])

src_inputs = batcher.ListBatch([s[:-1] for s in src], mask=batcher.Mask(src.mask.np_arr[:,:-1]) if src.mask else None)
src_targets = batcher.ListBatch([s[1:] for s in src], mask=batcher.Mask(src.mask.np_arr[:,1:]) if src.mask else None)

self.start_sent(src)
embeddings = self.src_embedder.embed_sent(src_inputs)
encodings = self.rnn.transduce(embeddings)
encodings_tensor = encodings.as_tensor()
((hidden_dim, seq_len), batch_size) = encodings.dim()
encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len)
outputs = self.transform(encoding_reshaped)

ref_action = np.asarray([sent.words for sent in src_targets]).reshape((seq_len * batch_size,))
loss_expr_perstep = self.scorer.calc_loss(outputs, batcher.mark_as_batch(ref_action))
loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size)
if src_targets.mask:
loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-src_targets.mask.np_arr.T, batched=True))
loss_expr = dy.sum_elems(loss_expr_perstep)

model_loss = loss.FactoredLossExpr()
model_loss.add_loss("mle", loss_expr)

return model_loss
2 changes: 1 addition & 1 deletion xnmt/seq_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class SeqLabeler(model_base.GeneratorModel, Serializable, reports.Reportable, model_base.EventTrigger):
"""
A default translator based on attentional sequence-to-sequence models.
A simple sequence labeler based on an encoder and an output softmax layer.
Args:
src_reader: A reader for the source side.
Expand Down
24 changes: 14 additions & 10 deletions xnmt/training_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,15 @@ def _augment_data_next_epoch(self):
# reload the data
self.model.src_reader.train = self.model.trg_reader.train = True
self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
self.src_file, self.trg_file,
batcher=self.batcher, sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
input_reader.read_parallel_corpus(src_reader=self.model.src_reader,
trg_reader=self.model.trg_reader,
src_file=self.src_file,
trg_file=self.trg_file,
batcher=self.batcher,
sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len,
max_trg_len=self.max_trg_len)
self.model.src_reader.train = self.model.trg_reader.train = False
# restart data generation
self._augmentation_handle = Popen(augment_command + " --epoch %d" % self.training_state.epoch_num, shell=True)
Expand Down Expand Up @@ -219,11 +223,11 @@ def advance_epoch(self):
self.model.src_reader.needs_reload() or self.model.trg_reader.needs_reload():
self.model.set_train(True)
self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
self.src_file, self.trg_file,
batcher=self.batcher, sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
input_reader.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader,
src_file=self.src_file, trg_file=self.trg_file,
batcher=self.batcher, sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
self.model.src_reader.train = self.model.trg_reader.train = False
self.training_state.epoch_seed = random.randint(1,2147483647)
random.seed(self.training_state.epoch_seed)
Expand Down

0 comments on commit 7cdd8e4

Please sign in to comment.