Skip to content

Commit

Permalink
Merge pull request #4651 from leofang/fix_ptds_hip
Browse files Browse the repository at this point in the history
ROCm: disable PTDS
  • Loading branch information
mergify[bot] committed Feb 10, 2021
2 parents 4e8f39d + a47e624 commit cc6f3e4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
6 changes: 5 additions & 1 deletion cupy/cuda/stream.pyx
Expand Up @@ -327,6 +327,9 @@ class Stream(BaseStream):
# following 2.8.3-1.
self.ptr = 0
elif ptds:
if runtime._is_hip_environment:
raise ValueError('HIP does not support per-thread '
'default stream (ptds)')
self.ptr = runtime.streamPerThread
elif non_blocking:
self.ptr = runtime.streamCreateWithFlags(
Expand Down Expand Up @@ -374,4 +377,5 @@ class ExternalStream(BaseStream):


Stream.null = Stream(null=True)
Stream.ptds = Stream(ptds=True)
if not runtime._is_hip_environment:
Stream.ptds = Stream(ptds=True)
49 changes: 34 additions & 15 deletions tests/cupy_tests/cuda_tests/test_stream.py
Expand Up @@ -3,17 +3,26 @@
from cupy._creation import from_data
from cupy import cuda
from cupy import testing
from cupy.testing import attr


@testing.parameterize(
*testing.product({
'stream': [cuda.Stream.null, cuda.Stream.ptds],
'stream_name': ['null', 'ptds'],
}))
@testing.gpu
class TestStream(unittest.TestCase):

@attr.gpu
def test_eq(self):
def setUp(self):
if cuda.runtime.is_hip and self.stream_name == 'ptds':
self.skipTest('HIP does not support PTDS')

if self.stream_name == 'null':
self.stream = cuda.Stream.null
elif self.stream_name == 'ptds':
self.stream = cuda.Stream.ptds

@unittest.skipIf(cuda.runtime.is_hip, 'This test is only for CUDA')
def test_eq_cuda(self):
null0 = self.stream
if self.stream == cuda.Stream.null:
null1 = cuda.Stream(True)
Expand All @@ -30,6 +39,17 @@ def test_eq(self):
assert null2 != null3
assert null2 != null4

@unittest.skipIf(not cuda.runtime.is_hip, 'This test is only for HIP')
def test_eq_hip(self):
null0 = self.stream
null1 = cuda.Stream(True)
null2 = cuda.Stream(True)
null3 = cuda.Stream()

assert null0 == null1
assert null1 == null2
assert null2 != null3

def check_del(self, null, ptds):
stream = cuda.Stream(null=null, ptds=ptds).use()
stream_ptr = stream.ptr
Expand All @@ -41,17 +61,18 @@ def check_del(self, null, ptds):
del stream_ptr
del x

@attr.gpu
def test_del_default(self):
self.check_del(null=False, ptds=False)

@attr.gpu
def test_del(self):
null = self.stream == cuda.Stream.null
ptds = self.stream == cuda.Stream.ptds
if cuda.runtime.is_hip:
ptds = False
else:
ptds = self.stream == cuda.Stream.ptds

self.check_del(null=null, ptds=ptds)

@attr.gpu
def test_get_and_add_callback(self):
N = 100
cupy_arrays = [testing.shaped_random((2, 3)) for _ in range(N)]
Expand Down Expand Up @@ -79,8 +100,8 @@ def _callback(s, _, t):
assert out == list(range(N))
assert all(s == stream.ptr for s in stream_list)

@attr.gpu
@unittest.skipIf(cuda.runtime.is_hip, 'HIP does not support this')
@unittest.skipIf(cuda.runtime.is_hip,
'HIP does not support launch_host_func')
@unittest.skipIf(cuda.driver.get_build_version() < 10000,
'Only CUDA 10.0+ supports this')
def test_launch_host_func(self):
Expand All @@ -98,7 +119,6 @@ def test_launch_host_func(self):
stream.synchronize()
assert out == list(range(N))

@attr.gpu
def test_with_statement(self):
stream1 = cuda.Stream()
stream2 = cuda.Stream()
Expand All @@ -110,14 +130,14 @@ def test_with_statement(self):
assert stream1 == cuda.get_current_stream()
assert self.stream == cuda.get_current_stream()

@attr.gpu
def test_use(self):
stream1 = cuda.Stream().use()
assert stream1 == cuda.get_current_stream()
self.stream.use()
assert self.stream == cuda.get_current_stream()


@testing.gpu
class TestExternalStream(unittest.TestCase):

def setUp(self):
Expand All @@ -127,7 +147,6 @@ def setUp(self):
def tearDown(self):
cuda.runtime.streamDestroy(self.stream_ptr)

@attr.gpu
def test_get_and_add_callback(self):
N = 100
cupy_arrays = [testing.shaped_random((2, 3)) for _ in range(N)]
Expand All @@ -144,8 +163,8 @@ def test_get_and_add_callback(self):
stream.synchronize()
assert out == list(range(N))

@attr.gpu
@unittest.skipIf(cuda.runtime.is_hip, 'HIP does not support this')
@unittest.skipIf(cuda.runtime.is_hip,
'HIP does not support launch_host_func')
@unittest.skipIf(cuda.driver.get_build_version() < 10000,
'Only CUDA 10.0+ supports this')
def test_launch_host_func(self):
Expand Down

0 comments on commit cc6f3e4

Please sign in to comment.