Skip to content

Commit

Permalink
Define GPipe.partitions as attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Jun 10, 2019
1 parent c557922 commit 38458bf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 21 deletions.
15 changes: 15 additions & 0 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn

from torchgpipe import GPipe, current_microbatch
from torchgpipe.partition import Partition


def test_parameters():
Expand Down Expand Up @@ -487,3 +488,17 @@ def test_devices():
cpu = torch.device('cpu')
# Extra devices must be discarded.
assert model.devices == (cpu, cpu, cpu)


def test_partitions():
a = nn.Linear(1, 1)
b = nn.Linear(1, 1)

model = nn.Sequential(a, b)
model = GPipe(model, [1, 1], devices=['cpu', 'cpu'])

assert isinstance(model.partitions, nn.ModuleList)
assert isinstance(model.partitions[0], Partition)
assert isinstance(model.partitions[1], Partition)

assert 'partitions.0.module.0.weight' in model.state_dict()
28 changes: 7 additions & 21 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,16 @@ def __init__(self,
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, self.chunks)

self._partitions, self.balance, self.devices = self.partition(module, balance, devices)

def __iter__(self) -> Iterable[nn.Module]:
"""Iterates over underlying sequential layers."""
# NOTE(sublee): self._partitions is typed as nn.ModuleList which
# iterates over nn.Modules. But actually, it includes only Partitions.
# Here we cast it to List[Partition] for activation of Partition's
# iteration capabilities during type checking.
partitions = cast(List[Partition], self._partitions)

for partition in partitions:
yield from partition
self.partitions, self.balance, self.devices = self.partition(module, balance, devices)

def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
partitions = cast(List[Partition], self._partitions)
partitions = cast(List[Partition], self.partitions)
return sum(len(p) for p in partitions)

def __getitem__(self, index: int) -> nn.Module:
"""Gets a layer in the underlying sequential module."""
partitions = cast(List[Partition], self._partitions)
partitions = cast(List[Partition], self.partitions)
if index < 0:
partitions = cast(List[Partition], reversed(partitions))

Expand All @@ -176,11 +165,6 @@ def __getitem__(self, index: int) -> nn.Module:

raise IndexError

def partitions(self) -> List[Partition]:
"""The underlying partitions."""
partitions = cast(List[Partition], self._partitions)
return list(partitions)

@staticmethod
def partition(module: nn.Sequential,
balance: Iterable[int],
Expand Down Expand Up @@ -244,11 +228,13 @@ def partition(module: nn.Sequential,

def spawn_workers(self) -> Tuple[PriorityQueue, PriorityQueue]:
"""Creates worker threads."""
n = len(self._partitions)
partitions = cast(List[Partition], self.partitions)

n = len(partitions)
queues: List[PriorityQueue] = [PriorityQueue() for _ in range(n+1)]
grad_enabled = torch.is_grad_enabled()

for i, partition in enumerate(self._partitions):
for i, partition in enumerate(partitions):
in_queue = queues[i]
out_queue = queues[i+1]

Expand Down

0 comments on commit 38458bf

Please sign in to comment.