Skip to content

Commit

Permalink
Apply Wait to every tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Oct 8, 2019
1 parent 622025c commit d5acb37
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 10 deletions.
3 changes: 2 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ v0.0.4 (WIP)
Not released yet.

- Reduced GPU memory fragmentation by caching CUDA streams for copy.
- Fixed potential GPU memory violation when copying shifted view tensors.
- Fixed potential GPU memory violation on tuple of multiple tensors.
- Fixed potential GPU memory violation on shifted view tensors.
(`issue #27366`_ and `pull request #27371`_ on PyTorch)

.. _issue #27366: https://github.com/pytorch/pytorch/issues/27366
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def manual_seed_zero():

@pytest.fixture(scope='session')
def cuda_sleep():
# Warm-up CUDA.
torch.empty(1, device='cuda')

# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_bugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,46 @@ def forward(self, x):

with pytest.raises(ExpectedException):
model(torch.rand(3))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='2 cuda devices required')
def test_tuple_wait(cuda_sleep):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
# properly to the copy stream.
class Sleep(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.detach()

@staticmethod
def backward(ctx, grad):
with torch.cuda.device(grad.device):
cuda_sleep(0.05)
return grad

class Layer1(nn.Module):
def forward(self, pair):
a, b = pair
return a*1, b*2, b*3

class Layer2(nn.Module):
def forward(self, triple):
a, b, c = triple
b = Sleep.apply(b)
return a+b+c

model = nn.Sequential(Layer1(), Layer2())
model = GPipe(model, [1, 1], devices=[0, 1], chunks=32, checkpoint='never')

a = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)

y = model((a, b))
y.norm().backward()

torch.cuda.synchronize(0)
torch.cuda.synchronize(1)

assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))
12 changes: 11 additions & 1 deletion tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
x = torch.ones(100, device=device, requires_grad=True)

y, = Copy.apply(prev_stream, next_stream, x)
y = Wait.apply(prev_stream, next_stream, x)
y, = Wait.apply(prev_stream, next_stream, x)

with use_stream(next_stream):
assert torch.allclose(y.sum(), torch.tensor(100.0, device=device))
Expand Down Expand Up @@ -51,3 +51,13 @@ def test_copy_wait_cuda_cuda(cuda_sleep):
prev_stream = current_stream(torch.device('cuda'))
next_stream = new_stream(torch.device('cuda'))
_test_copy_wait(prev_stream, next_stream, cuda_sleep)


def test_wait_multiple_tensors():
a = torch.rand(1, requires_grad=True)
b = torch.rand(1, requires_grad=True)

a, b = Wait.apply(CPUStream, CPUStream, a, b)

assert a.grad_fn is b.grad_fn
assert a.grad_fn.__class__ is Wait._backward_cls
15 changes: 8 additions & 7 deletions torchgpipe/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,23 @@ class Wait(torch.autograd.Function):
def forward(ctx: Context, # type: ignore
prev_stream: AbstractStream,
next_stream: AbstractStream,
input: Tensor,
) -> Tensor:
*input: Tensor,
) -> Tensors:
ctx.prev_stream = prev_stream
ctx.next_stream = next_stream

wait_stream(next_stream, prev_stream)

return input.detach()
return tuple(x.detach() for x in input)

@staticmethod
def backward(ctx: Context, # type: ignore
grad_input: Tensor,
) -> Tuple[None, None, Tensor]:
def backward(ctx: Context,
*grad_input: Tensor,
) -> Tuple[Optional[Tensor], ...]:
prev_stream = ctx.prev_stream
next_stream = ctx.next_stream

wait_stream(prev_stream, next_stream)

return None, None, grad_input
grad_streams: Tuple[Optional[Tensor], ...] = (None, None)
return grad_streams + grad_input
2 changes: 1 addition & 1 deletion torchgpipe/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream)


def wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None:
batch[0] = Wait.apply(prev_stream, next_stream, batch[0])
batch[:] = Wait.apply(prev_stream, next_stream, *batch)


def clock_cycles(n: int, m: int) -> Iterable[List[Tuple[int, int]]]:
Expand Down

0 comments on commit d5acb37

Please sign in to comment.