Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
fix generator behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Jun 3, 2019
1 parent 44ef716 commit b510ffb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
19 changes: 11 additions & 8 deletions delira/training/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def __init__(self,
def _setup(self, network, lr_scheduler_cls, lr_scheduler_params, gpu_ids,
key_mapping, convert_batch_to_npy_fn, prepare_batch_fn):

super()._setup(network, key_mapping, convert_batch_to_npy_fn,
super()._setup(network, key_mapping, convert_batch_to_npy_fn,
prepare_batch_fn)

self.closure_fn = network.closure
Expand Down Expand Up @@ -427,10 +427,13 @@ def reduce_fn(batch): return batch[-1]

# validate network
if datamgr_valid is not None and (epoch % self.val_freq == 0):
val_predictions, val_metrics = self.predict_data_mgr(
# next must be called here because self.predict_data_mgr
# returns a generator (of size 1) and we want to get the first
# (and only) item
val_predictions, val_metrics = next(self.predict_data_mgr(
datamgr_valid, datamgr_valid.batch_size,
metrics=val_metric_fns, metric_keys=val_metric_keys,
verbose=verbose)
verbose=verbose, lazy_gen=False))

total_metrics.update(val_metrics)

Expand Down Expand Up @@ -458,7 +461,7 @@ def reduce_fn(batch): return batch[-1]

# set best_val_score to new_val_score if is_best
best_val_score = int(is_best) * new_val_score + \
(1 - int(is_best)) * best_val_score
(1 - int(is_best)) * best_val_score

if is_best and verbose:
logging.info("New Best Value at Epoch %03d : %03.3f" %
Expand Down Expand Up @@ -532,12 +535,12 @@ def register_callback(self, callback: AbstractCallback):
"""
assertion_str = "Given callback is not valid; Must be instance of " \
"AbstractCallback or provide functions " \
"'at_epoch_begin' and 'at_epoch_end'"
"AbstractCallback or provide functions " \
"'at_epoch_begin' and 'at_epoch_end'"

assert isinstance(callback, AbstractCallback) or \
(hasattr(callback, "at_epoch_begin")
and hasattr(callback, "at_epoch_end")), assertion_str
(hasattr(callback, "at_epoch_begin")
and hasattr(callback, "at_epoch_end")), assertion_str

self._callbacks.append(callback)

Expand Down
42 changes: 22 additions & 20 deletions delira/training/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def predict(self, data: dict, **kwargs):
)[1]

def predict_data_mgr(self, datamgr, batchsize=None, metrics={},
metric_keys=None, verbose=False, as_generator=False,
metric_keys=None, verbose=False, lazy_gen=False,
**kwargs):

"""
Expand All @@ -158,32 +158,27 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},
the ``batch_dict`` items to use for metric calculation
verbose : bool
whether to show a progress-bar or not, default: False
as_generator : bool
lazy_gen : bool
if True: Yields results instead of returning them; should be
specified if predicting on a low-memory device or when results
should be saved immediately
kwargs :
keyword arguments passed to :func:`prepare_batch_fn`
Returns
-------
dict
a dictionary containing all predictions;
if ``as_generator`` is False
dict
a dictionary containing all validation metrics (maybe empty);
if ``as_generator`` is False
None
if ``as_generator`` is True
Yields
------
dict
a dictionary containing all predictions of the current batch
if ``as_generator`` is True
if ``lazy_gen`` is True
dict
a dictionary containing all metrics of the current batch
if ``as_generator`` is True
if ``lazy_gen`` is True
dict
a dictionary containing all predictions;
if ``lazy_gen`` is False
dict
a dictionary containing all validation metrics (maybe empty);
if ``lazy_gen`` is False
"""

Expand All @@ -198,7 +193,7 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},

batchgen = datamgr.get_batchgen()

if not as_generator:
if not lazy_gen:
predictions_all, metric_vals = [], {k: [] for k in metrics.keys()}

n_batches = batchgen.generator.num_batches * batchgen.num_processes
Expand Down Expand Up @@ -250,7 +245,7 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},
metrics=metrics,
metric_keys=metric_keys)

if as_generator:
if lazy_gen:
yield preds, _metric_vals
else:
for k, v in _metric_vals.items():
Expand All @@ -264,7 +259,8 @@ def predict_data_mgr(self, datamgr, batchsize=None, metrics={},
datamgr.batch_size = orig_batch_size
datamgr.n_process_augmentation = orig_num_aug_processes

if as_generator:
if lazy_gen:
# triggers stopiteration
return

# convert predictions from list of dicts to dict of lists
Expand Down Expand Up @@ -346,7 +342,13 @@ def __concatenate_dict_items(dict_like: dict):
for k, v in metric_vals.items():
metric_vals[k] = np.array(v)

return predictions_all, metric_vals
# must yield these instead of returning them,
# because every function with a yield in it's body returns a
# generator object (even if the yield is never triggered)
yield predictions_all, metric_vals

# triggers stopiteration
return

def __setattr__(self, key, value):
"""
Expand Down Expand Up @@ -401,4 +403,4 @@ def calc_metrics(batch: LookupConfig, metrics={}, metric_keys=None):
metric_keys = {k: ("pred", "label") for k in metrics.keys()}

return {key: metric_fn(*[batch.nested_get(k) for k in metric_keys[key]])
for key, metric_fn in metrics.items()}
for key, metric_fn in metrics.items()}

0 comments on commit b510ffb

Please sign in to comment.