-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
256 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import gc | ||
import time | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
from torchgpipe import GPipe, microbatchlocal | ||
|
||
|
||
def test_current_microbatch(): | ||
class Twice(nn.Module): | ||
def forward(self, x): | ||
return x * 2 | ||
|
||
class CurrentMicrobatch(nn.Module): | ||
def forward(self, _): | ||
return microbatchlocal.current_microbatch() | ||
|
||
# Not in a partition. | ||
with pytest.raises(RuntimeError): | ||
microbatchlocal.current_microbatch() | ||
|
||
input = torch.tensor([1., 2., 3.]) | ||
|
||
model = nn.Sequential(Twice(), CurrentMicrobatch()) | ||
model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3) | ||
|
||
output = model(input) | ||
|
||
assert torch.allclose(output, torch.tensor([1., 2., 3.])) | ||
|
||
# Not in a partition. | ||
with pytest.raises(RuntimeError): | ||
microbatchlocal.current_microbatch() | ||
|
||
|
||
def test_local(): | ||
local = microbatchlocal.local() | ||
dicts = object.__getattribute__(local, 'local_dicts').dicts | ||
|
||
class Keep(nn.Module): | ||
def forward(self, x): | ||
local.x = x | ||
return x - x # return zero tensor. | ||
|
||
class InvDelay(nn.Module): | ||
def forward(self, x): | ||
# Sleep for 0.01 seconds for tensor([3.]), 0.03 seconds for tensor([1.]). | ||
time.sleep((4-x.item()) / 100) | ||
return x | ||
|
||
class Restore(nn.Module): | ||
def forward(self, x): | ||
del x | ||
return local.x | ||
|
||
input = torch.tensor([1., 2., 3.]) | ||
|
||
model = nn.Sequential(Keep(), InvDelay(), Restore()) | ||
model = GPipe(model, balance=[1, 1, 1], devices=['cpu', 'cpu', 'cpu'], chunks=3) | ||
|
||
output = model(input) | ||
|
||
assert torch.allclose(output, torch.tensor([1., 2., 3.])) | ||
|
||
# The micro-batch local storage has to be cleared by GC. | ||
gc.collect() | ||
assert len(dicts) == 0 | ||
|
||
|
||
def test_local_current_microbatch(): | ||
local = microbatchlocal.local() | ||
|
||
class KeepCurrentMicrobatch(nn.Module): | ||
def forward(self, x): | ||
local.x = microbatchlocal.current_microbatch() | ||
return x | ||
|
||
input = torch.tensor([1.]) | ||
|
||
model = nn.Sequential(KeepCurrentMicrobatch()) | ||
model = GPipe(model, balance=[1], devices=['cpu']) | ||
|
||
with pytest.raises(RuntimeError): | ||
model(input) | ||
|
||
|
||
def test_local_out_of_partition(): | ||
local = microbatchlocal.local() | ||
with pytest.raises(RuntimeError): | ||
local.test = 'test' | ||
|
||
|
||
def test_local_attrerror(): | ||
local = microbatchlocal.local() | ||
|
||
class Test(nn.Module): | ||
def forward(self, x): | ||
with pytest.raises(AttributeError): | ||
local.not_exists | ||
with pytest.raises(AttributeError): | ||
del local.not_exists | ||
|
||
local.exists = True | ||
local.exists | ||
del local.exists | ||
|
||
return x | ||
|
||
input = torch.tensor([1.]) | ||
|
||
model = nn.Sequential(Test()) | ||
model = GPipe(model, balance=[1], devices=['cpu']) | ||
|
||
model(input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
"""A GPipe implementation in PyTorch.""" | ||
from torchgpipe.__version__ import __version__ # noqa | ||
from torchgpipe.gpipe import GPipe | ||
from torchgpipe.microbatchlocal import current_microbatch, local | ||
|
||
__all__ = ['GPipe'] | ||
__all__ = ['GPipe', 'current_microbatch', 'local'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
"""Provides tools to implement per micro-batch features.""" | ||
import threading | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
import weakref | ||
|
||
from torch import Tensor | ||
|
||
__all__ = ['current_microbatch', 'local'] | ||
|
||
|
||
LocalDict = Dict[str, Any] | ||
WeakTensor = Callable[[], Optional[Tensor]] | ||
|
||
|
||
# The micro-batch leaf tensor storage for each partition worker thread. | ||
_local = threading.local() | ||
|
||
|
||
def _record_microbatch(microbatch: Tensor) -> None: | ||
# GPipe.worker calls it to record the current micro-batch lane. | ||
_local.microbatch = microbatch | ||
|
||
|
||
def current_microbatch() -> Tensor: | ||
"""Gets the current micro-batch identifier as a tensor. | ||
If your modules should rely on where the current micro-batch lane, use it | ||
to identify the lane. | ||
""" | ||
try: | ||
return _local.microbatch | ||
except AttributeError: | ||
raise RuntimeError('not in partition') | ||
|
||
|
||
class local: | ||
"""A micro-batch local object, just like :class:`threading.local`. | ||
It stores any attributes as per micro-batch. If your modules need a storage | ||
isolated with another micro-batch lanes, use it to keep your implementation | ||
simple:: | ||
local = torchgpipe.local() | ||
class RememberMax(nn.Module): | ||
def forward(self, x): | ||
self.local.max = x.max() | ||
return x | ||
class NormalzeByRememberedMax(nn.Module): | ||
def forward(self, x): | ||
return x / self.local.max | ||
""" | ||
|
||
def __init__(self) -> None: | ||
# Call object.__setattr__ instead of the overridden __setattr__ to | ||
# avoid recursive __getattr__ calls. | ||
object.__setattr__(self, 'local_dicts', _local_dicts()) | ||
|
||
def __getattr__(self, attr: str) -> Any: | ||
__dict__ = self.local_dicts.dict() | ||
return __dict__[attr] | ||
|
||
def __setattr__(self, attr: str, value: Any) -> None: | ||
if value is current_microbatch(): | ||
raise RuntimeError('current_microbatch() is not allowed to store in local') | ||
|
||
__dict__ = self.local_dicts.dict() | ||
__dict__[attr] = value | ||
|
||
def __delattr__(self, attr: str) -> None: | ||
__dict__ = self.local_dicts.dict() | ||
del __dict__[attr] | ||
|
||
|
||
class _local_dicts: | ||
"""The internal local storage manager for :class:`local`. It provides a | ||
dictionary attached to the current micro-batch and removes it when the | ||
micro-batch has gone. | ||
This implementation imitates ``_threading_local._localimpl``. | ||
""" | ||
def __init__(self) -> None: | ||
self.key = 'torchgpipe.microbatchlocal.%d' % id(self) | ||
self.dicts: Dict[int, Tuple[LocalDict, WeakTensor]] = {} | ||
|
||
def dict(self) -> LocalDict: | ||
"""Gets the micro-batch local dictionary. If the dictionary doesn't | ||
exist yet, it creates one. | ||
""" | ||
microbatch = current_microbatch() | ||
microbatch_id = id(microbatch) | ||
|
||
try: | ||
local_dict, _ = self.dicts[microbatch_id] | ||
return local_dict | ||
except KeyError: | ||
return self._create_dict(microbatch, microbatch_id) | ||
|
||
def _create_dict(self, microbatch: Tensor, microbatch_id: int) -> LocalDict: | ||
local_dict: LocalDict = {} | ||
local_key = self.key | ||
|
||
# The below two weakref callback functions disconnect weak references | ||
# between the micro-batch and this _local_dicts object. | ||
def local_dicts_deleted(_: Any, local_key: str = local_key) -> None: | ||
microbatch = weak_microbatch() | ||
if microbatch is not None: | ||
del microbatch.__dict__[local_key] | ||
|
||
def microbatch_deleted(_: Any, microbatch_id: int = microbatch_id) -> None: | ||
local_dicts = weak_local_dicts() | ||
if local_dicts is not None: | ||
local_dicts.dicts.pop(microbatch_id) | ||
|
||
weak_local_dicts = weakref.ref(self, local_dicts_deleted) | ||
weak_microbatch = weakref.ref(microbatch, microbatch_deleted) | ||
|
||
microbatch.__dict__[local_key] = weak_local_dicts | ||
self.dicts[microbatch_id] = local_dict, weak_microbatch | ||
|
||
return local_dict |