Skip to content

Commit

Permalink
Use microbatch functions with explicit prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Oct 30, 2019
1 parent 0cb3860 commit 06ee106
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchgpipe/gpipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.autograd
import torch.cuda

from torchgpipe import microbatch
from torchgpipe.batchnorm import DeferredBatchNorm
from torchgpipe.microbatch import check, gather, scatter
from torchgpipe.pipeline import Pipeline
from torchgpipe.stream import AbstractStream, new_stream

Expand Down Expand Up @@ -323,14 +323,14 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
TypeError: input is not a tensor or tensors.
"""
check(input)
microbatch.check(input)

if not self.devices:
# Empty sequential module is not illegal.
return input

# Divide a mini-batch into micro-batches.
batches = scatter(input, self.chunks)
batches = microbatch.scatter(input, self.chunks)

# Separate CUDA streams for copy.
copy_streams = self._ensure_copy_streams()
Expand All @@ -354,5 +354,5 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore
pipeline.run()

# Merge the micro-batches into one mini-batch.
output = gather(batches)
output = microbatch.gather(batches)
return output

0 comments on commit 06ee106

Please sign in to comment.