Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- validate_for can now (optionally) be specified a loader to valida…
Browse files Browse the repository at this point in the history
…te on
  • Loading branch information
nasimrahaman committed Oct 11, 2017
1 parent 72e4c23 commit f76ded1
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,27 @@ def train_for(self, num_iterations=None, break_callback=None):
self.callbacks.call(self.callbacks.END_OF_TRAINING_RUN, num_iterations=num_iterations)
return self

def validate_for(self, num_iterations=None):
def validate_for(self, num_iterations=None, loader_name='validate'):
"""
Validate for a given number of validation (if `num_iterations is not None`)
or over the entire (validation) data set.
Parameters
----------
num_iterations : int
Number of iterations to validate for. To validate on the entire dataset,
leave this as `None`.
loader_name : str
Name of the data loader to use for validation. 'validate' is the obvious default.
Returns
-------
Trainer
self.
"""
assert_(loader_name in ['validate', 'test', 'train'],
"Invalid `loader_name`: {}".format(loader_name),
ValueError)
# Average over errors
validation_error_meter = tu.AverageMeter()
validation_loss_meter = tu.AverageMeter()
Expand All @@ -1247,7 +1267,7 @@ def validate_for(self, num_iterations=None):
# If we don't know num_iterations, we're validating the entire dataset - so we might as
# well restart the loader now
if num_iterations is None:
self.restart_generators('validate')
self.restart_generators(loader_name)

while True:
if num_iterations is not None and iteration_num > num_iterations:
Expand All @@ -1257,23 +1277,23 @@ def validate_for(self, num_iterations=None):
iteration_num=iteration_num)

try:
batch = self.fetch_next_batch('validate',
batch = self.fetch_next_batch(loader_name,
restart_exhausted_generators=
num_iterations is not None,
update_batch_count=False,
update_epoch_count_if_generator_exhausted=False)
except StopIteration:
self.print("Validation generator exhausted, breaking.")
self.print("{} generator exhausted, breaking.".format(loader_name))
break

self.print("Validating iteration {}.".format(iteration_num))

# Delay SIGINTs till after computation
with pyu.delayed_keyboard_interrupt():
# Wrap
batch = self.wrap_batch(batch, from_loader='validate', volatile=True)
batch = self.wrap_batch(batch, from_loader=loader_name, volatile=True)
# Separate
inputs, target = self.split_batch(batch, from_loader='validate')
inputs, target = self.split_batch(batch, from_loader=loader_name)
# Apply model, compute loss
output, loss = self.apply_model_and_loss(inputs, target, backward=False)
batch_size = target.size(0)
Expand Down

0 comments on commit f76ded1

Please sign in to comment.