Skip to content

Commit

Permalink
Detect duplicate children or parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Sep 24, 2019
1 parent 4684b8d commit eb53350
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 82 deletions.
36 changes: 29 additions & 7 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn

from torchgpipe import GPipe
from torchgpipe.gpipe import verify_module


def test_parameters():
Expand All @@ -15,11 +16,6 @@ def test_parameters():
assert list(gpipe.parameters()) != []


def test_non_sequential():
with pytest.raises(TypeError):
GPipe(nn.Module(), balance=[1], devices=['cpu'])


@pytest.mark.parametrize('balance', [[2], [1, 1]])
def test_sequential_like(balance):
a = nn.Linear(1, 1)
Expand Down Expand Up @@ -77,8 +73,7 @@ def test_chunks_less_than_1():


def test_too_few_devices():
x = nn.Linear(1, 1)
model = nn.Sequential(x, x, x, x)
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1))

with pytest.raises(IndexError):
# len(balance) > len(devices)
Expand Down Expand Up @@ -485,3 +480,30 @@ def test_recommend_torchgpipe_balancing():
with pytest.raises(ValueError, match='torchgpipe_balancing'):
# module and sum of balance have different length (module: 2, sum of balance: 1)
GPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])


def test_verify_module_non_sequential():
with pytest.raises(TypeError, match='module must be nn.Sequential to be partitioned'):
verify_module(nn.Module())


def test_verify_module_duplicate_children():
conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(conv, conv)

with pytest.raises(ValueError, match='module with duplicate children is not supported'):
verify_module(model)


def test_verify_module_duplicate_parameters_in_distinct_children():
class Surrogate(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module

conv = nn.Conv2d(3, 3, 1)
model = nn.Sequential(Surrogate(conv), Surrogate(conv))

with pytest.raises(ValueError, match='module with duplicate parameters in '
'distinct children is not supported'):
verify_module(model)
161 changes: 86 additions & 75 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The GPipe implementation."""
"""The GPipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast

Expand Down Expand Up @@ -30,11 +30,9 @@
NamedModules = OrderedDict


def recommend_torchgpipe_balancing(title: str) -> ValueError:
"""Creates a :exc:`ValueError` with recommendation to
:mod:`torchgpipe_balancing`.
"""
return ValueError('''{title}
def recommend_torchgpipe_balancing(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchgpipe_balancing`."""
return '''{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend torchgpipe_balancing for naive
Expand All @@ -47,7 +45,82 @@ def recommend_torchgpipe_balancing(title: str) -> ValueError:
balance = balance_by_time(model, sample, partitions=...)
model = GPipe(model, balance, chunks=...)
'''.format(title=title))
'''.format(message=message)


def verify_module(module: nn.Sequential) -> None:
if not isinstance(module, nn.Sequential):
raise TypeError('module must be nn.Sequential to be partitioned')

named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError('module with duplicate children is not supported')

num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters != num_child_parameters:
raise ValueError('module with duplicate parameters in distinct children is not supported')


class BalanceError(ValueError):
pass


def split_module(module: nn.Sequential,
balance: List[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance, devices).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
if len(module) != sum(balance):
raise BalanceError('module and sum of balance have different length '
'(module: %d, sum of balance: %d)' % (len(module), sum(balance)))

if any(x <= 0 for x in balance):
raise BalanceError('all balance numbers must be positive integer (balance: %r)' % balance)

if len(balance) > len(devices):
raise IndexError('too few devices to hold given partitions '
'(devices: %s, partitions: %d)' % (len(devices), len(balance)))

i = 0
partitions = []
layers: NamedModules = OrderedDict()

for name, layer in module.named_children():
layers[name] = layer

if len(layers) == balance[i]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)

device = devices[i]
partition.to(device)

partitions.append(partition)

# Prepare for the next partition.
layers.clear()
i += 1

partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
del devices[i:]

return partitions, balance, devices


MOVING_DENIED = TypeError('denied to move parameters and buffers, '
Expand Down Expand Up @@ -130,96 +203,34 @@ def __init__(self,
) -> None:
super().__init__()

if not isinstance(module, nn.Sequential):
raise TypeError('non-sequential module cannot be partitioned')

if balance is None:
raise recommend_torchgpipe_balancing('balance is required')

raise ValueError(recommend_torchgpipe_balancing('balance is required'))
if chunks <= 0:
raise ValueError('number of chunks must be positive integer')

if checkpoint not in ['always', 'except_last', 'never']:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")

self.chunks = chunks
self.checkpoint = checkpoint

verify_module(module)

if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, self.chunks)

# Split the module into multiple partitions.
balance = list(balance)

if devices is None:
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]

try:
self.partitions, self.balance, self.devices = \
self._split_module(module, balance, devices)
except ValueError as exc:
raise recommend_torchgpipe_balancing(str(exc))
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_torchgpipe_balancing(str(exc)))

self._copy_streams: List[List[AbstractStream]] = []

@staticmethod
def _split_module(module: nn.Sequential,
balance: List[int],
devices: List[torch.device],
) -> Tuple[List[nn.Sequential], List[int], List[torch.device]]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance, devices).
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
ValueError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
if len(module) != sum(balance):
raise ValueError('module and sum of balance have different length '
'(module: %d, sum of balance: %d)' % (len(module), sum(balance)))
if any(x <= 0 for x in balance):
raise ValueError('all balance numbers must be positive integer '
'(balance: %r)' % balance)

if len(balance) > len(devices):
raise IndexError('too few devices to hold given partitions '
'(devices: %s, partitions: %d)' % (len(devices), len(balance)))

i = 0
partitions = []
layers: NamedModules = OrderedDict()

for name, layer in module.named_children():
layers[name] = layer

if len(layers) == balance[i]:
# Group buffered layers as a partition.
partition = nn.Sequential(layers)

device = devices[i]
partition.to(device)

partitions.append(partition)

# Prepare for the next partition.
layers.clear()
i += 1

partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
del devices[i:]

return partitions, balance, devices

def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
Expand Down

0 comments on commit eb53350

Please sign in to comment.