Skip to content

Commit

Permalink
Cascaded inference + updated model interface (#462)
Browse files Browse the repository at this point in the history
* minor doc improvements

* fix/update expression sequence doc

* remove unused initialize_generator()

* GeneratorModel no longer subclass of TrainableModel

* implement cascade; clean up interface a bit

* subdivide trainable models into supervised and unsupervised models

* fix unit tests

* fix some minor issues

* fix type for src_file

* rename to ConditionedModel/UnconditionedModel

* fix unit test
  • Loading branch information
msperber committed Jul 16, 2018
1 parent 9ed5a34 commit 8143fad
Show file tree
Hide file tree
Showing 20 changed files with 228 additions and 134 deletions.
7 changes: 0 additions & 7 deletions test/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator(beam=1)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -80,7 +79,6 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator(beam=1)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand All @@ -91,7 +89,6 @@ def test_single(self):
trg=self.trg_data[0],
loss_calculator=AutoRegressiveMLELoss()).value()
dy.renew_cg()
self.model.initialize_generator()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(beam_size=1),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
self.assertAlmostEqual(-outputs[0].score, train_loss, places=4)
Expand All @@ -117,14 +114,12 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator(beam=1)

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))

def test_single(self):
dy.renew_cg()
self.model.initialize_generator(beam=1)
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
dy.renew_cg()
Expand Down Expand Up @@ -163,13 +158,11 @@ def setUp(self):

def test_greedy_vs_beam(self):
dy.renew_cg()
self.model.initialize_generator()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], BeamSearch(beam_size=1),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
output_score1 = outputs[0].score

dy.renew_cg()
self.model.initialize_generator()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
output_score2 = outputs[0].score
Expand Down
5 changes: 0 additions & 5 deletions test/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator()

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand Down Expand Up @@ -82,7 +81,6 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator()

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))
Expand All @@ -93,7 +91,6 @@ def test_single(self):
trg=self.trg_data[0],
loss_calculator=AutoRegressiveMLELoss()).value()
dy.renew_cg()
self.model.initialize_generator()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
output_score = outputs[0].score
Expand All @@ -120,14 +117,12 @@ def setUp(self):
bridge=CopyBridge(dec_dim=layer_dim, dec_layers=1)),
)
self.model.set_train(False)
self.model.initialize_generator()

self.src_data = list(self.model.src_reader.read_sents("examples/data/head.ja"))
self.trg_data = list(self.model.trg_reader.read_sents("examples/data/head.en"))

def test_single(self):
dy.renew_cg()
self.model.initialize_generator()
outputs = self.model.generate(xnmt.batcher.mark_as_batch([self.src_data[0]]), [0], GreedySearch(),
forced_trg_ids=xnmt.batcher.mark_as_batch([self.trg_data[0]]))
output_score = outputs[0].score
Expand Down
2 changes: 1 addition & 1 deletion xnmt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import xnmt.input_reader
import xnmt.lm
import xnmt.lstm
import xnmt.exp_global
import xnmt.model_base
import xnmt.optimizer
import xnmt.param_init
import xnmt.preproc_runner
Expand Down
2 changes: 1 addition & 1 deletion xnmt/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xnmt.persistence import serializable_init, Serializable, bare
import xnmt.inference

class SequenceClassifier(model_base.GeneratorModel, Serializable, model_base.EventTrigger):
class SequenceClassifier(model_base.ConditionedModel, model_base.GeneratorModel, Serializable, model_base.EventTrigger):
"""
A sequence classifier.
Expand Down
33 changes: 15 additions & 18 deletions xnmt/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from xnmt.batcher import Batcher, SrcBatcher
from xnmt.evaluator import Evaluator
from xnmt.model_base import GeneratorModel, TrainableModel
from xnmt.inference import Inference
from xnmt import model_base
import xnmt.inference
import xnmt.input_reader
from xnmt.persistence import serializable_init, Serializable, Ref, bare
from xnmt.loss_calculator import LossCalculator, AutoRegressiveMLELoss
Expand Down Expand Up @@ -41,7 +41,7 @@ class LossEvalTask(EvalTask, Serializable):
yaml_tag = '!LossEvalTask'

@serializable_init
def __init__(self, src_file: str, ref_file: Optional[str] = None, model: TrainableModel = Ref("model"),
def __init__(self, src_file: str, ref_file: Optional[str] = None, model: 'model_base.GeneratorModel' = Ref("model"),
batcher: Batcher = Ref("train.batcher", default=bare(xnmt.batcher.SrcBatcher, batch_size=32)),
loss_calculator: LossCalculator = bare(AutoRegressiveMLELoss), max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None,
Expand All @@ -57,12 +57,12 @@ def __init__(self, src_file: str, ref_file: Optional[str] = None, model: Trainab
self.loss_comb_method = loss_comb_method
self.desc=desc

def eval(self) -> tuple:
def eval(self) -> 'EvalScore':
"""
Perform evaluation task.
Returns:
tuple of score and reference length
Evaluated score
"""
self.model.set_train(False)
if self.src_data is None:
Expand Down Expand Up @@ -92,7 +92,10 @@ def eval(self) -> tuple:
loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

try:
return LossScore(loss_stats[self.model.get_primary_loss()], loss_stats=loss_stats, desc=self.desc), ref_words_cnt
return LossScore(loss_stats[self.model.get_primary_loss()],
loss_stats=loss_stats,
num_ref_words = ref_words_cnt,
desc=self.desc)
except KeyError:
raise RuntimeError("Did you wrap your loss calculation with FactoredLossExpr({'primary_loss': loss_value}) ?")

Expand All @@ -114,8 +117,8 @@ class AccuracyEvalTask(EvalTask, Serializable):

@serializable_init
def __init__(self, src_file: Union[str,Sequence[str]], ref_file: Union[str,Sequence[str]], hyp_file: str,
model: GeneratorModel = Ref("model"), eval_metrics: Union[str, Sequence[Evaluator]] = "bleu",
inference: Optional[Inference] = None, desc: Any = None):
model: 'model_base.GeneratorModel' = Ref("model"), eval_metrics: Union[str, Sequence[Evaluator]] = "bleu",
inference: Optional[xnmt.inference.Inference] = None, desc: Any = None):
self.model = model
if isinstance(eval_metrics, str):
eval_metrics = [xnmt.xnmt_evaluate.eval_shortcuts[shortcut]() for shortcut in eval_metrics.split(",")]
Expand All @@ -139,13 +142,7 @@ def eval(self):
eval_scores = xnmt.xnmt_evaluate.xnmt_evaluate(hyp_file=self.hyp_file, ref_file=self.ref_file, desc=self.desc,
evaluators=self.eval_metrics)

# Calculate the reference file size
ref_words_cnt = 0
for ref_sent in self.model.trg_reader.read_sents(
self.ref_file if isinstance(self.ref_file, str) else self.ref_file[0]):
ref_words_cnt += ref_sent.len_unpadded()
ref_words_cnt += 0
return eval_scores, ref_words_cnt
return eval_scores

class DecodingEvalTask(EvalTask, Serializable):
"""
Expand All @@ -161,8 +158,8 @@ class DecodingEvalTask(EvalTask, Serializable):
yaml_tag = '!DecodingEvalTask'

@serializable_init
def __init__(self, src_file: Union[str,Sequence[str]], hyp_file: str, model: GeneratorModel = Ref("model"),
inference: Optional[Inference] = None):
def __init__(self, src_file: Union[str,Sequence[str]], hyp_file: str, model: 'model_base.GeneratorModel' = Ref("model"),
inference: Optional[xnmt.inference.Inference] = None):

self.model = model
self.src_file = src_file
Expand All @@ -174,4 +171,4 @@ def eval(self):
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
trg_file=self.hyp_file)
return None, None
return None
9 changes: 6 additions & 3 deletions xnmt/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,19 @@ class LossScore(EvalScore, Serializable):
Args:
loss: the (primary) loss value
loss_stats: info on additional loss values
num_ref_words: number of reference tokens
desc: human-readable description to include in log outputs
"""

yaml_tag = "!LossScore"

@serializable_init
def __init__(self, loss: float, loss_stats: Dict[str, float] = None, desc: Any = None) -> None:
def __init__(self, loss: float, loss_stats: Dict[str, float] = None, num_ref_words: Optional[int] = None,
desc: Any = None) -> None:
super().__init__(desc=desc)
self.loss = loss
self.loss_stats = loss_stats
self.num_ref_words = num_ref_words
self.serialize_params = {"loss":loss}
if desc is not None: self.serialize_params["desc"] = desc
if loss_stats is not None: self.serialize_params["loss_stats"] = desc
Expand All @@ -136,9 +139,9 @@ def metric_name(self): return "Loss"
def higher_is_better(self): return False
def score_str(self):
if self.loss_stats is not None and len(self.loss_stats) > 1:
return "{" + ", ".join("%s: %.5f" % (k, v) for k, v in self.loss_stats.items()) + "}"
return "{" + ", ".join(f"{k}: {v:.5f}" for k, v in self.loss_stats.items()) + f"}} (ref_len={self.num_ref_words})"
else:
return f"{self.value():.3f}"
return f"{self.value():.3f} (ref_len={self.num_ref_words})"

class BLEUScore(EvalScore, Serializable):
"""
Expand Down
2 changes: 1 addition & 1 deletion xnmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __call__(self, save_fct):
logger.info("> Performing final evaluation")
eval_scores = []
for evaluator in evaluate_args:
eval_score, _ = evaluator.eval()
eval_score = evaluator.eval()
if type(eval_score) == list:
eval_scores.extend(eval_score)
else:
Expand Down
9 changes: 7 additions & 2 deletions xnmt/expression_sequence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, Sequence

import dynet as dy
import numpy as np

Expand All @@ -9,13 +11,16 @@ class ExpressionSequence(object):
Internal representation is either a list of expressions or a single tensor or both.
If necessary, both forms of representation are created from the other on demand.
"""
def __init__(self, expr_list=None, expr_tensor=None, expr_transposed_tensor=None, mask=None):
def __init__(self, expr_list: Optional[Sequence[dy.Expression]] = None, expr_tensor: Optional[dy.Expression] = None,
expr_transposed_tensor: Optional[dy.Expression] = None, mask: Optional['xnmt.batcher.Mask'] = None) \
-> None:
"""Constructor.
Args:
expr_list: a python list of expressions
expr_tensor: a tensor where last dimension are the sequence items
mask: a numpy array consisting of whether things should be masked or not
expr_transposed_tensor: a tensor in transposed form (first dimension are sequence items)
mask: an optional mask object indicating what positions in a batched tensor should be masked
Raises:
valueError: raises an exception if neither expr_list nor expr_tensor are given,
or if both have inconsistent length
Expand Down
52 changes: 42 additions & 10 deletions xnmt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = Non
self.batcher = batcher
self.reporter = reporter

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids) -> List[output.Output]:
# TODO: src should probably a batch of inputs for consistency with return values being a batch of outputs
def generate_one(self, generator: 'model_base.GeneratorModel', src: xnmt.batcher.Batch, src_i: int, forced_ref_ids) \
-> List[output.Output]:
raise NotImplementedError("must be implemented by subclasses")

def compute_losses_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input,
def compute_losses_one(self, generator: 'model_base.GeneratorModel', src: xnmt.input.Input,
ref: xnmt.input.Input) -> loss.FactoredLossExpr:
raise NotImplementedError("must be implemented by subclasses")


def perform_inference(self, generator: model_base.GeneratorModel, src_file: str = None, trg_file: str = None):
def perform_inference(self, generator: 'model_base.GeneratorModel', src_file: str = None, trg_file: str = None):
"""
Perform inference.
Expand Down Expand Up @@ -86,7 +86,7 @@ def perform_inference(self, generator: model_base.GeneratorModel, src_file: str
src_corpus=src_corpus, trg_file=trg_file, batcher=self.batcher,
max_src_len=self.max_src_len)

def _generate_output(self, generator: model_base.GeneratorModel, src_corpus: Sequence[xnmt.input.Input],
def _generate_output(self, generator: 'model_base.GeneratorModel', src_corpus: Sequence[xnmt.input.Input],
trg_file: str, batcher: Optional[xnmt.batcher.Batcher] = None, max_src_len: Optional[int] = None,
forced_ref_corpus: Optional[Sequence[xnmt.input.Input]] = None,
assert_scores: Optional[Sequence[float]] = None) -> None:
Expand Down Expand Up @@ -173,7 +173,7 @@ def _write_rescored_output(ref_scores: Sequence[float], ref_file: str, trg_file:
fp.write("{} ||| score={}\n".format(nbest.strip(), score))

@staticmethod
def _read_corpus(generator: model_base.GeneratorModel, src_file: str, mode: str, ref_file: str) -> Tuple[List, List]:
def _read_corpus(generator: 'model_base.GeneratorModel', src_file: str, mode: str, ref_file: str) -> Tuple[List, List]:
src_corpus = list(generator.src_reader.read_sents(src_file))
# Get reference if it exists and is necessary
if mode == "forced" or mode == "forceddebug" or mode == "score":
Expand Down Expand Up @@ -238,12 +238,12 @@ def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = Non
max_num_sents=max_num_sents, mode=mode, batcher=batcher, reporter=reporter)
self.post_processor = output.OutputProcessor.get_output_processor(post_process)

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids)\
def generate_one(self, generator: 'model_base.GeneratorModel', src: xnmt.batcher.Batch, src_i: int, forced_ref_ids)\
-> List[output.Output]:
outputs = generator.generate(src, src_i, forced_trg_ids=forced_ref_ids)
return outputs

def compute_losses_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input,
def compute_losses_one(self, generator: 'model_base.GeneratorModel', src: xnmt.input.Input,
ref: xnmt.input.Input) -> loss.FactoredLossExpr:
loss_expr = generator.calc_loss(src, ref, loss_calculator=loss_calculator.AutoRegressiveMLELoss())
return loss_expr
Expand Down Expand Up @@ -289,12 +289,44 @@ def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = Non
self.post_processor = output.OutputProcessor.get_output_processor(post_process)
self.search_strategy = search_strategy

def generate_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input, src_i: int, forced_ref_ids)\
def generate_one(self, generator: 'model_base.GeneratorModel', src: xnmt.batcher.Batch, src_i: int, forced_ref_ids)\
-> List[output.Output]:
outputs = generator.generate(src, src_i, forced_trg_ids=forced_ref_ids, search_strategy=self.search_strategy)
return outputs

def compute_losses_one(self, generator: model_base.GeneratorModel, src: xnmt.input.Input,
def compute_losses_one(self, generator: 'model_base.GeneratorModel', src: xnmt.input.Input,
ref: xnmt.input.Input) -> loss.FactoredLossExpr:
loss_expr = generator.calc_loss(src, ref, loss_calculator=loss_calculator.AutoRegressiveMLELoss())
return loss_expr

class CascadeInference(Inference, Serializable):
"""Inference class that performs inference as a series of independent inference steps.
Steps are performed using a list of inference sub-objects and a list of models.
Intermediate outputs are written out to disk and then read by the next time step.
The generator passed to ``perform_inference`` must be a :class:`xnmt.model_base.CascadeGenerator`.
Args:
steps: list of inference objects
"""

yaml_tag = "!CascadeInference"
@serializable_init
def __init__(self, steps: Sequence[Inference]) -> None:
self.steps = steps

def perform_inference(self, generator: 'model_base.CascadeGenerator', src_file: str = None, trg_file: str = None):
assert isinstance(generator, model_base.CascadeGenerator)
assert len(generator.generators) == len(self.steps)
src_files = [src_file] + [f"{trg_file}.tmp.{i}" for i in range(len(self.steps)-1)]
trg_files = src_files[1:] + [trg_file]
for step_i, step in enumerate(self.steps):
step.perform_inference(generator=generator.generators[step_i],
src_file=src_files[step_i],
trg_file=trg_files[step_i])

def compute_losses_one(self, *args, **kwargs):
raise ValueError("cannot call CascadedInference.compute_losses_one() directly, use the sub-inference objects")
def generate_one(self, *args, **kwargs):
raise ValueError("cannot call CascadedInference.generate_one() directly, use the sub-inference objects")
2 changes: 1 addition & 1 deletion xnmt/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from xnmt import batcher, embedder, events, input_reader, loss, lstm, model_base, scorer, transducer, transform
from xnmt.persistence import serializable_init, Serializable, bare

class LanguageModel(model_base.TrainableModel, model_base.EventTrigger, Serializable):
class LanguageModel(model_base.ConditionedModel, model_base.EventTrigger, Serializable):
"""
A simple unidirectional language model predicting the next token.
Expand Down

0 comments on commit 8143fad

Please sign in to comment.