Skip to content

Commit

Permalink
Printing of debugging on crashes (#432)
Browse files Browse the repository at this point in the history
* Printing of debugging on crashes

* Added repr

* prototype context manager

* use logger consistently

* fixes

* more fixes

* Split train/dev error reporting

* Added missing import
  • Loading branch information
neubig committed Jun 19, 2018
1 parent 31bfb59 commit 977c432
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
18 changes: 10 additions & 8 deletions xnmt/eval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from xnmt.evaluator import LossScore
from xnmt.loss import FactoredLossExpr, FactoredLossVal
import xnmt.xnmt_evaluate
from xnmt import util

class EvalTask(object):
"""
Expand Down Expand Up @@ -72,16 +73,17 @@ def eval(self) -> tuple:
loss_val = FactoredLossVal()
ref_words_cnt = 0
for src, trg in zip(self.src_batches, self.ref_batches):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
with util.ReportOnException({"src": src, "trg": trg, "graph": dy.print_text_graphviz}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

loss_builder = FactoredLossExpr()
standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
additional_loss = self.model.calc_additional_loss(standard_loss)
loss_builder.add_factored_loss_expr(standard_loss)
loss_builder.add_factored_loss_expr(additional_loss)
loss_builder = FactoredLossExpr()
standard_loss = self.model.calc_loss(src, trg, self.loss_calculator)
additional_loss = self.model.calc_additional_loss(standard_loss)
loss_builder.add_factored_loss_expr(standard_loss)
loss_builder.add_factored_loss_expr(additional_loss)

ref_words_cnt += self.model.trg_reader.count_words(trg)
loss_val += loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)
ref_words_cnt += self.model.trg_reader.count_words(trg)
loss_val += loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)

loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}

Expand Down
5 changes: 4 additions & 1 deletion xnmt/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(self, words: Sequence[int], vocab: vocab.Vocab = None):
self.words = words
self.vocab = vocab

def __repr__(self):
return '{}'.format(self.words)

def __len__(self):
return len(self.words)

Expand Down Expand Up @@ -134,4 +137,4 @@ def get_padded_sent(self, token, pad_len):
return ArrayInput(new_nparr, padded_len=self.padded_len + pad_len)

def get_array(self):
return self.nparr
return self.nparr
21 changes: 11 additions & 10 deletions xnmt/training_regimen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from xnmt.loss_calculator import LossCalculator, MLELoss
from xnmt.param_collection import ParamManager
from xnmt.persistence import serializable_init, Serializable, bare, Ref
from xnmt import training_task, optimizer, batcher, eval_task
from xnmt import training_task, optimizer, batcher, eval_task, util

class TrainingRegimen(object):
"""
Expand Down Expand Up @@ -139,15 +139,16 @@ def run_training(self, save_fct, update_weights=True):
if self.dev_zero:
self.checkpoint_and_save(save_fct)
self.dev_zero = False
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
with self.train_loss_tracker.time_tracker:
self.model.set_train(True)
loss_builder = self.training_step(src, trg)
loss = loss_builder.compute()
if update_weights:
self.backward(loss, self.dynet_profiling)
self.update(self.trainer)
self.train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method))
with util.ReportOnException({"src": src, "trg": trg, "graph": dy.print_text_graphviz}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
with self.train_loss_tracker.time_tracker:
self.model.set_train(True)
loss_builder = self.training_step(src, trg)
loss = loss_builder.compute()
if update_weights:
self.backward(loss, self.dynet_profiling)
self.update(self.trainer)
self.train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method))
if self.checkpoint_needed():
self.checkpoint_and_save(save_fct)
if self.should_stop_training(): break
Expand Down
21 changes: 21 additions & 0 deletions xnmt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ def update(self, new):
self.stddev = math.sqrt(self.variance)
else:
assert len(self.vals) < self.N

class ReportOnException(object):
"""
Context manager that prints debug information when an exception occurs.
Args:
args: a dictionary containing debug info. Callable items are called, other items are passed to logger.error()
"""
def __init__(self, args: dict):
self.args = args
def __enter__(self):
return self
def __exit__(self, et, ev, traceback):
if et is not None: # exception occurred
logger.error("------ Fatal Error During Training! ------")
for key, val in self.args.items():
logger.error(f"*** {key} ***")
if callable(val):
val()
else:
logger.error(str(val))

0 comments on commit 977c432

Please sign in to comment.