Skip to content

Commit

Permalink
print CG on error only in debug mode (#517)
Browse files Browse the repository at this point in the history
* print CG on error only in debug mode

* remove debugging code

* fix readable sentence __str__ and __repr__
  • Loading branch information
msperber authored and neubig committed Aug 16, 2018
1 parent 776c1e0 commit ce31628
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion xnmt/eval/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def eval(self) -> 'EvalScore':
loss_val = FactoredLossVal()
ref_words_cnt = 0
for src, trg in zip(self.src_batches, self.ref_batches):
with utils.ReportOnException({"src": src, "trg": trg, "graph": dy.print_text_graphviz}):
with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)

loss = self.loss_calculator.calc_loss(self.model, src, trg)
Expand Down
2 changes: 1 addition & 1 deletion xnmt/inferences.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _generate_output(self, generator: 'models.GeneratorModel', src_corpus: Seque
fp.write(f"{output_txt}\n")
else:
if forced_ref_corpus: ref_batch = ref_batches[batch_i]
with utils.ReportOnException({"batchno":batch_i, "src": src_batch, "graph": dy.print_text_graphviz}):
with utils.ReportOnException({"batchno":batch_i, "src": src_batch, "graph": utils.print_cg_conditional}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
outputs = self.generate_one(generator, src_batch, ref_batch)
if self.reporter: self._create_sent_report()
Expand Down
11 changes: 4 additions & 7 deletions xnmt/sent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def sent_str(self, custom_output_procs=None, **kwargs) -> str:
if isinstance(pps, OutputProcessor): pps = [pps]
for pp in pps:
out_str = pp.process(out_str)
# TODO: change output processor interface accordingly
return out_str
def __repr__(self):
return f'"{self.sent_str()}"'
def __str__(self):
return self.sent_str()

class ScalarSentence(ReadableSentence):
"""
Expand Down Expand Up @@ -178,12 +181,6 @@ def __init__(self, words: Sequence[int], idx: Optional[int] = None, vocab: Optio
self.words = words
self.vocab = vocab

def __repr__(self):
return f"SimpleSentence({repr(self.words)})"

def __str__(self):
return self.sent_str()

def __getitem__(self, key):
ret = self.words[key]
if isinstance(ret, list): # support for slicing
Expand Down
2 changes: 2 additions & 0 deletions xnmt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Standard(object):
IMMEDIATE_COMPUTE = False
CHECK_VALIDITY = False
RESOURCE_WARNINGS = False
PRINT_CG_ON_ERROR = False
LOG_LEVEL_CONSOLE = "INFO"
LOG_LEVEL_FILE = "DEBUG"
DEFAULT_MOD_PATH = "{EXP_DIR}/models/{EXP}.mod"
Expand All @@ -47,6 +48,7 @@ class Debug(Standard):
IMMEDIATE_COMPUTE = True
CHECK_VALIDITY = True
RESOURCE_WARNINGS = True
PRINT_CG_ON_ERROR = True
LOG_LEVEL_CONSOLE = "DEBUG"
LOG_LEVEL_FILE = "DEBUG"

Expand Down
2 changes: 1 addition & 1 deletion xnmt/train/regimens.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def run_training(self, save_fct):
if self.dev_zero:
self.checkpoint_and_save(save_fct)
self.dev_zero = False
with utils.ReportOnException({"src": src, "trg": trg, "graph": dy.print_text_graphviz}):
with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
with self.train_loss_tracker.time_tracker:
event_trigger.set_train(True)
Expand Down
6 changes: 6 additions & 0 deletions xnmt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import functools

import numpy as np
import dynet as dy

from xnmt import logger, yaml_logger
from xnmt.settings import settings

def print_cg_conditional():
if settings.PRINT_CG_ON_ERROR:
dy.print_text_graphviz()

def make_parent_dir(filename):
if not os.path.exists(os.path.dirname(filename) or "."):
Expand Down

0 comments on commit ce31628

Please sign in to comment.