Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- made sure that the trainer never validates at epoch 0 / iteration 0
Browse files Browse the repository at this point in the history
- covered a few corner cases in `Trainer`
  • Loading branch information
nasimrahaman committed Oct 11, 2017
1 parent 428cdda commit 5cdb34b
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, model=None):
self._num_validation_iterations = None
# We should exclude the zero-th epoch from validation
self._last_validated_at_epoch = 0
self._last_validated_at_iteration = 0
# This is to allow a callback to trigger a validation by setting
# trainer.validate_now = True
self._validation_externally_triggered = False
Expand Down Expand Up @@ -563,10 +564,18 @@ def validate_now(self):
return False
else:
# If we haven't validated this epoch, check if we should
return self._validate_every.match(epoch_count=self._epoch_count)
return self._validate_every.match(epoch_count=self._epoch_count,
match_zero=False)
else:
return self._validate_every is not None and \
self._validate_every.match(iteration_count=self._iteration_count)
# Don't validate if we've done once already this iteration
if self._last_validated_at_iteration == self._iteration_count:
return False
else:
# If we haven't validated this iteration, check if we should. The `match_zero` is
# redundant, but we'll leave it on anyway.
return self._validate_every is not None and \
self._validate_every.match(iteration_count=self._iteration_count,
match_zero=False)

@validate_now.setter
def validate_now(self, value):
Expand Down Expand Up @@ -1259,7 +1268,7 @@ def validate_for(self, num_iterations=None, loader_name='validate'):

# Record the epoch we're validating in
self._last_validated_at_epoch = self._epoch_count

self._last_validated_at_iteration = self._iteration_count
self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_RUN,
num_iterations=num_iterations,
last_validated_at_epoch=self._last_validated_at_epoch)
Expand Down

0 comments on commit 5cdb34b

Please sign in to comment.