Skip to content

Commit

Permalink
Implement micro-batch local
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed May 16, 2019
1 parent ad0ea56 commit 2e482a8
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 5 deletions.
116 changes: 116 additions & 0 deletions tests/test_microbatchlocal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import gc
import time

import pytest
import torch
from torch import nn

from torchgpipe import GPipe, microbatchlocal


def test_current_microbatch():
class Twice(nn.Module):
def forward(self, x):
return x * 2

class CurrentMicrobatch(nn.Module):
def forward(self, _):
return microbatchlocal.current_microbatch()

# Not in a partition.
with pytest.raises(RuntimeError):
microbatchlocal.current_microbatch()

input = torch.tensor([1., 2., 3.])

model = nn.Sequential(Twice(), CurrentMicrobatch())
model = GPipe(model, balance=[1, 1], devices=['cpu', 'cpu'], chunks=3)

output = model(input)

assert torch.allclose(output, torch.tensor([1., 2., 3.]))

# Not in a partition.
with pytest.raises(RuntimeError):
microbatchlocal.current_microbatch()


def test_local():
local = microbatchlocal.local()
dicts = object.__getattribute__(local, 'local_dicts').dicts

class Keep(nn.Module):
def forward(self, x):
local.x = x
return x - x # return zero tensor.

class InvDelay(nn.Module):
def forward(self, x):
# Sleep for 0.01 seconds for tensor([3.]), 0.03 seconds for tensor([1.]).
time.sleep((4-x.item()) / 100)
return x

class Restore(nn.Module):
def forward(self, x):
del x
return local.x

input = torch.tensor([1., 2., 3.])

model = nn.Sequential(Keep(), InvDelay(), Restore())
model = GPipe(model, balance=[1, 1, 1], devices=['cpu', 'cpu', 'cpu'], chunks=3)

output = model(input)

assert torch.allclose(output, torch.tensor([1., 2., 3.]))

# The micro-batch local storage has to be cleared by GC.
gc.collect()
assert len(dicts) == 0


def test_local_current_microbatch():
local = microbatchlocal.local()

class KeepCurrentMicrobatch(nn.Module):
def forward(self, x):
local.x = microbatchlocal.current_microbatch()
return x

input = torch.tensor([1.])

model = nn.Sequential(KeepCurrentMicrobatch())
model = GPipe(model, balance=[1], devices=['cpu'])

with pytest.raises(RuntimeError):
model(input)


def test_local_out_of_partition():
local = microbatchlocal.local()
with pytest.raises(RuntimeError):
local.test = 'test'


def test_local_attrerror():
local = microbatchlocal.local()

class Test(nn.Module):
def forward(self, x):
with pytest.raises(AttributeError):
local.not_exists
with pytest.raises(AttributeError):
del local.not_exists

local.exists = True
local.exists
del local.exists

return x

input = torch.tensor([1.])

model = nn.Sequential(Test())
model = GPipe(model, balance=[1], devices=['cpu'])

model(input)
3 changes: 2 additions & 1 deletion torchgpipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A GPipe implementation in PyTorch."""
from torchgpipe.__version__ import __version__ # noqa
from torchgpipe.gpipe import GPipe
from torchgpipe.microbatchlocal import current_microbatch, local

__all__ = ['GPipe']
__all__ = ['GPipe', 'current_microbatch', 'local']
17 changes: 13 additions & 4 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchgpipe.batchnorm import patch_deferred_batch_norm
from torchgpipe.checkpoint import first
from torchgpipe.microbatch import gather, scatter
from torchgpipe.microbatchlocal import _record_microbatch
from torchgpipe.partition import Partition

__all__ = ['GPipe']
Expand Down Expand Up @@ -252,7 +253,11 @@ def worker(partition: Partition,
if failed:
continue

input, checkpoint = msg.payload
input, leaf, checkpoint = msg.payload

# Track the current micro-batch lane by the leaf tensor of the
# lane. It can be accessed by current_microbatch().
_record_microbatch(leaf)

# Linearize micro-batches by dependency between nth micro-batch
# input and n-1th micro-batch output. It prevents unexpected
Expand Down Expand Up @@ -282,7 +287,7 @@ def worker(partition: Partition,
# don't send the current micro-batch until the next partition is ready to receive it.
out_queue.join()

msg = Message(msg.i, (output, checkpoint))
msg = Message(msg.i, (output, leaf, checkpoint))
out_queue.put(msg)

def push_input(self,
Expand All @@ -305,7 +310,11 @@ def push_input(self,
elif self.checkpoint == 'never':
checkpoint = False

msg = Message(i, (_input, checkpoint))
# Every partition should track the current micro-batch. A
# micro-batch lane can be identified its detached leaf tensor.
leaf = (_input[0] if isinstance(_input, tuple) else _input).detach()

msg = Message(i, (_input, leaf, checkpoint))
in_queue.put(msg)

close = Message(num_inputs, None)
Expand Down Expand Up @@ -342,7 +351,7 @@ def pull_output(self,
exc_info = msg.payload
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])

output, _ = msg.payload
output, _, _ = msg.payload
outputs.append(output)

output = gather(outputs, device=self.out_device)
Expand Down
125 changes: 125 additions & 0 deletions torchgpipe/microbatchlocal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Provides tools to implement per micro-batch features."""
import threading
from typing import Any, Callable, Dict, Optional, Tuple
import weakref

from torch import Tensor

__all__ = ['current_microbatch', 'local']


LocalDict = Dict[str, Any]
WeakTensor = Callable[[], Optional[Tensor]]


# The micro-batch leaf tensor storage for each partition worker thread.
_local = threading.local()


def _record_microbatch(microbatch: Tensor) -> None:
# GPipe.worker calls it to record the current micro-batch lane.
_local.microbatch = microbatch


def current_microbatch() -> Tensor:
"""Gets the current micro-batch identifier as a tensor.
If your modules should rely on where the current micro-batch lane, use it
to identify the lane.
"""
try:
return _local.microbatch
except AttributeError:
raise RuntimeError('not in partition')


class local:
"""A micro-batch local object, just like :class:`threading.local`.
It stores any attributes as per micro-batch. If your modules need a storage
isolated with another micro-batch lanes, use it to keep your implementation
simple::
local = torchgpipe.local()
class RememberMax(nn.Module):
def forward(self, x):
self.local.max = x.max()
return x
class NormalzeByRememberedMax(nn.Module):
def forward(self, x):
return x / self.local.max
"""

def __init__(self) -> None:
# Call object.__setattr__ instead of the overridden __setattr__ to
# avoid recursive __getattr__ calls.
object.__setattr__(self, 'local_dicts', _local_dicts())

def __getattr__(self, attr: str) -> Any:
__dict__ = self.local_dicts.dict()
return __dict__[attr]

def __setattr__(self, attr: str, value: Any) -> None:
if value is current_microbatch():
raise RuntimeError('current_microbatch() is not allowed to store in local')

__dict__ = self.local_dicts.dict()
__dict__[attr] = value

def __delattr__(self, attr: str) -> None:
__dict__ = self.local_dicts.dict()
del __dict__[attr]


class _local_dicts:
"""The internal local storage manager for :class:`local`. It provides a
dictionary attached to the current micro-batch and removes it when the
micro-batch has gone.
This implementation imitates ``_threading_local._localimpl``.
"""
def __init__(self) -> None:
self.key = 'torchgpipe.microbatchlocal.%d' % id(self)
self.dicts: Dict[int, Tuple[LocalDict, WeakTensor]] = {}

def dict(self) -> LocalDict:
"""Gets the micro-batch local dictionary. If the dictionary doesn't
exist yet, it creates one.
"""
microbatch = current_microbatch()
microbatch_id = id(microbatch)

try:
local_dict, _ = self.dicts[microbatch_id]
return local_dict
except KeyError:
return self._create_dict(microbatch, microbatch_id)

def _create_dict(self, microbatch: Tensor, microbatch_id: int) -> LocalDict:
local_dict: LocalDict = {}
local_key = self.key

# The below two weakref callback functions disconnect weak references
# between the micro-batch and this _local_dicts object.
def local_dicts_deleted(_: Any, local_key: str = local_key) -> None:
microbatch = weak_microbatch()
if microbatch is not None:
del microbatch.__dict__[local_key]

def microbatch_deleted(_: Any, microbatch_id: int = microbatch_id) -> None:
local_dicts = weak_local_dicts()
if local_dicts is not None:
local_dicts.dicts.pop(microbatch_id)

weak_local_dicts = weakref.ref(self, local_dicts_deleted)
weak_microbatch = weakref.ref(microbatch, microbatch_deleted)

microbatch.__dict__[local_key] = weak_local_dicts
self.dicts[microbatch_id] = local_dict, weak_microbatch

return local_dict

0 comments on commit 2e482a8

Please sign in to comment.