-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
432 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import time | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from torchgpipe_balancing import balance_by_size, balance_by_time, blockpartition | ||
|
||
|
||
def test_blockpartition(): | ||
assert blockpartition.solve([1, 2, 3, 4, 5, 6], partitions=2) == [[1, 2, 3, 4], [5, 6]] | ||
|
||
|
||
def test_blockpartition_zeros(): | ||
assert blockpartition.solve([0, 0], partitions=2) == [[0], [0]] | ||
|
||
|
||
def test_blockpartition_non_positive_partitions(): | ||
with pytest.raises(ValueError): | ||
blockpartition.solve([42], partitions=0) | ||
with pytest.raises(ValueError): | ||
blockpartition.solve([42], partitions=-1) | ||
|
||
|
||
def test_blockpartition_short_sequence(): | ||
with pytest.raises(ValueError): | ||
blockpartition.solve([], partitions=1) | ||
with pytest.raises(ValueError): | ||
blockpartition.solve([42], partitions=2) | ||
|
||
|
||
def test_balance_by_time(): | ||
class Delay(nn.Module): | ||
def __init__(self, seconds): | ||
super().__init__() | ||
self.seconds = seconds | ||
|
||
def forward(self, x): | ||
time.sleep(self.seconds) | ||
return x | ||
|
||
model = nn.Sequential(*[Delay(i/100) for i in [1, 2, 3, 4, 5, 6]]) | ||
sample = torch.rand(1) | ||
balance = balance_by_time(model, sample, partitions=2, device='cpu') | ||
assert balance == [4, 2] | ||
|
||
|
||
# balance_by_size supports only CUDA device. | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda required') | ||
def test_balance_by_size(): | ||
class Expand(nn.Module): | ||
def __init__(self, times): | ||
super().__init__() | ||
self.times = times | ||
|
||
def forward(self, x): | ||
for i in range(self.times): | ||
x = x + torch.rand_like(x, requires_grad=True) | ||
return x | ||
|
||
model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) | ||
sample = torch.rand(10, 100, 100) | ||
balance = balance_by_size(model, sample, partitions=2, device='cuda') | ||
assert balance == [4, 2] | ||
|
||
|
||
def test_sandbox(): | ||
model = nn.Sequential(nn.BatchNorm2d(3)) | ||
|
||
before = {k: v.clone() for k, v in model.state_dict().items()} | ||
|
||
sample = torch.rand(1, 3, 10, 10) | ||
balance_by_time(model, sample, partitions=1, device='cpu') | ||
|
||
after = model.state_dict() | ||
|
||
assert before.keys() == after.keys() | ||
for key, value in before.items(): | ||
assert torch.allclose(after[key], value) | ||
|
||
|
||
def test_not_training(): | ||
class AssertTraining(nn.Module): | ||
def forward(self, x): | ||
assert self.training | ||
return x | ||
model = nn.Sequential(AssertTraining()) | ||
|
||
model.eval() | ||
assert not model.training | ||
|
||
sample = torch.rand(1) | ||
balance_by_time(model, sample, partitions=1, device='cpu') | ||
|
||
assert not model.training |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
"""A helper to roughly balance a sequential module. | ||
Usage:: | ||
import torch | ||
from torchgpipe import GPipe | ||
from torchgpipe_balancing import balance_by_time | ||
sample = torch.rand(128, 3, 224, 224) | ||
balance = balance_by_time(model, sample, partitions=4) | ||
gpipe = GPipe(model, balance, chunks=8) | ||
""" | ||
import time | ||
from typing import List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
import torch.nn as nn | ||
|
||
from torchgpipe_balancing import utils | ||
|
||
__all__ = ['balance_by_time', 'balance_by_size'] | ||
|
||
|
||
Device = Union[torch.device, int, str] | ||
|
||
|
||
def balance_by_time(module: nn.Sequential, | ||
sample: Tensor, | ||
*, | ||
partitions: int = 1, | ||
device: Optional[Device] = None, | ||
timeout: float = 1.0, | ||
) -> List[int]: | ||
"""Balances the given seqeuntial module by elapsed time per layer. | ||
Args: | ||
module (nn.Sequential): | ||
sequential module to be partitioned | ||
sample (Tensor): | ||
example input | ||
Keyword Args: | ||
partitions (int): | ||
intended number of partitions (default: 1) | ||
device (torch.device): | ||
CUDA device where the module is profiled (default: any related CUDA | ||
device or ``torch.device('cuda')``) | ||
timeout (float): | ||
profiling iterates again if the timeout (as second) is not exceeded | ||
(default: 1 second) | ||
Returns: | ||
A list of number of layers in each partition. Use it for the | ||
``balance`` parameter of :class:`torchgpipe.GPipe`. | ||
""" | ||
sample, device = utils.concentrate_on_device(module, sample, device) | ||
|
||
times: List[List[float]] = [[] for _ in module] | ||
|
||
begun_at = time.time() | ||
while time.time() - begun_at < timeout: | ||
|
||
x = sample | ||
with utils.training_sandbox(module): | ||
for i, layer in enumerate(module): | ||
utils.synchronize_device(device) | ||
tick = time.time() | ||
|
||
x = layer(x) | ||
|
||
utils.synchronize_device(device) | ||
tock = time.time() | ||
|
||
times[i].append(tock - tick) | ||
|
||
return utils.balance_cost(map(sum, times), partitions) | ||
|
||
|
||
def balance_by_size(module: nn.Sequential, | ||
sample: Tensor, | ||
*, | ||
partitions: int = 1, | ||
device: Optional[Device] = None, | ||
) -> List[int]: | ||
"""Balances the given seqeuntial module by memory usage per layer. | ||
Note: | ||
This function relies on :func:`torch.cuda.reset_max_memory_allocated` | ||
which is introduced at PyTorch 1.1. Therefore, it doesn't support | ||
neither CPU tensors nor PyTorch 1.0.x. | ||
Args: | ||
module (nn.Sequential): | ||
sequential module to be partitioned | ||
sample (Tensor): | ||
example input | ||
Keyword Args: | ||
partitions (int): | ||
intended number of partitions (default: 1) | ||
device (torch.device): | ||
CUDA device where the module is profiled (default: any related CUDA | ||
device or ``torch.device('cuda')``) | ||
Returns: | ||
A list of number of layers in each partition. Use it for the | ||
``balance`` parameter of :class:`torchgpipe.GPipe`. | ||
""" | ||
if not hasattr(torch.cuda, 'reset_max_memory_allocated'): | ||
raise NotImplementedError('balance_by_size requires PyTorch>=1.1') | ||
|
||
sample, device = utils.concentrate_on_device(module, sample, device) | ||
|
||
if device.type != 'cuda': | ||
raise ValueError('balance_by_size supports only CUDA device') | ||
|
||
sizes: List[int] = [] | ||
|
||
x = sample | ||
with utils.training_sandbox(module): | ||
for i, layer in enumerate(module): | ||
torch.cuda.reset_max_memory_allocated(device) | ||
x = layer(x) | ||
sizes.append(torch.cuda.max_memory_allocated(device)) | ||
|
||
return utils.balance_cost(sizes, partitions) |
Oops, something went wrong.