Skip to content

Commit

Permalink
Several miscellaneous robustness fixes (#464)
Browse files Browse the repository at this point in the history
* Several miscellaneous robustness fixes

* don't revert if model has never been saved
  • Loading branch information
neubig committed Jul 14, 2018
1 parent da9455d commit 9ed5a34
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
8 changes: 5 additions & 3 deletions xnmt/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from xnmt.exp_global import ExpGlobal
from xnmt.eval_task import EvalTask
from xnmt.model_base import TrainableModel
from xnmt.param_collection import ParamManager
from xnmt.param_collection import ParamManager, RevertingUnsavedModelException
from xnmt.preproc_runner import PreprocRunner
from xnmt.training_regimen import TrainingRegimen
from xnmt.persistence import serializable_init, Serializable, bare
Expand Down Expand Up @@ -51,10 +51,12 @@ def __call__(self, save_fct):
eval_scores = ["Not evaluated"]
if self.train:
logger.info("> Training")
save_fct() # save initial model
self.train.run_training(save_fct = save_fct)
logger.info('reverting learned weights to best checkpoint..')
ParamManager.param_col.revert_to_best_model()
try:
ParamManager.param_col.revert_to_best_model()
except RevertingUnsavedModelException:
pass

evaluate_args = self.evaluate
if evaluate_args:
Expand Down
4 changes: 3 additions & 1 deletion xnmt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dynet as dy

import xnmt.batcher
from xnmt import loss, loss_calculator, model_base, output, reports, search_strategy, util
from xnmt import logger, loss, loss_calculator, model_base, output, reports, search_strategy, util
from xnmt.persistence import serializable_init, Serializable, bare

NO_DECODING_ATTEMPTED = "@@NO_DECODING_ATTEMPTED@@"
Expand Down Expand Up @@ -67,6 +67,8 @@ def perform_inference(self, generator: model_base.GeneratorModel, src_file: str
trg_file = trg_file or self.trg_file
util.make_parent_dir(trg_file)

logger.info(f'Performing inference on {src_file}')

ref_corpus, src_corpus = self._read_corpus(generator, src_file, mode=self.mode, ref_file=self.ref_file)

generator.set_train(False)
Expand Down
14 changes: 10 additions & 4 deletions xnmt/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ def update(self) -> None:
"""
Update the parameters.
"""
if not (self.skip_noisy and self._check_gradients_noisy()):
self.optimizer.update()
else:
logger.info("skipping noisy update")
try:
if not (self.skip_noisy and self._check_gradients_noisy()):
self.optimizer.update()
else:
logger.info("skipping noisy update")
except RuntimeError:
logger.warning("Failed to perform update. Skipping example and clearing gradients.")
for subcol in ParamManager.param_col.subcols.values():
for param in subcol.parameters_list():
param.scale_gradient(0)

def status(self):
"""
Expand Down
7 changes: 5 additions & 2 deletions xnmt/param_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def _update_data_files(self):
def add_subcollection(self, subcol_owner, subcol_name):
assert subcol_owner not in self.all_subcol_owners
self.all_subcol_owners.add(subcol_owner)
assert subcol_name not in self.subcols
if subcol_name in self.subcols:
raise RuntimeError(f'Duplicate subcol_name {subcol_name} found when loading')
new_subcol = self._param_col.add_subcollection(subcol_name)
self.subcols[subcol_name] = new_subcol
return new_subcol
Expand All @@ -165,7 +166,7 @@ def save(self):

def revert_to_best_model(self):
if not self._is_saved:
raise ValueError("revert_to_best_model() is illegal because this model has never been saved.")
raise RevertingUnsavedModelException("revert_to_best_model() is illegal because this model has never been saved.")
for subcol_name, subcol in self.subcols.items():
subcol.populate(os.path.join(self._data_files[0], subcol_name))

Expand All @@ -192,3 +193,5 @@ def _shift_saved_checkpoints(self):
for i in range(len(self._data_files)-1)[::-1]:
if os.path.exists(self._data_files[i]):
os.rename(self._data_files[i], self._data_files[i+1])

class RevertingUnsavedModelException(Exception): pass

0 comments on commit 9ed5a34

Please sign in to comment.