diff --git a/deeppavlov/core/trainers/nn_trainer.py b/deeppavlov/core/trainers/nn_trainer.py index 1765a108e6..2eded68c4a 100644 --- a/deeppavlov/core/trainers/nn_trainer.py +++ b/deeppavlov/core/trainers/nn_trainer.py @@ -25,6 +25,7 @@ from deeppavlov.core.data.data_learning_iterator import DataLearningIterator from deeppavlov.core.trainers.fit_trainer import FitTrainer from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder +from deeppavlov.core.models.serializable import Serializable log = getLogger(__name__) @@ -72,6 +73,8 @@ class NNTrainer(FitTrainer): log_on_k_batches: count of random train batches to calculate metrics in log (default is ``1``) max_test_batches: maximum batches count for pipeline testing and evaluation, overrides ``log_on_k_batches``, ignored if negative (default is ``-1``) + save_every_n_batches: how often (in batches) to save model into f'{save_path}_{current_step}, the best model + is still saved to `save_path`, ignored if negative or zero (default is ``-1``) **kwargs: additional parameters whose names will be logged but otherwise ignored @@ -103,6 +106,7 @@ def __init__(self, chainer_config: dict, *, validate_first: bool = True, validation_patience: int = 5, val_every_n_epochs: int = -1, val_every_n_batches: int = -1, log_every_n_batches: int = -1, log_every_n_epochs: int = -1, log_on_k_batches: int = 1, + save_every_n_batches: int = -1, **kwargs) -> None: super().__init__(chainer_config, batch_size=batch_size, metrics=metrics, evaluation_targets=evaluation_targets, show_examples=show_examples, tensorboard_log_dir=tensorboard_log_dir, @@ -134,6 +138,7 @@ def _improved(op): self.log_every_n_epochs = log_every_n_epochs self.log_every_n_batches = log_every_n_batches self.log_on_k_batches = log_on_k_batches if log_on_k_batches >= 0 else None + self.save_every_n_batches = save_every_n_batches self.max_epochs = epochs self.epoch = start_epoch_num @@ -296,6 +301,14 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None: self._validate(iterator, tensorboard_tag='every_n_batches', tensorboard_index=self.train_batches_seen) + if self.save_every_n_batches > 0 and self.train_batches_seen % self.save_every_n_batches == 0: + # can't specify save path for chainer from here! + # have to duplicate code from chainer.save + main_component = self._chainer.get_main_component() + if isinstance(main_component, Serializable) and hasattr(main_component, 'save_path'): + log.info(f'Saving model at step: {self.train_batches_seen}') + main_component.save(f'{main_component.save_path}_{self.train_batches_seen}') + self._send_event(event_name='after_batch') if 0 < self.max_batches <= self.train_batches_seen: