Skip to content

Commit

Permalink
Introduce torchgpipe_balancing
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Jun 20, 2019
2 parents bd13f16 + 1455413 commit e852db5
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ API
.. autoattribute:: devices
:annotation:

.. autofunction:: torchgpipe_balancing.balance_by_time(module, canary, partitions, device, timeout)

.. autofunction:: torchgpipe_balancing.balance_by_size(module, canary, partitions, device)

Licensing and Authors
---------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ mypy_path = ./stubs/
follow_imports = normal

# This project must be strictly typed.
[mypy-torchgpipe.*]
[mypy-torchgpipe.*,mypy-torchgpipe_balancing.*]
check_untyped_defs = true
disallow_untyped_defs = true
disallow_untyped_calls = true
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

zip_safe=False,

packages=['torchgpipe'],
packages=['torchgpipe', 'torchgpipe_balancing'],
package_data={'torchgpipe': ['py.typed']},

install_requires=['torch>=1'],
Expand Down
95 changes: 95 additions & 0 deletions tests/test_balancing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import time

import pytest
import torch
from torch import nn

from torchgpipe_balancing import balance_by_size, balance_by_time, blockpartition


def test_blockpartition():
assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]]


def test_blockpartition_zeros():
assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]]


def test_blockpartition_non_positive_partitions():
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=0)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=-1)


def test_blockpartition_short_sequence():
with pytest.raises(ValueError):
blockpartition.solve([], partitions=1)
with pytest.raises(ValueError):
blockpartition.solve([42], partitions=2)


def test_balance_by_time():
class Delay(nn.Module):
def __init__(self, seconds):
super().__init__()
self.seconds = seconds

def forward(self, x):
time.sleep(self.seconds)
return x

model = nn.Sequential(*[Delay(i/100) for i in [1, 2, 3, 4, 5, 6]])
sample = torch.rand(1)
balance = balance_by_time(model, sample, partitions=2, device='cpu')
assert balance == [4, 2]


# balance_by_size supports only CUDA device.
@pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda required')
def test_balance_by_size():
class Expand(nn.Module):
def __init__(self, times):
super().__init__()
self.times = times

def forward(self, x):
for i in range(self.times):
x = x + torch.rand_like(x, requires_grad=True)
return x

model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
sample = torch.rand(10, 100, 100)
balance = balance_by_size(model, sample, partitions=2, device='cuda')
assert balance == [4, 2]


def test_sandbox():
model = nn.Sequential(nn.BatchNorm2d(3))

before = {k: v.clone() for k, v in model.state_dict().items()}

sample = torch.rand(1, 3, 10, 10)
balance_by_time(model, sample, partitions=1, device='cpu')

after = model.state_dict()

assert before.keys() == after.keys()
for key, value in before.items():
assert torch.allclose(after[key], value)


def test_not_training():
class AssertTraining(nn.Module):
def forward(self, x):
assert self.training
return x
model = nn.Sequential(AssertTraining())

model.eval()
assert not model.training

sample = torch.rand(1)
balance_by_time(model, sample, partitions=1, device='cpu')

assert not model.training
14 changes: 14 additions & 0 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,17 @@ def test_named_children():
# several methods in its namespace.
with pytest.raises(AttributeError):
model.a


def test_recommend_torchgpipe_balancing():
with pytest.raises(ValueError, match='torchgpipe_balancing'):
# balance is required
GPipe(nn.Sequential())

with pytest.raises(ValueError, match='torchgpipe_balancing'):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
GPipe(nn.Sequential(), [1])

with pytest.raises(ValueError, match='torchgpipe_balancing'):
# module and sum of balance have different length (module: 2, sum of balance: 1)
GPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])
30 changes: 28 additions & 2 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ def current_microbatch() -> Optional[Tensor]:
return None


def recommend_torchgpipe_balancing(title: str) -> ValueError:
"""Creates a :exc:`ValueError` with recommendation to
:mod:`torchgpipe_balancing`.
"""
return ValueError('''{title}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend torchgpipe_balancing for naive
automatic balancing:
from torchgpipe import GPipe
from torchgpipe_balancing import balance_by_time
sample = torch.rand(...)
balance = balance_by_time(model, sample, partitions=...)
model = GPipe(model, balance, chunks=...)
'''.format(title=title))


MOVING_DENIED = TypeError('denied to move parameters and buffers, '
'because GPipe should manage device placement')

Expand Down Expand Up @@ -130,7 +150,7 @@ class GPipe(nn.Module):

def __init__(self,
module: nn.Sequential,
balance: Iterable[int],
balance: Optional[Iterable[int]] = None,
*,
devices: Optional[Devices] = None,
chunks: int = 1,
Expand All @@ -142,6 +162,9 @@ def __init__(self,
if not isinstance(module, nn.Sequential):
raise TypeError('non-sequential module cannot be partitioned')

if balance is None:
raise recommend_torchgpipe_balancing('balance is required')

if chunks <= 0:
raise ValueError('number of chunks must be positive integer')

Expand All @@ -161,7 +184,10 @@ def __init__(self,
devices = range(torch.cuda.device_count())
devices = [torch.device(d) for d in devices]

self.partitions, self.balance, self.devices = self._partition(module, balance, devices)
try:
self.partitions, self.balance, self.devices = self._partition(module, balance, devices)
except ValueError as exc:
raise recommend_torchgpipe_balancing(str(exc))

def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
Expand Down
131 changes: 131 additions & 0 deletions torchgpipe_balancing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""A helper to roughly balance a sequential module.
Usage::
import torch
from torchgpipe import GPipe
from torchgpipe_balancing import balance_by_time
sample = torch.rand(128, 3, 224, 224)
balance = balance_by_time(model, sample, partitions=4)
gpipe = GPipe(model, balance, chunks=8)
"""
import time
from typing import List, Optional, Union

import torch
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import utils

__all__ = ['balance_by_time', 'balance_by_size']


Device = Union[torch.device, int, str]


def balance_by_time(module: nn.Sequential,
sample: Tensor,
*,
partitions: int = 1,
device: Optional[Device] = None,
timeout: float = 1.0,
) -> List[int]:
"""Balances the given seqeuntial module by elapsed time per layer.
Args:
module (nn.Sequential):
sequential module to be partitioned
sample (Tensor):
example input
Keyword Args:
partitions (int):
intended number of partitions (default: 1)
device (torch.device):
CUDA device where the module is profiled (default: any related CUDA
device or ``torch.device('cuda')``)
timeout (float):
profiling iterates again if the timeout (as second) is not exceeded
(default: 1 second)
Returns:
A list of number of layers in each partition. Use it for the
``balance`` parameter of :class:`torchgpipe.GPipe`.
"""
sample, device = utils.concentrate_on_device(module, sample, device)

times: List[List[float]] = [[] for _ in module]

begun_at = time.time()
while time.time() - begun_at < timeout:

x = sample
with utils.training_sandbox(module):
for i, layer in enumerate(module):
utils.synchronize_device(device)
tick = time.time()

x = layer(x)

utils.synchronize_device(device)
tock = time.time()

times[i].append(tock - tick)

return utils.balance_cost(map(sum, times), partitions)


def balance_by_size(module: nn.Sequential,
sample: Tensor,
*,
partitions: int = 1,
device: Optional[Device] = None,
) -> List[int]:
"""Balances the given seqeuntial module by memory usage per layer.
Note:
This function relies on :func:`torch.cuda.reset_max_memory_allocated`
which is introduced at PyTorch 1.1. Therefore, it doesn't support
neither CPU tensors nor PyTorch 1.0.x.
Args:
module (nn.Sequential):
sequential module to be partitioned
sample (Tensor):
example input
Keyword Args:
partitions (int):
intended number of partitions (default: 1)
device (torch.device):
CUDA device where the module is profiled (default: any related CUDA
device or ``torch.device('cuda')``)
Returns:
A list of number of layers in each partition. Use it for the
``balance`` parameter of :class:`torchgpipe.GPipe`.
"""
if not hasattr(torch.cuda, 'reset_max_memory_allocated'):
raise NotImplementedError('balance_by_size requires PyTorch>=1.1')

sample, device = utils.concentrate_on_device(module, sample, device)

if device.type != 'cuda':
raise ValueError('balance_by_size supports only CUDA device')

sizes: List[int] = []

x = sample
with utils.training_sandbox(module):
for i, layer in enumerate(module):
torch.cuda.reset_max_memory_allocated(device)
x = layer(x)
sizes.append(torch.cuda.max_memory_allocated(device))

return utils.balance_cost(sizes, partitions)

0 comments on commit e852db5

Please sign in to comment.