Skip to content

Commit

Permalink
Move torchgpipe_balancing to torchgpipe.balance
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Oct 25, 2019
2 parents 61a8a54 + e50f2da commit 9aa856c
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 28 deletions.
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ Inspecting GPipe Timeline
Automatic Balancing
~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchgpipe_balancing.balance_by_time(module, canary, partitions, device, timeout)
.. autofunction:: torchgpipe.balance.balance_by_time(module, canary, partitions, device, timeout)

.. autofunction:: torchgpipe_balancing.balance_by_size(module, canary, partitions, device)
.. autofunction:: torchgpipe.balance.balance_by_size(module, canary, partitions, device)
6 changes: 5 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ v0.0.5 (WIP)

Not released yet.

- Checkpointing deterministically handles randomness managed by PyTorch.
Improvements:
- Checkpointing deterministically handles randomness managed by PyTorch.

Breaking Changes:
- Moved ``torchgpipe_balancing`` module to :mod:`torchgpipe.balance`.

v0.0.4
~~~~~~
Expand Down
14 changes: 7 additions & 7 deletions docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,20 @@ Automatic Balancing

It could be hard to determine the optimal balance of a model. In particular, if
you are still designing a model, the model architecture may change over time.
In this case, we highly recommend :mod:`torchgpipe_balancing` for automatic
In this case, we highly recommend :mod:`torchgpipe.balance` for automatic
balancing. This won't give you the optimal balance, but a good-enough balance.
Note that this is provided by `torchgpipe` package, and is not from the GPipe
paper.

There are two balancing tools, :func:`~torchgpipe_balancing.balance_by_time`
and :func:`~torchgpipe_balancing.balance_by_size`. Both are based on per-layer
There are two balancing tools, :func:`~torchgpipe.balance.balance_by_time` and
:func:`~torchgpipe.balance.balance_by_size`. Both are based on per-layer
profiling. Just like `PyTorch JIT`_, you need to feed a sample input into the
model. :func:`~torchgpipe_balancing.balance_by_time` traces elapsed time of
each layer, while :func:`~torchgpipe_balancing.balance_by_size` detects the
CUDA memory usage of each layer. Choose the balancing tool for your needs::
model. :func:`~torchgpipe.balance.balance_by_time` traces elapsed time of each
layer, while :func:`~torchgpipe.balance.balance_by_size` detects the CUDA
memory usage of each layer. Choose the balancing tool for your needs::

from torchgpipe import GPipe
from torchgpipe_balancing import balance_by_time
from torchgpipe.balance import balance_by_time

sample = torch.rand(128, 3, 224, 224)
balance = balance_by_time(model, sample, partitions=4)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ mypy_path = ./stubs/
follow_imports = normal

# This project must be strictly typed.
[mypy-torchgpipe.*,torchgpipe_balancing.*]
[mypy-torchgpipe.*]
check_untyped_defs = true
disallow_untyped_defs = true
disallow_untyped_calls = true
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@

zip_safe=False,

packages=['torchgpipe', 'torchgpipe_balancing'],
packages=['torchgpipe', 'torchgpipe.balance'],
package_data={'torchgpipe': ['py.typed']},
py_modules=['torchgpipe_balancing'],

install_requires=['torch>=1.1'],
setup_requires=['pytest-runner'],
Expand Down
7 changes: 6 additions & 1 deletion tests/test_balancing.py → tests/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import nn

from torchgpipe_balancing import balance_by_size, balance_by_time, blockpartition
from torchgpipe.balance import balance_by_size, balance_by_time, blockpartition


def test_blockpartition():
Expand Down Expand Up @@ -98,3 +98,8 @@ def forward(self, x):
balance_by_time(model, sample, partitions=1, device='cpu')

assert not model.training


def test_deprecated_torchgpipe_balancing():
with pytest.raises(ImportError, match='torchgpipe.balance'):
__import__('torchgpipe_balancing')
8 changes: 4 additions & 4 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,16 @@ def test_named_children():
model.a


def test_recommend_torchgpipe_balancing():
with pytest.raises(ValueError, match='torchgpipe_balancing'):
def test_recommend_auto_balance():
with pytest.raises(ValueError, match='torchgpipe.balance'):
# balance is required
GPipe(nn.Sequential())

with pytest.raises(ValueError, match='torchgpipe_balancing'):
with pytest.raises(ValueError, match='torchgpipe.balance'):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
GPipe(nn.Sequential(), [1])

with pytest.raises(ValueError, match='torchgpipe_balancing'):
with pytest.raises(ValueError, match='torchgpipe.balance'):
# module and sum of balance have different length (module: 2, sum of balance: 1)
GPipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torchgpipe import GPipe
from torchgpipe_balancing import balance_by_time
from torchgpipe.balance import balance_by_time
sample = torch.rand(128, 3, 224, 224)
balance = balance_by_time(model, sample, partitions=4)
Expand All @@ -18,8 +18,8 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import utils
from torchgpipe_balancing.profile import profile_sizes, profile_times
from torchgpipe.balance import utils
from torchgpipe.balance.profile import profile_sizes, profile_times

__all__ = ['balance_by_time', 'balance_by_size']

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import utils
from torchgpipe.balance import utils

__all__: List[str] = []

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import blockpartition
from torchgpipe.balance import blockpartition

__all__: List[str] = []

Expand Down
12 changes: 6 additions & 6 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
NamedModules = OrderedDict


def recommend_torchgpipe_balancing(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchgpipe_balancing`."""
def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchgpipe.balance`."""
return '''{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend torchgpipe_balancing for naive
frequently. In this case, we highly recommend 'torchgpipe.balance' for naive
automatic balancing:
from torchgpipe import GPipe
from torchgpipe_balancing import balance_by_time
from torchgpipe.balance import balance_by_time
sample = torch.rand(...)
balance = balance_by_time(model, sample, partitions=...)
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(self,
super().__init__()

if balance is None:
raise ValueError(recommend_torchgpipe_balancing('balance is required'))
raise ValueError(recommend_auto_balance('balance is required'))
if chunks <= 0:
raise ValueError('number of chunks must be positive integer')
if checkpoint not in ['always', 'except_last', 'never']:
Expand All @@ -227,7 +227,7 @@ def __init__(self,
try:
self.partitions, self.balance, self.devices = split_module(module, balance, devices)
except BalanceError as exc:
raise ValueError(recommend_torchgpipe_balancing(str(exc)))
raise ValueError(recommend_auto_balance(str(exc)))

self._copy_streams: List[List[AbstractStream]] = []

Expand Down
2 changes: 2 additions & 0 deletions torchgpipe_balancing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# 'torchgpipe_balancing' has moved to 'torchgpipe.balance' in v0.0.5.
raise ImportError("import 'torchgpipe.balance' instead")

0 comments on commit 9aa856c

Please sign in to comment.