Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/super-dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jun 10, 2019
2 parents 0561e8a + e8f204d commit 8c94732
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from inspect import signature
import os
import shutil
import contextlib

# These are fetched from globals, they're not unused
# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -1159,7 +1158,7 @@ def restart_generators(self, of_loader=None):
for from_loader in of_loader})
return self

def wrap_batch(self, batch, from_loader=None, requires_grad=False, volatile=False):
def wrap_batch(self, batch, from_loader=None, requires_grad=False):
base_device_ordinal = \
self._base_device_ordinal if hasattr(self, '_base_device_ordinal') else None
# First, send to the right device
Expand All @@ -1179,6 +1178,7 @@ def wrap_batch(self, batch, from_loader=None, requires_grad=False, volatile=Fals
RuntimeError)
# Get number of targets
num_targets = loader_spec['num_targets']
assert_(num_targets > 0, "Number of targets must be larger than zero.", RuntimeError)
# Fetch input batches and send'em to device (leave the targets alone)
inputs = batch[:-num_targets]
inputs = self.to_device(inputs)
Expand All @@ -1190,6 +1190,18 @@ def wrap_batch(self, batch, from_loader=None, requires_grad=False, volatile=Fals

# Cast to the right dtype and return
batch = self.cast(batch)
# Set gradients if required
variable_batch = []
for batch_num, _batch in enumerate(batch):
if thu.is_tensor(_batch):
variable_batch.append(_batch.requires_grad_() if requires_grad else _batch)
elif pyu.is_listlike(_batch):
variable_batch.append([__batch.requires_grad_() if requires_grad else __batch
for __batch in _batch])
else:
raise RuntimeError(f"Was Expecting batch at index {batch_num} to be either a "
f"tensor or a list of tensors. Got {type(_batch)} instead.")
batch = type(batch)(variable_batch)
return batch

def next_iteration(self):
Expand Down Expand Up @@ -1408,6 +1420,7 @@ def train_for(self, num_iterations=None, break_callback=None):
# Compute metric
if self.metric_is_defined and self.evaluate_metric_now:
self._last_metric_evaluated_at_epoch = self._epoch_count
# TODO Make unwrap a method for folks to overload
error = self.metric(thu.unwrap(prediction, to_cpu=False),
thu.unwrap(target, to_cpu=False))
self.update_state('training_error', thu.unwrap(error))
Expand Down Expand Up @@ -1507,13 +1520,10 @@ def validate_for(self, num_iterations=None, loader_name='validate'):

self.console.progress("Validating iteration {}.".format(iteration_num))

no_grad = torch.no_grad if hasattr(torch, 'no_grad') else contextlib.suppress
# Delay SIGINTs till after computation
with pyu.delayed_keyboard_interrupt(), no_grad():
with pyu.delayed_keyboard_interrupt(), torch.no_grad():
# Wrap
# FIXME The volatile=True is required for compatibility with older 0.3 code.
# FIXME Remove when support is deprecated.
batch = self.wrap_batch(batch, from_loader=loader_name, volatile=True)
batch = self.wrap_batch(batch, from_loader=loader_name)
# Separate
inputs, target = self.split_batch(batch, from_loader=loader_name)
# Apply model, compute loss
Expand Down

0 comments on commit 8c94732

Please sign in to comment.