Skip to content

Commit

Permalink
Introduce GPipe.devices
Browse files Browse the repository at this point in the history
It is a replacement of GPipe.in_device and out_device.
  • Loading branch information
sublee authored and GitHub Enterprise committed Jun 10, 2019
1 parent b9e0658 commit c557922
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ API

.. automethod:: forward(input)

.. autoattribute:: devices
:annotation:


Licensing and Authors
---------------------
Expand Down
16 changes: 16 additions & 0 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,19 @@ def forward(self, _):

# Not in a partition.
assert current_microbatch() is None


def test_devices():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)
c = nn.Linear(1, 1)

# There are extra two devices.
devices = ['cpu', 'cpu', 'cpu', 'cpu', 'cpu']

model = nn.Sequential(a, b, c)
model = GPipe(model, [1, 1, 1], devices=devices)

cpu = torch.device('cpu')
# Extra devices must be discarded.
assert model.devices == (cpu, cpu, cpu)
38 changes: 28 additions & 10 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ class GPipe(nn.Module):
"""

#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = gpipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = gpipe(input)
#: loss = F.cross_entropy(output, target)
#:
devices: Tuple[torch.device, ...] = ()
# ^^^^
# The default value () required for Sphinx's autoattribute.

def __init__(self,
module: nn.Sequential,
balance: Iterable[int],
Expand All @@ -109,18 +127,17 @@ def __init__(self,

if chunks <= 0:
raise ValueError('number of chunks must be positive integer')
self.chunks = chunks

if checkpoint not in ['always', 'except_last', 'never']:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")

self.chunks = chunks
self.checkpoint = checkpoint

if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, self.chunks)

self._partitions, self.balance, self.in_device, self.out_device = \
self.partition(module, balance, devices)
self._partitions, self.balance, self.devices = self.partition(module, balance, devices)

def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over underlying sequential layers."""
Expand Down Expand Up @@ -168,11 +185,11 @@ def partitions(self) -> List[Partition]:
def partition(module: nn.Sequential,
balance: Iterable[int],
devices: Optional[Devices],
) -> Tuple[nn.ModuleList, List[int], torch.device, torch.device]:
) -> Tuple[nn.ModuleList, Tuple[int, ...], Tuple[torch.device, ...]]:
"""Partitions the given sequential module onto the devices.
Returns:
A tuple of (partitions, input device, output device).
A tuple of (partitions, balance, devices).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
Expand Down Expand Up @@ -221,10 +238,9 @@ def partition(module: nn.Sequential,
del partition_layers[:]
i += 1

in_device = partitions[0].device
out_device = partitions[-1].device
del devices[i:]

return nn.ModuleList(partitions), balance, in_device, out_device
return nn.ModuleList(partitions), tuple(balance), tuple(devices)

def spawn_workers(self) -> Tuple[PriorityQueue, PriorityQueue]:
"""Creates worker threads."""
Expand Down Expand Up @@ -318,7 +334,8 @@ def push_input(self,
) -> int:
"""Pushes chunked inputs to the first partition."""
# Divide a mini-batch into micro-batches.
inputs = scatter(input, chunks=self.chunks, device=self.in_device)
in_device = self.devices[0]
inputs = scatter(input, chunks=self.chunks, device=in_device)

# The number of inputs might be smaller than the number of chunks.
num_inputs = len(inputs)
Expand Down Expand Up @@ -376,7 +393,8 @@ def pull_output(self,
output, _, _ = msg.payload
outputs.append(output)

output = gather(outputs, device=self.out_device)
out_device = self.devices[-1]
output = gather(outputs, device=out_device)
out_queue.get()

return output
Expand Down

0 comments on commit c557922

Please sign in to comment.