Skip to content

Commit

Permalink
Sentence level evaluation (#396)
Browse files Browse the repository at this point in the history
* initial refactoring of levenshtein scores

* multi-ref evaluation for sentence-level evaluators

* sentence-based recall + doc improvements

* implement write_sentence_scores

* some cleanup

* make GLEU evaluator adhere to SentenceLevelEvaluator interface

* clean up speech example

* make SequenceAccuracyEvaluator adhere to SentenceLevelEvaluator interface

* make batcher optional in LossEvalTask
  • Loading branch information
msperber authored and neubig committed May 25, 2018
1 parent a7d1338 commit fe0c891
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 231 deletions.
6 changes: 3 additions & 3 deletions docs/api_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ Experiment
Model
-----

GeneratorModel
~~~~~~~~~~~~~~
Model Base Classes
~~~~~~~~~~~~~~~~~~

.. automodule:: xnmt.generator
.. automodule:: xnmt.model_base
:members:
:show-inheritance:

Expand Down
44 changes: 44 additions & 0 deletions test/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,49 @@ def test_bleu_4gram_fast(self):
act_bleu = bleu.evaluate(self.ref_id, self.hyp_id)
self.assertEqual(act_bleu, exp_bleu)

class TestGLEU(unittest.TestCase):
def setUp(self):
self.evaluator = evaluator.GLEUEvaluator()
def test_gleu_single_1(self):
self.assertAlmostEqual(
self.evaluator.evaluate(['the cat is on the mat'.split()], ['the the the the the the the'.split()]).value(),
0.0909,
places=4)
def test_gleu_single_2(self):
self.assertAlmostEqual(
self.evaluator.evaluate(
['It is a guide to action that ensures that the military will forever heed Party commands'.split()], [
'It is a guide to action which ensures that the military always obeys the commands of the party'.split()]).value(),
0.4393,
places=3)
def test_gleu_single_3(self):
self.assertAlmostEqual(
self.evaluator.evaluate(
['It is a guide to action that ensures that the military will forever heed Party commands'.split()], [
'It is to insure the troops forever hearing the activity guidebook that party direct'.split()]).value(),
0.1206,
places=3)
def test_gleu_corpus(self):
self.assertAlmostEqual(
self.evaluator.evaluate(
['It is a guide to action that ensures that the military will forever heed Party commands'.split(),
'It is a guide to action that ensures that the military will forever heed Party commands'.split()], [
'It is a guide to action which ensures that the military always obeys the commands of the party'.split(),
'It is to insure the troops forever hearing the activity guidebook that party direct'.split()]).value(),
0.2903,
places=3)

class TestSequenceAccuracy(unittest.TestCase):
def setUp(self):
self.evaluator = evaluator.SequenceAccuracyEvaluator()
def test_correct(self):
self.assertEqual(self.evaluator.evaluate(["1 2 3".split()], ["1 2 3".split()]).value(), 1.0)
def test_incorrect(self):
self.assertEqual(self.evaluator.evaluate(["2 3".split()], ["1 2 3".split()]).value(), 0.0)
def test_corpus(self):
self.assertEqual(self.evaluator.evaluate(["1 2 3".split(), "2 3".split()],
["1 2 3".split(), "1 2 3".split()]).value(),
0.5)

if __name__ == '__main__':
unittest.main()
8 changes: 5 additions & 3 deletions xnmt/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class LossEvalTask(EvalTask, Serializable):

@serializable_init
def __init__(self, src_file: str, ref_file: str, model: GeneratorModel = Ref("model"),
batcher: Batcher = Ref("train.batcher", default=None), loss_calculator: LossCalculator = bare(MLELoss),
max_src_len: Optional[int] = None, max_trg_len: Optional[int] = None, desc: Any = None):
batcher: Optional[Batcher] = Ref("train.batcher", default=None),
loss_calculator: LossCalculator = bare(MLELoss), max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None, desc: Any = None):
self.model = model
self.loss_calculator = loss_calculator
self.src_file = src_file
Expand Down Expand Up @@ -107,10 +108,11 @@ class AccuracyEvalTask(EvalTask, Serializable):
def __init__(self, src_file: Union[str,Sequence[str]], ref_file: Union[str,Sequence[str]], hyp_file: str,
model: GeneratorModel = Ref("model"), eval_metrics: Union[str, Sequence[Evaluator]] = "bleu",
inference: Optional[SimpleInference] = None, candidate_id_file: Optional[str] = None,
desc: Optional[Any] = None):
desc: Any = None):
self.model = model
if isinstance(eval_metrics, str):
eval_metrics = [xnmt.xnmt_evaluate.eval_shortcuts[shortcut]() for shortcut in eval_metrics.split(",")]
elif not isinstance(eval_metrics, str): eval_metrics = [eval_metrics]
self.eval_metrics = eval_metrics
self.src_file = src_file
self.ref_file = ref_file
Expand Down

0 comments on commit fe0c891

Please sign in to comment.