Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #180 from inferno-pytorch/remove-variable
Browse files Browse the repository at this point in the history
Remove variable
  • Loading branch information
constantinpape committed Jun 10, 2019
2 parents d451bf4 + 93f8580 commit 0561e8a
Show file tree
Hide file tree
Showing 17 changed files with 94 additions and 160 deletions.
9 changes: 7 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ python:
- 3.7

env:
# FIXME multi-processing hangs with pytorch 1.0 and uinttest test discovery
# - PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch
- PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch
- PYTORCH_CONDA="pytorch=0.4.1" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch

# exclude hanging build
matrix:
exclude:
- python: 3.6
env: PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch

install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
Expand Down
10 changes: 5 additions & 5 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ With our model built, it's time to worry about the data generators. Or is it?
.. code:: python
from inferno.io.box.cifar import get_cifar10_loaders
train_loader, validate_loader = get_cifar10_loaders('path/to/cifar10',
download=True,
train_batch_size=128,
train_loader, validate_loader = get_cifar10_loaders('path/to/cifar10',
download=True,
train_batch_size=128,
test_batch_size=100)
CIFAR-10 works out-of-the-`box` (pun very much intended) with all the fancy data-augmentation and normalization. Of course, it's perfectly fine if you have your own [`DataLoader`](http://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader).
Expand Down Expand Up @@ -117,7 +117,7 @@ or
trainer.build_metric(MyMetric, **my_metric_kwargs)
Note that the metric applies to `torch.Tensor`s, and not on `torch.autograd.Variable`s. Also, a metric might be way too expensive to evaluate every training iteration without slowing down the training. If this is the case and you'd like to evaluate the metric every (say) 10 *training* iterations:
A metric might be way too expensive to evaluate every training iteration without slowing down the training. If this is the case and you'd like to evaluate the metric every (say) 10 *training* iterations:

.. code:: python
Expand Down Expand Up @@ -254,7 +254,7 @@ Inferno supports logging scalars and images to Tensorboard out-of-the-box, thoug
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every=(20, 'iterations')),
log_directory='/path/to/log/directory')
Expand Down
4 changes: 1 addition & 3 deletions inferno/extensions/containers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def copy(self, **init_kwargs):
return new



class Graph(nn.Module):
"""
A graph structure to build networks with complex architectures. The resulting graph model
Expand All @@ -64,7 +63,6 @@ class Graph(nn.Module):
>>> from inferno.extensions.layers.reshape import Concatenate
>>> from inferno.extensions.layers.convolutional import ConvELU2D
>>> import torch
>>> from torch.autograd import Variable
>>> # Build the model
>>> inception_module = Graph()
>>> inception_module.add_input_node('input')
Expand All @@ -75,7 +73,7 @@ class Graph(nn.Module):
>>> previous=['conv1x1', 'conv3x3', 'conv5x5'])
>>> inception_module.add_output_node('output', 'cat')
>>> # Build dummy variable
>>> input = Variable(torch.rand(1, 64, 100, 100))
>>> input = torch.rand(1, 64, 100, 100)
>>> # Get output
>>> output = inception_module(input)
Expand Down
3 changes: 1 addition & 2 deletions inferno/extensions/criteria/elementwise_measures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch.nn as nn
from torch.autograd import Variable
from ...utils.exceptions import assert_


Expand All @@ -26,5 +25,5 @@ def forward(self, input, target):
# Get final weight by adding weight differential to a tensor with negative weights
weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT)
# `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with.
sqrt_weights = Variable(weights.sqrt_(), requires_grad=False)
sqrt_weights = weights.sqrt_()
return self.mse(input * sqrt_weights, target * sqrt_weights)
16 changes: 7 additions & 9 deletions inferno/extensions/criteria/set_similarity_measures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch.nn as nn
from ...utils.torch_utils import flatten_samples
from torch.autograd import Variable

__all__ = ['SorensenDiceLoss', 'GeneralizedDiceLoss']

Expand Down Expand Up @@ -53,11 +52,9 @@ def forward(self, input, target):
# With pytorch < 0.2, channelwise_loss.size = (C, 1).
if channelwise_loss.dim() == 2:
channelwise_loss = channelwise_loss.squeeze(1)
# Wrap weights in a variable
weight = Variable(self.weight, requires_grad=False)
assert weight.size() == channelwise_loss.size()
assert self.weight.size() == channelwise_loss.size()
# Apply weight
channelwise_loss = weight * channelwise_loss
channelwise_loss = self.weight * channelwise_loss
# Sum over the channels to compute the total loss
loss = channelwise_loss.sum()
return loss
Expand Down Expand Up @@ -104,7 +101,7 @@ def forward(self, input, target):
else:
def flatten_and_preserve_channels(tensor):
tensor_dim = tensor.dim()
assert tensor_dim >= 3
assert tensor_dim >= 3
num_channels = tensor.size(1)
num_classes = tensor.size(2)
# Permute the channel axis to first
Expand All @@ -131,10 +128,11 @@ def flatten_and_preserve_channels(tensor):
if self.weight is not None:
if channelwise_loss.dim() == 2:
channelwise_loss = channelwise_loss.squeeze(1)
channel_weights = Variable(self.weight, requires_grad=False)
assert channel_weights.size() == channelwise_loss.size(), "`weight` should have shape (nb_channels, ), `target` should have shape (batch_size, nb_channels, nb_classes, ...)"
assert self.weight.size() == channelwise_loss.size(),\
"""`weight` should have shape (nb_channels, ),
`target` should have shape (batch_size, nb_channels, nb_classes, ...)"""
# Apply channel weights:
channelwise_loss = channel_weights * channelwise_loss
channelwise_loss = self.weight * channelwise_loss

loss = channelwise_loss.sum()

Expand Down
7 changes: 0 additions & 7 deletions inferno/extensions/initializers/presets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import torch.nn.init as init
from torch.autograd import Variable
from functools import partial

from .base import Initialization, Initializer
Expand All @@ -19,8 +18,6 @@ def __init__(self, constant):
self.constant = constant

def call_on_tensor(self, tensor):
if isinstance(tensor, Variable):
tensor = tensor.data
tensor.fill_(self.constant)
return tensor

Expand All @@ -42,9 +39,6 @@ def compute_fan_in(self, tensor):
return np.prod(list(tensor.size())[1:])

def call_on_weight(self, tensor):
if isinstance(tensor, Variable):
self.call_on_weight(tensor.data)
return tensor
# Compute stddev if required
if self.sqrt_gain_over_fan_in is not None:
stddev = self.stddev * \
Expand Down Expand Up @@ -85,4 +79,3 @@ def __init__(self):
super(ELUWeightsZeroBias, self)\
.__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),
bias_initializer=Constant(0.))

40 changes: 7 additions & 33 deletions inferno/trainers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shutil
import contextlib
import warnings

# These are fetched from globals, they're not unused
# noinspection PyUnresolvedReferences
Expand All @@ -14,7 +13,6 @@

import torch
from numpy import inf
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn.parallel.data_parallel import data_parallel
from .callbacks.logging.base import Logger
Expand Down Expand Up @@ -1189,30 +1187,9 @@ def wrap_batch(self, batch, from_loader=None, requires_grad=False, volatile=Fals
else:
raise ValueError("Internal Error: Invalid base_device_ordinal: {}."
.format(base_device_ordinal))
# Cast to the right dtype

# Cast to the right dtype and return
batch = self.cast(batch)
# Second, wrap as variable
variable_batch = []
for batch_num, _batch in enumerate(batch):
if thu.is_tensor(_batch):
# This supresses the volatile deprecated warning
# TODO remove after Pytorch 1.0
with warnings.catch_warnings():
warnings.simplefilter('ignore')
variable_batch.append(Variable(_batch, requires_grad=requires_grad,
volatile=volatile))
elif pyu.is_listlike(_batch):
# This supresses the volatile deprecated warning
# TODO remove after Pytorch 1.0
with warnings.catch_warnings():
warnings.simplefilter('ignore')
variable_batch.append([Variable(__batch, requires_grad=requires_grad,
volatile=volatile)
for __batch in _batch])
else:
raise RuntimeError(f"Was Expecting batch at index {batch_num} to be either a "
f"tensor or a list of tensors. Got {type(_batch)} instead.")
batch = type(batch)(variable_batch)
return batch

def next_iteration(self):
Expand Down Expand Up @@ -1408,7 +1385,7 @@ def train_for(self, num_iterations=None, break_callback=None):
self.console.info("Breaking on request from callback.")
break
self.console.progress("Training iteration {} (batch {} of epoch {})."
.format(iteration_num, self._batch_count, self._epoch_count))
.format(iteration_num, self._batch_count, self._epoch_count))
# Call callback
self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_ITERATION,
iteration_num=iteration_num)
Expand Down Expand Up @@ -1512,8 +1489,6 @@ def validate_for(self, num_iterations=None, loader_name='validate'):
num_iterations_in_generator=len(self._loader_iters[loader_name]),
last_validated_at_epoch=self._last_validated_at_epoch)



while True:
if num_iterations is not None and iteration_num >= num_iterations:
break
Expand All @@ -1523,8 +1498,7 @@ def validate_for(self, num_iterations=None, loader_name='validate'):

try:
batch = self.fetch_next_batch(loader_name,
restart_exhausted_generators=
num_iterations is not None,
restart_exhausted_generators=num_iterations is not None,
update_batch_count=False,
update_epoch_count_if_generator_exhausted=False)
except StopIteration:
Expand All @@ -1545,7 +1519,7 @@ def validate_for(self, num_iterations=None, loader_name='validate'):
# Apply model, compute loss
output, loss = self.apply_model_and_loss(inputs, target, backward=False,
mode='eval')
if isinstance(target, (list,tuple)):
if isinstance(target, (list, tuple)):
batch_size = target[0].size(self._target_batch_dim)
else:
batch_size = target.size(self._target_batch_dim)
Expand Down Expand Up @@ -1589,8 +1563,8 @@ def validate_for(self, num_iterations=None, loader_name='validate'):

self.callbacks.call(self.callbacks.END_OF_VALIDATION_RUN,
validation_loss_meter=validation_loss_meter,
validation_error_meter=
validation_error_meter if self.metric_is_defined else None)
validation_error_meter=validation_error_meter if
self.metric_is_defined else None)
return self

def record_validation_results(self, validation_loss, validation_error):
Expand Down
1 change: 0 additions & 1 deletion inferno/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch.autograd import Variable
from .exceptions import assert_, NotTorchModuleError, ShapeError


Expand Down
63 changes: 24 additions & 39 deletions inferno/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
import numpy as np
import torch
from torch.autograd import Variable

from .python_utils import delayed_keyboard_interrupt
from .exceptions import assert_, ShapeError, NotUnwrappableError


def unwrap(tensor_or_variable, to_cpu=True, as_numpy=False, extract_item=False):
if isinstance(tensor_or_variable, (list, tuple)):
return type(tensor_or_variable)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
for _t in tensor_or_variable])
elif isinstance(tensor_or_variable, Variable):
tensor = tensor_or_variable.data
elif torch.is_tensor(tensor_or_variable):
tensor = tensor_or_variable
elif isinstance(tensor_or_variable, np.ndarray):
return tensor_or_variable
elif isinstance(tensor_or_variable, (float, int)):
return tensor_or_variable
def unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False):
if isinstance(input_, (list, tuple)):
return type(input_)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)
for _t in input_])
elif torch.is_tensor(input_):
tensor = input_
elif isinstance(input_, np.ndarray):
return input_
elif isinstance(input_, (float, int)):
return input_
else:
raise NotUnwrappableError("Cannot unwrap a '{}'."
.format(type(tensor_or_variable).__name__))
.format(type(input_).__name__))
# Transfer to CPU if required
if to_cpu:
with delayed_keyboard_interrupt():
tensor = tensor.cpu()
# Convert to numpy if required
if as_numpy:
return tensor.cpu().numpy()
return tensor.cpu().detach().numpy()
elif extract_item:
try:
return tensor.item()
Expand Down Expand Up @@ -94,11 +91,11 @@ def where(condition, if_true, if_false):
Parameters
----------
condition : torch.ByteTensor or torch.cuda.ByteTensor or torch.autograd.Variable
condition : torch.ByteTensor or torch.cuda.ByteTensor
Condition to check.
if_true : torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable
if_true : torch.Tensor or torch.cuda.Tensor
Output value if condition is true.
if_false: torch.Tensor or torch.cuda.Tensor or torch.autograd.Variable
if_false: torch.Tensor or torch.cuda.Tensor
Output value if condition is false
Returns
Expand All @@ -107,30 +104,18 @@ def where(condition, if_true, if_false):
Raises
------
AssertionError
if if_true and if_false are not both variables or both tensors.
AssertionError
if if_true and if_false don't have the same datatype.
"""
if isinstance(if_true, Variable) or isinstance(if_false, Variable):
assert isinstance(condition, Variable), \
"Condition must be a variable if either if_true or if_false is a variable."
assert isinstance(if_false, Variable) and isinstance(if_false, Variable), \
"Both if_true and if_false must be variables if either is one."
assert if_true.data.type() == if_false.data.type(), \
"Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
else:
assert not isinstance(condition, Variable), \
"Condition must not be a variable because neither if_true nor if_false is one."
# noinspection PyArgumentList
assert if_true.type() == if_false.type(), \
"Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
# noinspection PyArgumentList
assert if_true.type() == if_false.type(), \
"Type mismatch: {} and {}".format(if_true.data.type(), if_false.data.type())
casted_condition = condition.type_as(if_true)
output = casted_condition * if_true + (1 - casted_condition) * if_false
return output


def flatten_samples(tensor_or_variable):
def flatten_samples(input_):
"""
Flattens a tensor or a variable such that the channel axis is first and the sample axis
is second. The shapes are transformed as follows:
Expand All @@ -139,17 +124,17 @@ def flatten_samples(tensor_or_variable):
(N, C) --> (C, N)
The input must be atleast 2d.
"""
assert_(tensor_or_variable.dim() >= 2,
assert_(input_.dim() >= 2,
"Tensor or variable must be atleast 2D. Got one of dim {}."
.format(tensor_or_variable.dim()),
.format(input_.dim()),
ShapeError)
# Get number of channels
num_channels = tensor_or_variable.size(1)
num_channels = input_.size(1)
# Permute the channel axis to first
permute_axes = list(range(tensor_or_variable.dim()))
permute_axes = list(range(input_.dim()))
permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]
# For input shape (say) NCHW, this should have the shape CNHW
permuted = tensor_or_variable.permute(*permute_axes).contiguous()
permuted = input_.permute(*permute_axes).contiguous()
# Now flatten out all but the first axis and return
flattened = permuted.view(num_channels, -1)
return flattened
Expand Down

0 comments on commit 0561e8a

Please sign in to comment.