Skip to content

Commit

Permalink
Fix and refactor torchgpipe_balancing
Browse files Browse the repository at this point in the history
Fix issue #3. Split time and size profilers into torchgpipe_balancing.profile.
  • Loading branch information
sublee committed Aug 13, 2019
1 parent bba5482 commit 3011ff0
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 49 deletions.
3 changes: 2 additions & 1 deletion docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ Automatic Balancing
It could be hard to determine the optimal balance of a model. In particular, if
you are still designing a model, probably the model architecture may change
over time. In this case, we highly recommend :mod:`torchgpipe_balancing` for
automatic balancing. This library is also a part of `torchgpipe` package but
automatic balancing. This library doesn't determine the optimal balance, but
could be a handy tool. Note that it is also a part of `torchgpipe` package but
not a part of the GPipe paper.

There are two balancing tools, :func:`~torchgpipe_balancing.balance_by_time`
Expand Down
7 changes: 6 additions & 1 deletion tests/test_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ def forward(self, x):
x = x + torch.rand_like(x, requires_grad=True)
return x

model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
sample = torch.rand(10, 100, 100)

model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
balance = balance_by_size(model, sample, partitions=2, device='cuda')
assert balance == [4, 2]

model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
balance = balance_by_size(model, sample, partitions=2, device='cuda')
assert balance == [2, 4]


def test_sandbox():
model = nn.Sequential(nn.BatchNorm2d(3))
Expand Down
43 changes: 4 additions & 39 deletions torchgpipe_balancing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
gpipe = GPipe(model, balance, chunks=8)
"""
import time
from typing import List, Optional, Union

import torch
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import utils
from torchgpipe_balancing.profile import profile_sizes, profile_times

__all__ = ['balance_by_time', 'balance_by_size']

Expand Down Expand Up @@ -57,27 +57,8 @@ def balance_by_time(module: nn.Sequential,
``balance`` parameter of :class:`~torchgpipe.GPipe`.
"""
sample, device = utils.concentrate_on_device(module, sample, device)

times: List[List[float]] = [[] for _ in module]

begun_at = time.time()
while time.time() - begun_at < timeout:

x = sample
with utils.training_sandbox(module):
for i, layer in enumerate(module):
utils.synchronize_device(device)
tick = time.time()

x = layer(x)

utils.synchronize_device(device)
tock = time.time()

times[i].append(tock - tick)

return utils.balance_cost(map(sum, times), partitions)
times = profile_times(module, sample, device, timeout)
return utils.balance_cost(times, partitions)


def balance_by_size(module: nn.Sequential,
Expand Down Expand Up @@ -111,21 +92,5 @@ def balance_by_size(module: nn.Sequential,
``balance`` parameter of :class:`~torchgpipe.GPipe`.
"""
if not hasattr(torch.cuda, 'reset_max_memory_allocated'):
raise NotImplementedError('balance_by_size requires PyTorch>=1.1')

sample, device = utils.concentrate_on_device(module, sample, device)

if device.type != 'cuda':
raise ValueError('balance_by_size supports only CUDA device')

sizes: List[int] = []

x = sample
with utils.training_sandbox(module):
for i, layer in enumerate(module):
torch.cuda.reset_max_memory_allocated(device)
x = layer(x)
sizes.append(torch.cuda.max_memory_allocated(device))

sizes = profile_sizes(module, sample, device)
return utils.balance_cost(sizes, partitions)
13 changes: 8 additions & 5 deletions torchgpipe_balancing/blockpartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
Paper: https://arxiv.org/pdf/1308.2452.pdf
"""
import sys
from typing import List

__all__ = ['solve']


def solve(sequence: List[float], partitions: int = 1) -> List[List[float]]:
def solve(sequence: List[int], partitions: int = 1) -> List[List[int]]:
"""Splits a sequence into several partitions to minimize variance for each
partition.
Expand All @@ -26,10 +25,14 @@ def solve(sequence: List[float], partitions: int = 1) -> List[List[float]]:
'' % (n, partitions))

# Normalize the sequence in [0, 1].
maximum = max(sequence)
minimum = min(sequence)
maximum = max(sequence) - minimum

normal_sequence: List[float]
if maximum == 0:
maximum = sys.float_info.epsilon
normal_sequence = [x / maximum for x in sequence]
normal_sequence = [0 for _ in sequence]
else:
normal_sequence = [(x-minimum)/maximum for x in sequence]

splits = [n//partitions * (x+1) for x in range(partitions-1)] + [n]

Expand Down
74 changes: 74 additions & 0 deletions torchgpipe_balancing/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Per-layer profilers."""
import time
from typing import List, Optional, Union

import torch
from torch import Tensor
import torch.nn as nn

from torchgpipe_balancing import utils

__all__: List[str] = []


Device = Union[torch.device, int, str]


def profile_times(module: nn.Sequential,
sample: Tensor,
device: Optional[Device],
timeout: float,
) -> List[int]:
"""Profiles elapsed times per layer."""
sample, device = utils.concentrate_on_device(module, sample, device)

time_bufs: List[List[float]] = [[] for _ in module]

begun_at = time.time()
while time.time() - begun_at < timeout:

x = sample
with utils.training_sandbox(module):
for i, layer in enumerate(module):
utils.synchronize_device(device)
tick = time.time()

x = layer(x)

utils.synchronize_device(device)
tock = time.time()

time_bufs[i].append(tock - tick)

us = 1_000_000
return [sum(int(t*us) for t in buf) for buf in time_bufs]


def profile_sizes(module: nn.Sequential,
sample: Tensor,
device: Optional[Device],
) -> List[int]:
"""Profiles CUDA memory usage per layer."""
if not hasattr(torch.cuda, 'reset_max_memory_allocated'):
raise NotImplementedError('balance_by_size requires PyTorch>=1.1')

sample, device = utils.concentrate_on_device(module, sample, device)

if device.type != 'cuda':
raise ValueError('balance_by_size supports only CUDA device')

sizes: List[int] = []

x = sample
with torch.cuda.device(device), utils.training_sandbox(module):
for i, layer in enumerate(module):
torch.cuda.reset_max_memory_allocated(device)

size_before = torch.cuda.max_memory_allocated(device)
x = layer(x)
size_after = torch.cuda.max_memory_allocated(device)

size = size_after - size_before
sizes.append(size)

return sizes
6 changes: 3 additions & 3 deletions torchgpipe_balancing/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Internal utilities."""
from contextlib import contextmanager
from typing import Generator, Iterable, List, Optional, Tuple, Union
from typing import Generator, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -67,6 +67,6 @@ def synchronize_device(device: torch.device):
torch.cuda.synchronize()


def balance_cost(cost: Iterable[float], partitions: int) -> List[int]:
partitioned = blockpartition.solve(list(cost), partitions)
def balance_cost(cost: List[int], partitions: int) -> List[int]:
partitioned = blockpartition.solve(cost, partitions)
return [len(p) for p in partitioned]

0 comments on commit 3011ff0

Please sign in to comment.