Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- integrated device transfers in graph model
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Sep 15, 2017
1 parent 24c5169 commit 8d05e2e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
16 changes: 16 additions & 0 deletions inferno/extensions/containers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ...utils import python_utils as pyu
from ...utils.exceptions import assert_
from ..layers.device import OnDevice


__all__ = ['NNGraph', 'Graph']
Expand Down Expand Up @@ -360,6 +361,21 @@ def get_module_for_nodes(self, names):
modules.append(module)
return pyu.from_iterable(modules)

def to_device(self, names, target_device, device_ordinal=None, async=False):
"""Transfer nodes in the network to a specified device."""
names = pyu.to_iterable(names)
for name in names:
assert self.is_node_in_graph(name), "Node '{}' is not in graph.".format(name)
module = getattr(self, name, None)
assert module is not None, "Node '{}' is in the graph but could not find a module " \
"corresponding to it.".format(name)
# Transfer
module_on_device = OnDevice(module, target_device,
device_ordinal=device_ordinal,
async=async)
setattr(self, name, module_on_device)
return self

def get_parameters_for_nodes(self, names, named=False):
"""Get parameters of all nodes listed in `names`."""
if not named:
Expand Down
2 changes: 1 addition & 1 deletion inferno/extensions/layers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ...utils.python_utils import from_iterable, to_iterable
from ...utils.exceptions import assert_, DeviceError

__all__ = ['DeviceTransfer']
__all__ = ['DeviceTransfer', 'OnDevice']


class DeviceTransfer(nn.Module):
Expand Down
17 changes: 17 additions & 0 deletions tests/extensions/containers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def test_graph_basic(self):
model.add_output_node('output_0', previous='conv1')
ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model)

def test_graph_device_transfers(self):
from inferno.extensions.containers.graph import Graph
from inferno.extensions.layers.convolutional import ConvELU2D
import torch
from torch.autograd import Variable
# Build graph
model = Graph()
model.add_input_node('input_0')
model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0')
model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0')
model.add_output_node('output_0', previous='conv1')
# Transfer
model.to_device('conv0', 'cpu').to_device('conv1', 'cuda', 0)
x = Variable(torch.rand(1, 1, 100, 100))
y = model(x)
self.assertIsInstance(y.data, torch.cuda.FloatTensor)

@unittest.skip("Needs machine with 4 GPUs")
def test_multi_gpu(self):
import torch
Expand Down

0 comments on commit 8d05e2e

Please sign in to comment.