Skip to content

Commit

Permalink
Tidy up internal interfaces
Browse files Browse the repository at this point in the history
- Publish is_recomputing().
- Prefix non-public methods of GPipe with underscore.
- Remove non-public interfaces from __all__.
  • Loading branch information
sublee committed Jun 17, 2019
1 parent fceb607 commit afdee27
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 33 deletions.
3 changes: 2 additions & 1 deletion torchgpipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A GPipe implementation in PyTorch."""
from torchgpipe.__version__ import __version__ # noqa
from torchgpipe.checkpoint import is_recomputing
from torchgpipe.gpipe import GPipe, current_microbatch

__all__ = ['GPipe', 'current_microbatch']
__all__ = ['GPipe', 'current_microbatch', 'is_recomputing']
2 changes: 1 addition & 1 deletion torchgpipe/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor
import torch.autograd

__all__ = ['checkpoint', 'is_recomputing', 'first']
__all__ = ['is_recomputing']


Tensors = Tuple[Tensor, ...]
Expand Down
48 changes: 24 additions & 24 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self,
if deferred_batch_norm:
module = DeferredBatchNorm.convert_deferred_batch_norm(module, self.chunks)

self.partitions, self.balance, self.devices = self.partition(module, balance, devices)
self.partitions, self.balance, self.devices = self._partition(module, balance, devices)

def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
Expand Down Expand Up @@ -201,10 +201,10 @@ def to(self, *args: Any, **kwargs: Any) -> 'GPipe':
return super().to(*args, **kwargs)

@staticmethod
def partition(module: nn.Sequential,
balance: Iterable[int],
devices: Optional[Devices],
) -> Tuple[nn.ModuleList, Tuple[int, ...], Tuple[torch.device, ...]]:
def _partition(module: nn.Sequential,
balance: Iterable[int],
devices: Optional[Devices],
) -> Tuple[nn.ModuleList, Tuple[int, ...], Tuple[torch.device, ...]]:
"""Partitions the given sequential module onto the devices.
Returns:
Expand Down Expand Up @@ -261,7 +261,7 @@ def partition(module: nn.Sequential,

return nn.ModuleList(partitions), tuple(balance), tuple(devices)

def spawn_workers(self) -> Tuple[PriorityQueue, PriorityQueue]:
def _spawn_workers(self) -> Tuple[PriorityQueue, PriorityQueue]:
"""Creates worker threads."""
partitions = cast(List[Partition], self.partitions)

Expand All @@ -274,18 +274,18 @@ def spawn_workers(self) -> Tuple[PriorityQueue, PriorityQueue]:
out_queue = queues[i+1]

args = (partition, in_queue, out_queue, grad_enabled)
t = threading.Thread(target=GPipe.worker, args=args)
t = threading.Thread(target=GPipe._worker, args=args)
t.daemon = True
t.start()

return queues[0], queues[-1]

@staticmethod
def worker(partition: Partition,
in_queue: PriorityQueue,
out_queue: PriorityQueue,
grad_enabled: bool,
) -> None:
def _worker(partition: Partition,
in_queue: PriorityQueue,
out_queue: PriorityQueue,
grad_enabled: bool,
) -> None:
"""Run by worker threads."""
torch.set_grad_enabled(grad_enabled)

Expand Down Expand Up @@ -349,10 +349,10 @@ def worker(partition: Partition,
msg = Message(msg.i, (output, leaf, checkpoint))
out_queue.put(msg)

def push_input(self,
input: TensorOrTensors,
in_queue: PriorityQueue,
) -> int:
def _push_input(self,
input: TensorOrTensors,
in_queue: PriorityQueue,
) -> int:
"""Pushes chunked inputs to the first partition."""
# Divide a mini-batch into micro-batches.
in_device = self.devices[0]
Expand Down Expand Up @@ -382,11 +382,11 @@ def push_input(self,

return num_inputs

def pull_output(self,
num_inputs: int,
in_queue: PriorityQueue,
out_queue: PriorityQueue,
) -> Tensor:
def _pull_output(self,
num_inputs: int,
in_queue: PriorityQueue,
out_queue: PriorityQueue,
) -> Tensor:
"""Collects and concatenates chunked outputs from the last partition.
If an exception from a parititon is detected, all workers are closed
Expand Down Expand Up @@ -440,6 +440,6 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors:
check(input)
return input

in_queue, out_queue = self.spawn_workers()
num_inputs = self.push_input(input, in_queue)
return self.pull_output(num_inputs, in_queue, out_queue)
in_queue, out_queue = self._spawn_workers()
num_inputs = self._push_input(input, in_queue)
return self._pull_output(num_inputs, in_queue, out_queue)
14 changes: 7 additions & 7 deletions torchgpipe/microbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
import torch.cuda.comm

__all__ = ['check', 'scatter', 'gather']
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Expand All @@ -32,13 +32,13 @@ def check(input: TensorOrTensors) -> None:
def scatter(input: TensorOrTensors, chunks: int, device: torch.device) -> ChunkedTensorOrTensors:
"""Splits an input mini-batch into multiple micro-batches."""
if isinstance(input, tuple):
buf = [_scatter_1(x, chunks, device) for x in input]
buf = [scatter_1(x, chunks, device) for x in input]
return list(zip(*buf))

return _scatter_1(input, chunks, device)
return scatter_1(input, chunks, device)


def _scatter_1(tensor: Tensor, chunks: int, device: torch.device) -> List[Tensor]:
def scatter_1(tensor: Tensor, chunks: int, device: torch.device) -> List[Tensor]:
"""Choose the best PyTorch API for :func:`scatter`."""
if not isinstance(tensor, Tensor):
raise TypeError('expected Tensor to scatter, but got %s' % tensor.__class__.__name__)
Expand All @@ -56,14 +56,14 @@ def _scatter_1(tensor: Tensor, chunks: int, device: torch.device) -> List[Tensor
def gather(outputs: ChunkedTensorOrTensors, device: torch.device) -> TensorOrTensors:
"""Concatenates output micro-batches into a mini-batch."""
if isinstance(outputs[0], tuple):
buf = [_gather_1(list(chunks), device) for chunks in zip(*outputs)]
buf = [gather_1(list(chunks), device) for chunks in zip(*outputs)]
return tuple(buf)

# NOTE(sublee): mypy could not infer the type after the above isinstance.
return _gather_1(outputs, device) # type: ignore
return gather_1(outputs, device) # type: ignore


def _gather_1(tensors: List[Tensor], device: torch.device) -> Tensor:
def gather_1(tensors: List[Tensor], device: torch.device) -> Tensor:
"""Choose the best PyTorch API for :func:`gather`."""
if device.type == 'cpu':
tensor = torch.cat(tensors)
Expand Down

0 comments on commit afdee27

Please sign in to comment.