Skip to content

Commit

Permalink
Merge pull request #8219 from boku13/streampriority
Browse files Browse the repository at this point in the history
add `cudaStreamCreateWithPriority`
  • Loading branch information
takagi committed Mar 13, 2024
2 parents be5d7f6 + f7d162a commit 50d48e7
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 5 deletions.
31 changes: 26 additions & 5 deletions cupy/cuda/stream.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,21 @@ class _BaseStream:
except RuntimeError: # can be RuntimeError or CUDARuntimeError
raise

@property
def is_non_blocking(self):
"""True if the stream is non_blocking.
False indicates the default stream creation flag."""
cdef unsigned int flags
flags = runtime.streamGetFlags(self.ptr)
return (flags & runtime.streamNonBlocking) != 0

@property
def priority(self):
"""Query the priority of a stream."""
cdef int priority
priority = runtime.streamGetPriority(self.ptr)
return priority


class Stream(_BaseStream):

Expand All @@ -445,6 +460,8 @@ class Stream(_BaseStream):
per-thread default stream object.
non_blocking (bool): If ``True`` and both ``null`` and ``ptds`` are
``False``, the stream does not synchronize with the NULL stream.
priority (int): Priority of the stream. Lower numbers represent higher
priorities.
Attributes:
~Stream.ptr (intptr_t): Raw stream handle.
Expand All @@ -457,7 +474,8 @@ class Stream(_BaseStream):
null = None
ptds = None

def __init__(self, null=False, non_blocking=False, ptds=False):
def __init__(self, null=False, non_blocking=False, ptds=False,
priority=None):
if null:
# TODO(pentschev): move to streamLegacy. This wasn't possible
# because of a NCCL bug that should be fixed in the version
Expand All @@ -470,11 +488,14 @@ class Stream(_BaseStream):
'default stream (ptds)')
ptr = runtime.streamPerThread
device_id = -1
elif non_blocking:
ptr = runtime.streamCreateWithFlags(runtime.streamNonBlocking)
device_id = runtime.getDevice()
else:
ptr = runtime.streamCreate()
if priority is None:
priority = 0 # default
if non_blocking:
flag = runtime.streamNonBlocking
else:
flag = runtime.streamDefault
ptr = runtime.streamCreateWithPriority(flag, priority)
device_id = runtime.getDevice()
super().__init__(ptr, device_id)

Expand Down
4 changes: 4 additions & 0 deletions cupy_backends/cuda/api/_runtime_extern.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ cdef extern from '../../cupy_backend_runtime.h' nogil:
int cudaStreamCreate(driver.Stream* pStream)
int cudaStreamCreateWithFlags(driver.Stream* pStream,
unsigned int flags)
int cudaStreamCreateWithPriority(driver.Stream* pStream,
unsigned int flags, int priority)
int cudaStreamGetFlags(driver.Stream pStream, unsigned int* flags)
int cudaStreamGetPriority(driver.Stream pStream, int* priority)
int cudaStreamDestroy(driver.Stream stream)
int cudaStreamSynchronize(driver.Stream stream)
int cudaStreamAddCallback(driver.Stream stream, StreamCallback callback,
Expand Down
4 changes: 4 additions & 0 deletions cupy_backends/cuda/api/runtime.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ cpdef memPoolSetAttribute(intptr_t, int, object)

cpdef intptr_t streamCreate() except? 0
cpdef intptr_t streamCreateWithFlags(unsigned int flags) except? 0
cpdef intptr_t streamCreateWithPriority(unsigned int flags,
int priority) except? 0
cpdef unsigned int streamGetFlags(intptr_t stream) except? 0
cpdef int streamGetPriority(intptr_t stream) except? 0
cpdef streamDestroy(intptr_t stream)
cpdef streamSynchronize(intptr_t stream)
cpdef streamAddCallback(intptr_t stream, callback, intptr_t arg,
Expand Down
22 changes: 22 additions & 0 deletions cupy_backends/cuda/api/runtime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,28 @@ cpdef intptr_t streamCreateWithFlags(unsigned int flags) except? 0:
return <intptr_t>stream


cpdef intptr_t streamCreateWithPriority(unsigned int flags,
int priority) except? 0:
cdef driver.Stream stream
status = cudaStreamCreateWithPriority(&stream, flags, priority)
check_status(status)
return <intptr_t>stream


cpdef unsigned int streamGetFlags(intptr_t stream) except? 0:
cdef unsigned int flags
status = cudaStreamGetFlags(<driver.Stream>stream, &flags)
check_status(status)
return flags


cpdef int streamGetPriority(intptr_t stream) except? 0:
cdef int priority
status = cudaStreamGetPriority(<driver.Stream>stream, &priority)
check_status(status)
return priority


cpdef streamDestroy(intptr_t stream):
status = cudaStreamDestroy(<driver.Stream>stream)
check_status(status)
Expand Down
14 changes: 14 additions & 0 deletions cupy_backends/hip/cupy_hip_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,20 @@ cudaError_t cudaStreamCreateWithFlags(cudaStream_t *stream,
return hipStreamCreateWithFlags(stream, flags);
}

cudaError_t cudaStreamCreateWithPriority(cudaStream_t *stream,
unsigned int flags,
int priority) {
return hipStreamCreateWithPriority(stream, flags, priority);
}

cudaError_t cudaStreamGetFlags(cudaStream_t stream, unsigned int *flags) {
return hipStreamGetFlags(stream, flags);
}

cudaError_t cudaStreamGetPriority(cudaStream_t stream, int *priority) {
return hipStreamGetPriority(stream, priority);
}

cudaError_t cudaStreamDestroy(cudaStream_t stream) {
return hipStreamDestroy(stream);
}
Expand Down
12 changes: 12 additions & 0 deletions cupy_backends/stub/cupy_cuda_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,18 @@ cudaError_t cudaStreamCreateWithFlags(...) {
return cudaSuccess;
}

cudaError_t cudaStreamCreateWithPriority(...) {
return cudaSuccess;
}

cudaError_t cudaStreamGetFlags(...) {
return cudaSuccess;
}

cudaError_t cudaStreamGetPriority(...) {
return cudaSuccess;
}

cudaError_t cudaStreamDestroy(...) {
return cudaSuccess;
}
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ to use these functions.
cupy.cuda.runtime.pointerGetAttributes
cupy.cuda.runtime.streamCreate
cupy.cuda.runtime.streamCreateWithFlags
cupy.cuda.runtime.streamCreateWithPriority
cupy.cuda.runtime.streamDestroy
cupy.cuda.runtime.streamSynchronize
cupy.cuda.runtime.streamAddCallback
Expand Down
17 changes: 17 additions & 0 deletions tests/cupy_tests/cuda_tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def f2(barrier, errors):
for err in errors:
assert err is False

def test_create_with_flags(self):
s1 = cuda.Stream()
s2 = cuda.Stream(non_blocking=True)
assert s1.is_non_blocking is False
assert s2.is_non_blocking is True

def test_create_with_priority(self):
# parameterize wasn't used since priority gets
# clamped when it isn't initialized within a specific
# returned by `cudaDeviceGetStreamPriorityRange`.
s1 = cuda.Stream(priority=0)
s2 = cuda.Stream(priority=-1)
s3 = cuda.Stream(priority=-3)
assert s1.priority == 0
assert s2.priority == -1
assert s3.priority == -3


class TestExternalStream(unittest.TestCase):

Expand Down

0 comments on commit 50d48e7

Please sign in to comment.