Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #133 from justusschock/predictor_caching_patch
Browse files Browse the repository at this point in the history
Fix predictor caching behavior
  • Loading branch information
justusschock committed Jun 12, 2019
2 parents 9d25a68 + d60e964 commit 4e9643b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
11 changes: 6 additions & 5 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def _train_single_epoch(self, batchgen: Augmenter, epoch,
unit=' batch',
total=n_batches,
desc='Epoch %d' %
epoch)
epoch)
else:
iterable = enumerate(batchgen)

Expand Down Expand Up @@ -438,10 +438,11 @@ def reduce_fn(batch):
# next must be called here because self.predict_data_mgr
# returns a generator (of size 1) and we want to get the first
# (and only) item
val_predictions, val_metrics = next(self.predict_data_mgr(
datamgr_valid, datamgr_valid.batch_size,
metrics=val_metric_fns, metric_keys=val_metric_keys,
verbose=verbose, lazy_gen=False))
val_metrics = next(
self.predict_data_mgr_cache_metrics_only(
datamgr_valid, datamgr_valid.batch_size,
metrics=val_metric_fns, metric_keys=val_metric_keys,
verbose=verbose))

total_metrics.update(val_metrics)

Expand Down
5 changes: 2 additions & 3 deletions delira/training/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,8 @@ def test(self, network, test_data: BaseDataManager,
prepare_batch_fn=prepare_batch, **kwargs)

# return first item of generator
return next(predictor.predict_data_mgr(test_data, 1, metrics,
metric_keys, verbose,
lazy_gen=False))
return next(predictor.predict_data_mgr_cache_all(test_data, 1, metrics,
metric_keys, verbose))

def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
num_splits=None, shuffle=False, random_seed=None,
Expand Down
1 change: 1 addition & 0 deletions delira/training/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},

n_batches = batchgen.num_batches


if verbose:
iterable = tqdm(enumerate(batchgen), unit=' sample',
total=n_batches, desc=self._tqdm_desc)
Expand Down

0 comments on commit 4e9643b

Please sign in to comment.