Skip to content

Commit

Permalink
Configure how to combine batched losses (#411)
Browse files Browse the repository at this point in the history
* some documentation and cleanup for the loss builders

* configure batch_comb

* fix unit test
  • Loading branch information
msperber committed May 31, 2018
1 parent 08ea387 commit d9e227b
Show file tree
Hide file tree
Showing 13 changed files with 177 additions and 90 deletions.
4 changes: 2 additions & 2 deletions docs/api_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ Multi-layer Perceptron
Loss
----

LossBuilder
~~~~~~~~~~~
Loss
~~~~

.. automodule:: xnmt.loss
:members:
Expand Down
4 changes: 2 additions & 2 deletions test/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_train_dev_loss_equal(self):
train_args['run_for_epochs'] = 1
training_regimen = xnmt.training_regimen.SimpleTrainingRegimen(**train_args)
training_regimen.run_training(save_fct = lambda: None, update_weights=False)
self.assertAlmostEqual(training_regimen.train_loss_tracker.epoch_loss.sum() / training_regimen.train_loss_tracker.epoch_words,
self.assertAlmostEqual(training_regimen.train_loss_tracker.epoch_loss.sum_factors() / training_regimen.train_loss_tracker.epoch_words,
training_regimen.dev_loss_tracker.dev_score.loss, places=5)

class TestOverfitting(unittest.TestCase):
Expand Down Expand Up @@ -379,7 +379,7 @@ def test_overfitting(self):
for _ in range(50):
training_regimen.run_training(save_fct=lambda:None, update_weights=True)
self.assertAlmostEqual(0.0,
training_regimen.train_loss_tracker.epoch_loss.sum() / training_regimen.train_loss_tracker.epoch_words,
training_regimen.train_loss_tracker.epoch_loss.sum_factors() / training_regimen.train_loss_tracker.epoch_words,
places=2)

if __name__ == '__main__':
Expand Down
21 changes: 12 additions & 9 deletions xnmt/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xnmt.persistence import serializable_init, Serializable, Ref, bare
from xnmt.loss_calculator import LossCalculator, MLELoss
from xnmt.evaluator import LossScore
from xnmt.loss import LossBuilder, LossScalarBuilder
from xnmt.loss import FactoredLossExpr, FactoredLossVal
import xnmt.xnmt_evaluate

class EvalTask(object):
Expand All @@ -33,7 +33,8 @@ class LossEvalTask(EvalTask, Serializable):
batcher: batcher to use
loss_calculator: loss calculator
max_src_len: omit sentences with source length greater than specified number
max_trg_len:omit sentences with target length greater than specified number
max_trg_len: omit sentences with target length greater than specified number
loss_comb_method: method for combining loss across batch elements ('sum' or 'avg').
desc: description to pass on to computed score objects
"""
yaml_tag = '!LossEvalTask'
Expand All @@ -42,7 +43,8 @@ class LossEvalTask(EvalTask, Serializable):
def __init__(self, src_file: str, ref_file: str, model: GeneratorModel = Ref("model"),
batcher: Optional[Batcher] = Ref("train.batcher", default=None),
loss_calculator: LossCalculator = bare(MLELoss), max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None, desc: Any = None):
max_trg_len: Optional[int] = None,
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), desc: Any = None):
self.model = model
self.loss_calculator = loss_calculator
self.src_file = src_file
Expand All @@ -51,6 +53,7 @@ def __init__(self, src_file: str, ref_file: str, model: GeneratorModel = Ref("mo
self.src_data = None
self.max_src_len = max_src_len
self.max_trg_len = max_trg_len
self.loss_comb_method = loss_comb_method
self.desc=desc

def eval(self) -> tuple:
Expand All @@ -66,26 +69,26 @@ def eval(self) -> tuple:
xnmt.input_reader.read_parallel_corpus(self.model.src_reader, self.model.trg_reader,
self.src_file, self.ref_file, batcher=self.batcher,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
loss_val = LossScalarBuilder()
loss_val = FactoredLossVal()
ref_words_cnt = 0
for src, trg in zip(self.src_batches, self.ref_batches):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

loss_builder = LossBuilder()
loss_builder = FactoredLossExpr()
standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
additional_loss = self.model.calc_additional_loss(standard_loss)
loss_builder.add_loss("standard_loss", standard_loss)
loss_builder.add_loss("additional_loss", additional_loss)
loss_builder.add_factored_loss_expr(standard_loss)
loss_builder.add_factored_loss_expr(additional_loss)

ref_words_cnt += self.model.trg_reader.count_words(trg)
loss_val += loss_builder.get_loss_stats()
loss_val += loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)

loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

try:
return LossScore(loss_stats[self.model.get_primary_loss()], loss_stats=loss_stats, desc=self.desc), ref_words_cnt
except KeyError:
raise RuntimeError("Did you wrap your loss calculation with LossBuilder({'primary_loss': loss_value}) ?")
raise RuntimeError("Did you wrap your loss calculation with FactoredLossExpr({'primary_loss': loss_value}) ?")

class AccuracyEvalTask(EvalTask, Serializable):
"""
Expand Down
2 changes: 1 addition & 1 deletion xnmt/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def ref_len(self):
def higher_is_better(self): return False
def score_str(self):
return f"{self.value()*100.0:.2f}% " \
f"( C/I/D/S: {self.correct}/{self.insertions}/{self.deletions}/{self.substitutions}; " \
f"( C/S/I/D: {self.correct}/{self.substitutions}/{self.insertions}/{self.deletions}; " \
f"hyp_len={self.hyp_len()}, ref_len={self.ref_len()} )"
@staticmethod
def aggregate(scores: Sequence['LevenshteinScore'], desc: Any = None) -> 'LevenshteinScore':
Expand Down
3 changes: 3 additions & 0 deletions xnmt/exp_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ExpGlobal(Serializable):
param_init: Default parameter initializer that should be used by supporting components but can be overwritten
bias_init: Default initializer for bias parameters that should be used by supporting components but can be overwritten
save_num_checkpoints: save DyNet parameters for the most recent n checkpoints, useful for model averaging/ensembling
loss_comb_method: method for combining loss across batch elements ('sum' or 'avg').
commandline_args: Holds commandline arguments with which XNMT was launched
placeholders: these will be used as arguments for a format() call applied to every string in the config.
For example, ``placeholders: {"PATH":"/some/path"} will cause each occurence of ``"{PATH}"`` in a
Expand All @@ -36,6 +37,7 @@ def __init__(self,
param_init: ParamInitializer = bare(GlorotInitializer),
bias_init: ParamInitializer = bare(ZeroInitializer),
save_num_checkpoints: int = 1,
loss_comb_method: str = "sum",
commandline_args=None,
placeholders: Dict[str, str] = {}) -> None:
self.model_file = model_file
Expand All @@ -47,4 +49,5 @@ def __init__(self,
self.bias_init = bias_init
self.commandline_args = commandline_args
self.save_num_checkpoints = save_num_checkpoints
self.loss_comb_method = loss_comb_method
self.placeholders = placeholders
170 changes: 120 additions & 50 deletions xnmt/loss.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,140 @@
import dynet as dy
from typing import Optional, Dict, List, Tuple
import collections

class LossBuilder(object):

# TODO: document me

def __init__(self, init_loss=None):
self.loss_values = collections.defaultdict(lambda: dy.scalarInput(0))
if init_loss is not None:
for key, val in init_loss.items():
self.loss_values[key] = val

def add_loss(self, loss_name, loss_expr):
if loss_expr is None:
return
if type(loss_expr) == LossBuilder:
for loss_name, loss in loss_expr.loss_values.items():
self.loss_values[loss_name] += loss
else:
self.loss_values[loss_name] += loss_expr
import dynet as dy

def compute(self):
return dy.sum_batches(dy.esum(list(self.loss_values.values())))
class FactoredLossExpr(object):

"""
Loss consisting of (possibly batched) DyNet expressions, with one expression per loss factor.
def value(self):
return dy.esum(list(self.loss_values.values())).value()
Used to represent losses within a training step.
def __getitem__(self, index):
return self.loss_values[index]
Args:
init_loss: initial loss values
"""

def get_loss_stats(self):
return LossScalarBuilder({k: dy.sum_batches(v).value() for k, v in self.loss_values.items()})
def __init__(self, init_loss: Optional[Dict[str, dy.Expression]] = None) -> None:
self.expr_factors = collections.defaultdict(lambda: dy.scalarInput(0))
if init_loss is not None:
for key, val in init_loss.items():
self.expr_factors[key] = val

def add_loss(self, loss_name: str, loss_expr: Optional[dy.Expression]) -> None:
if loss_expr:
self.expr_factors[loss_name] += loss_expr

def add_factored_loss_expr(self, factored_loss_expr: Optional['FactoredLossExpr']) -> None:
if factored_loss_expr:
for loss_name, loss in factored_loss_expr.expr_factors.items():
self.expr_factors[loss_name] += loss

def compute(self, comb_method: str = "sum") -> dy.Expression:
"""
Compute loss as DyNet expression by summing over factors and batch elements.
Args:
comb_method: method for combining loss across batch elements ('sum' or 'avg').
Returns:
Scalar DyNet expression.
"""
return self._combine_batches(dy.esum(list(self.expr_factors.values())), comb_method)

def value(self) -> List[float]:
"""
Get list of per-batch-element loss values, summed over factors.
Returns:
List of same length as batch-size.
"""
return dy.esum(list(self.expr_factors.values())).value()

def __getitem__(self, loss_name: str) -> dy.Expression:
return self.expr_factors[loss_name]

def get_factored_loss_val(self, comb_method: str = "sum") -> 'FactoredLossVal':
"""
Create factored loss values by calling ``.value()`` for each DyNet loss expression and applying batch combination.
Args:
comb_method: method for combining loss across batch elements ('sum' or 'avg').
Returns:
Factored loss values.
"""
return FactoredLossVal({k: self._combine_batches(v, comb_method).value() for k, v in self.expr_factors.items()})

def _combine_batches(self, batched_expr, comb_method: str = "sum"):
if comb_method == "sum":
return dy.sum_batches(batched_expr)
elif comb_method == "avg":
return dy.sum_batches(batched_expr) * (1.0 / batched_expr.dim()[1])
else:
raise ValueError(f"Unknown batch combination method '{comb_method}', expected 'sum' or 'avg'.'")

def __len__(self):
return len(self.loss_values)
return len(self.expr_factors)

def __repr__(self):
loss_str = ", ".join(["%s %f" % (loss_name, dy.sum_batches(loss_value).value()) for loss_name, loss_value in self.loss_values.items()])
return "{Loss Builder: %s}" % (loss_str)
loss_str = ", ".join(
[f"{loss_name} {dy.sum_batches(loss_value).value()}" for loss_name, loss_value in self.expr_factors.items()])
return f"{{Loss Builder: {loss_str}}}"

class LossScalarBuilder(object):
class FactoredLossVal(object):

# TODO: document me

def __init__(self, loss_stats=None):
if loss_stats is None:
loss_stats = {}
self.__loss_stats = loss_stats

def __iadd__(self, other):
for name, value in other.__loss_stats.items():
if name in self.__loss_stats:
self.__loss_stats[name] += value
"""
Loss consisting of (unbatched) float values, with one value per loss factor.
Used to represent losses accumulated across several training steps.
"""

def __init__(self, loss_dict = None) -> None:
if loss_dict is None:
loss_dict = {}
self._loss_dict = loss_dict

def __iadd__(self, other: 'FactoredLossVal'):
"""
Implements += operator, adding up factors individually.
Args:
other: other factored float loss
Returns:
self
"""
for name, value in other._loss_dict.items():
if name in self._loss_dict:
self._loss_dict[name] += value
else:
self.__loss_stats[name] = value
self._loss_dict[name] = value
return self

def sum(self):
return sum([x for x in self.__loss_stats.values()])
def sum_factors(self) -> float:
"""
Return the sum of all loss factors.
Returns:
A float value.
"""
return sum([x for x in self._loss_dict.values()])

def items(self) -> List[Tuple[str, float]]:
"""
Get name/value tuples for loss factors.
def items(self):
return self.__loss_stats.items()
Returns:
Name/value tuples.
"""
return self._loss_dict.items()

def __len__(self):
return len(self.__loss_stats)
return len(self._loss_dict)

def zero(self):
self.__loss_stats.clear()
def clear(self) -> None:
"""
Clears all loss factors.
"""
self._loss_dict.clear()

8 changes: 4 additions & 4 deletions xnmt/loss_calculator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dynet as dy
import numpy as np

from xnmt.loss import LossBuilder
from xnmt.loss import FactoredLossExpr
from xnmt.persistence import serializable_init, Serializable, Ref
from xnmt.vocab import Vocab
from xnmt.constants import INFINITY
Expand Down Expand Up @@ -54,7 +54,7 @@ def __call__(self, translator, initial_state, src, trg):
if i < seq_len-1:
dec_state = translator.decoder.add_input(dec_state, translator.trg_embedder.embed(ref_word))

return dy.esum(losses)
return FactoredLossExpr({"mle": dy.esum(losses)})

class ReinforceLoss(Serializable, LossCalculator):
yaml_tag = '!ReinforceLoss'
Expand Down Expand Up @@ -93,7 +93,7 @@ def __call__(self, translator, initial_state, src, trg):
self.eval_score.append(score)
self.true_score = dy.inputTensor(self.eval_score, batched=True)
# Composing losses
loss = LossBuilder()
loss = FactoredLossExpr()
if self.use_baseline:
baseline_loss = []
losses = []
Expand Down Expand Up @@ -168,5 +168,5 @@ def __call__(self, translator, initial_state, src, trg):
#print("----------------------")
### End debug

return LossBuilder({"risk": risk})
return FactoredLossExpr({"risk": risk})

6 changes: 3 additions & 3 deletions xnmt/loss_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TrainLossTracker(object):
def __init__(self, training_task):
self.training_task = training_task

self.epoch_loss = loss.LossScalarBuilder()
self.epoch_loss = loss.FactoredLossVal()
self.epoch_words = 0
self.last_report_sents_into_epoch = 0
self.last_report_sents_since_start = 0
Expand All @@ -46,7 +46,7 @@ def __init__(self, training_task):
@events.handle_xnmt_event
def on_new_epoch(self, training_task, num_sents):
if training_task is self.training_task:
self.epoch_loss.zero()
self.epoch_loss.clear()
self.epoch_words = 0
self.last_report_sents_since_start = 0
self.last_report_words = 0
Expand All @@ -70,7 +70,7 @@ def report(self, trg, loss):
TrainLossTracker.REPORT_TEMPLATE_SPEED if accum_time else TrainLossTracker.REPORT_TEMPLATE,
{"key": "train_loss", "data": "train",
"epoch": fractional_epoch,
"loss": self.epoch_loss.sum() / self.epoch_words,
"loss": self.epoch_loss.sum_factors() / self.epoch_words,
"words": self.epoch_words,
"words_per_sec": (self.epoch_words - self.last_report_words) / (
accum_time) if accum_time else "-",
Expand Down
2 changes: 1 addition & 1 deletion xnmt/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, src_reader: input_reader.InputReader, trg_reader: input_reade
self.trg_reader = trg_reader

def calc_loss(self, src: Union[batcher.Batch, input.Input], trg: Union[batcher.Batch, input.Input],
loss_calculator: loss_calculator.LossCalculator) -> loss.LossBuilder:
loss_calculator: loss_calculator.LossCalculator) -> loss.FactoredLossExpr:
'''Calculate loss based on input-output pairs.
Losses are accumulated only across unmasked timesteps in each batch element.
Expand Down

0 comments on commit d9e227b

Please sign in to comment.