Skip to content

Commit

Permalink
Remove CUDA stream pool
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Sep 25, 2019
1 parent 1916ce2 commit 146b686
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
15 changes: 1 addition & 14 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torchgpipe.batchnorm import DeferredBatchNorm
from torchgpipe.microbatch import check, gather, scatter
from torchgpipe.pipeline import Pipeline
from torchgpipe.stream import AbstractStream, new_stream

__all__ = ['GPipe']

Expand Down Expand Up @@ -229,8 +228,6 @@ def __init__(self,
except BalanceError as exc:
raise ValueError(recommend_torchgpipe_balancing(str(exc)))

self._copy_streams: List[List[AbstractStream]] = []

def __len__(self) -> int:
"""Counts the length of the underlying sequential module."""
return sum(len(p) for p in self.partitions)
Expand Down Expand Up @@ -285,13 +282,6 @@ def to(self, *args: Any, **kwargs: Any) -> 'GPipe':

return super().to(*args, **kwargs)

def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
if not self._copy_streams:
for device in self.devices:
self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])

return self._copy_streams

def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
""":class:`GPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
Expand All @@ -315,9 +305,6 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
# Empty sequential module is not illegal.
return input

# Prepare separate CUDA streams only for copy.
copy_streams = self._ensure_copy_streams()

# Divide a mini-batch into micro-batches.
batches = scatter(input, self.chunks)

Expand All @@ -330,7 +317,7 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
checkpoint_stop = 0

# Run pipeline parallelism.
pipeline = Pipeline(batches, self.partitions, self.devices, copy_streams, checkpoint_stop)
pipeline = Pipeline(batches, self.partitions, self.devices, checkpoint_stop)
pipeline.run()

# Merge the micro-batches into one mini-batch.
Expand Down
10 changes: 5 additions & 5 deletions torchgpipe/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchgpipe.copy import Copy, Wait
from torchgpipe.dependency import Fork, Join
from torchgpipe.microbatch import Batch
from torchgpipe.stream import AbstractStream, CPUStream, current_stream
from torchgpipe.stream import AbstractStream, current_stream, new_stream
from torchgpipe.worker import Task, spawn_workers

__all__: List[str] = []
Expand Down Expand Up @@ -64,7 +64,6 @@ def __init__(self,
batches: List[Batch],
partitions: List[nn.Sequential],
devices: Optional[List[torch.device]] = None,
copy_streams: Optional[List[List[AbstractStream]]] = None,
checkpoint_stop: int = 0,
) -> None:
self.batches = batches
Expand All @@ -74,9 +73,10 @@ def __init__(self,
devices = [torch.device('cpu') for _ in partitions]
self.devices = devices

if copy_streams is None:
copy_streams = [[CPUStream] * len(batches) for _ in partitions]
self.copy_streams = copy_streams
# NOTE(sublee): We don't need to manage a pool of CUDA streams because
# PyTorch already manages it.
# See https://github.com/pytorch/pytorch/pull/9938
self.copy_streams = [[new_stream(d) for _ in self.batches] for d in devices]

self.checkpoint_stop = checkpoint_stop

Expand Down

0 comments on commit 146b686

Please sign in to comment.