Skip to content

Commit

Permalink
feat: add save_every_n_batches arg to NNTraner, resolves #1418
Browse files Browse the repository at this point in the history
  • Loading branch information
yurakuratov committed Mar 24, 2021
1 parent f1efeec commit 0db66f9
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
from deeppavlov.core.trainers.fit_trainer import FitTrainer
from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder
from deeppavlov.core.models.serializable import Serializable

log = getLogger(__name__)

Expand Down Expand Up @@ -72,6 +73,8 @@ class NNTrainer(FitTrainer):
log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``)
max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``,
ignored if negative (default is ``-1``)
save_every_n_batches: how often (in batches) to save model into f'{save_path}_{current_step}, the best model
is still saved to `save_path`, ignored if negative or zero (default is ``-1``)
**kwargs: additional parameters whose names will be logged but otherwise ignored
Expand Down Expand Up @@ -103,6 +106,7 @@ def __init__(self, chainer_config: dict, *,
validate_first: bool = True,
validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1,
log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1,
save_every_n_batches: int = -1,
**kwargs) -> None:
super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets,
show_examples=show_examples, tensorboard_log_dir=tensorboard_log_dir,
Expand Down Expand Up @@ -134,6 +138,7 @@ def _improved(op):
self.log_every_n_epochs = log_every_n_epochs
self.log_every_n_batches = log_every_n_batches
self.log_on_k_batches = log_on_k_batches if log_on_k_batches >= 0 else None
self.save_every_n_batches = save_every_n_batches

self.max_epochs = epochs
self.epoch = start_epoch_num
Expand Down Expand Up @@ -296,6 +301,14 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
self._validate(iterator,
tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen)

if self.save_every_n_batches > 0 and self.train_batches_seen % self.save_every_n_batches == 0:
# can't specify save path for chainer from here!
# have to duplicate code from chainer.save
main_component = self._chainer.get_main_component()
if isinstance(main_component, Serializable) and hasattr(main_component, 'save_path'):
log.info(f'Saving model at step: {self.train_batches_seen}')
main_component.save(f'{main_component.save_path}_{self.train_batches_seen}')

self._send_event(event_name='after_batch')

if 0 < self.max_batches <= self.train_batches_seen:
Expand Down

0 comments on commit 0db66f9

Please sign in to comment.