Skip to content

Commit

Permalink
Workaround for record_stream() on shifted view tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee authored and GitHub Enterprise committed Oct 8, 2019
1 parent 4e5b4fb commit 6d2dfb3
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 22 deletions.
5 changes: 5 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ v0.0.4 (WIP)
Not released yet.

- Reduced GPU memory fragmentation by caching CUDA streams for copy.
- Fixed potential GPU memory violation when copying shifted view tensors.
(`issue #27366`_ and `pull request #27371`_ on PyTorch)

.. _issue #27366: https://github.com/pytorch/pytorch/issues/27366
.. _pull request #27371: https://github.com/pytorch/pytorch/pull/27371

v0.0.3
~~~~~~
Expand Down
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,21 @@ def manual_seed_zero():
torch.manual_seed(0)


@pytest.fixture(scope='session')
def cuda_sleep():
# From test/test_cuda.py in PyTorch.
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end)

def cuda_sleep(seconds):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
return cuda_sleep


def pytest_report_header():
return f'torch: {torch.__version__}'
16 changes: 8 additions & 8 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason='cuda required')


def _test_copy_wait(prev_stream, next_stream):
def _test_copy_wait(prev_stream, next_stream, cuda_sleep=None):
device = get_device(prev_stream)

with use_stream(prev_stream):
if is_cuda(prev_stream):
torch.cuda._sleep(100000000)
cuda_sleep(0.5)
x = torch.ones(100, device=device, requires_grad=True)

y, = Copy.apply(prev_stream, next_stream, x)
Expand All @@ -33,21 +33,21 @@ def test_copy_wait_cpu_cpu():


@skip_if_no_cuda
def test_copy_wait_cpu_cuda():
def test_copy_wait_cpu_cuda(cuda_sleep):
prev_stream = CPUStream
next_stream = current_stream(torch.device('cuda'))
_test_copy_wait(prev_stream, next_stream)
_test_copy_wait(prev_stream, next_stream, cuda_sleep)


@skip_if_no_cuda
def test_copy_wait_cuda_cpu():
def test_copy_wait_cuda_cpu(cuda_sleep):
prev_stream = current_stream(torch.device('cuda'))
next_stream = CPUStream
_test_copy_wait(prev_stream, next_stream)
_test_copy_wait(prev_stream, next_stream, cuda_sleep)


@skip_if_no_cuda
def test_copy_wait_cuda_cuda():
def test_copy_wait_cuda_cuda(cuda_sleep):
prev_stream = current_stream(torch.device('cuda'))
next_stream = new_stream(torch.device('cuda'))
_test_copy_wait(prev_stream, next_stream)
_test_copy_wait(prev_stream, next_stream, cuda_sleep)
60 changes: 46 additions & 14 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_get_device_cuda(self):


class TestWaitStream:
def _test_wait_stream(self, source, target):
def _test_wait_stream(self, source, target, cuda_sleep=None):
with use_stream(target):
if is_cuda(target):
torch.cuda._sleep(100000000)
cuda_sleep(0.5)
x = torch.ones(100, 100, device=get_device(target))

wait_stream(source, target)
Expand All @@ -94,47 +94,79 @@ def test_wait_stream_cpu_cpu(self):
self._test_wait_stream(source, target)

@skip_if_no_cuda
def test_wait_stream_cpu_cuda(self):
def test_wait_stream_cpu_cuda(self, cuda_sleep):
source = CPUStream
target = new_stream(torch.device('cuda'))
self._test_wait_stream(source, target)
self._test_wait_stream(source, target, cuda_sleep)

@skip_if_no_cuda
def test_wait_stream_cuda_cpu(self):
def test_wait_stream_cuda_cpu(self, cuda_sleep):
source = new_stream(torch.device('cuda'))
target = CPUStream
self._test_wait_stream(source, target)
self._test_wait_stream(source, target, cuda_sleep)

@skip_if_no_cuda
def test_wait_stream_cuda_cuda(self):
def test_wait_stream_cuda_cuda(self, cuda_sleep):
source = current_stream(torch.device('cuda'))
target = new_stream(torch.device('cuda'))
self._test_wait_stream(source, target)
self._test_wait_stream(source, target, cuda_sleep)


class TestRecordStream:
def test_record_stream_cpu(self):
# It should silently ignore CPU tensors.
x = torch.rand(1, device=torch.device('cpu'))
record_stream(x, CPUStream)

@skip_if_no_cuda
def test_record_stream_cuda(self):
x = torch.rand(1, device=torch.device('cuda'))
data_ptr = x.data_ptr()
def test_record_stream_cuda(self, cuda_sleep):
# This test detects unexpected block reallocation. For reliable test,
# the stream to allocate tensors is isolated. The allocator will not
# reuse free blocks which were allocated from another stream.
stream_alloc = new_stream(torch.device('cuda'))
with torch.cuda.stream(stream_alloc):
x = torch.rand(1, device=torch.device('cuda'))

stream = new_stream(torch.device('cuda'))
record_stream(x, stream)
with use_stream(stream):
torch.cuda._sleep(100000000)
cuda_sleep(0.5)

# 'x' is deleted at Python's perspective. But the block of 'x' is still
# required for 'stream'. 'y' shouldn't be allocated to the block.
data_ptr = x.data_ptr()
del x
y = torch.rand(1, device=torch.device('cuda'))
stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
y = torch.rand(1, device=torch.device('cuda'))
assert y.data_ptr() != data_ptr

# Pause Python until 'stream' finishes tasks queued. Now the block of
# 'x' is free to be reallocated.
wait_stream(CPUStream, stream)
z = torch.rand(1, device=torch.device('cuda'))
with torch.cuda.stream(stream_alloc):
z = torch.rand(1, device=torch.device('cuda'))
assert z.data_ptr() == data_ptr

@skip_if_no_cuda
def test_record_stream_shifted_view(self, cuda_sleep):
# Issue: https://github.com/pytorch/pytorch/issues/27366
stream_alloc = new_stream(torch.device('cuda'))
with torch.cuda.stream(stream_alloc):
x = torch.rand(2, device=torch.device('cuda'))

y = x[1:]
assert y.data_ptr() > x.data_ptr()

stream = new_stream(torch.device('cuda'))
with use_stream(stream):
cuda_sleep(0.5)
record_stream(y, stream)

data_ptr = x.data_ptr()
del x, y

stream_alloc.synchronize()
with torch.cuda.stream(stream_alloc):
z = torch.rand(2, device=torch.device('cuda'))
assert z.data_ptr() != data_ptr
10 changes: 10 additions & 0 deletions torchgpipe/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
if is_cuda(stream):
# NOTE(sublee): record_stream() on a shifted view tensor throws
# RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
# protect the tensor against unexpected reallocation, here we use a
# temporal tensor associated with the same storage without shifting as
# a workaround.
#
# Issue: https://github.com/pytorch/pytorch/issues/27366
#
tensor = tensor.new_empty([0]).set_(tensor.storage())

tensor.record_stream(as_cuda(stream))


Expand Down

0 comments on commit 6d2dfb3

Please sign in to comment.