Skip to content

Commit

Permalink
feat: TQDM added to trainers (#1593)
Browse files Browse the repository at this point in the history
Co-authored-by: Fedor Ignatov <ignatov.fedor@gmail.com>
  • Loading branch information
dimakarp1996 and IgnatovFedor committed Nov 5, 2022
1 parent a54b265 commit df54cbd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 4 additions & 2 deletions deeppavlov/core/trainers/fit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from logging import getLogger
from typing import Tuple, Dict, Union, Optional, Iterable, Any, Collection

from tqdm import tqdm

from deeppavlov.core.commands.infer import build_model
from deeppavlov.core.common.chainer import Chainer
from deeppavlov.core.common.params import from_params
Expand Down Expand Up @@ -90,7 +92,7 @@ def fit_chainer(self, iterator: Union[DataFittingIterator, DataLearningIterator]
targets = [targets]

if self.batch_size > 0 and callable(getattr(component, 'partial_fit', None)):
for i, (x, y) in enumerate(iterator.gen_batches(self.batch_size, shuffle=False)):
for i, (x, y) in tqdm(enumerate(iterator.gen_batches(self.batch_size, shuffle=False))):
preprocessed = self._chainer.compute(x, y, targets=targets)
# noinspection PyUnresolvedReferences
component.partial_fit(*preprocessed)
Expand Down Expand Up @@ -160,7 +162,7 @@ def test(self, data: Iterable[Tuple[Collection[Any], Collection[Any]]],

data = islice(data, self.max_test_batches)

for x, y_true in data:
for x, y_true in tqdm(data):
examples += len(x)
y_predicted = list(self._chainer.compute(list(x), list(y_true), targets=expected_outputs))
if len(expected_outputs) == 1:
Expand Down
7 changes: 5 additions & 2 deletions deeppavlov/core/trainers/nn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from pathlib import Path
from typing import List, Tuple, Union, Optional, Iterable

from tqdm import tqdm

from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.log_events import get_tb_writer
from deeppavlov.core.common.registry import register
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
from deeppavlov.core.trainers.fit_trainer import FitTrainer
from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder
from deeppavlov.core.common.log_events import get_tb_writer

log = getLogger(__name__)
report_log = getLogger('train_report')

Expand Down Expand Up @@ -273,7 +276,7 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
while True:
impatient = False
self._send_event(event_name='before_train')
for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'):
for x, y_true in tqdm(iterator.gen_batches(self.batch_size, data_type='train')):
self.last_result = self._chainer.train_on_batch(x, y_true)
if self.last_result is None:
self.last_result = {}
Expand Down

0 comments on commit df54cbd

Please sign in to comment.