Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Per Thread Default Stream (PTDS) #4322

Merged
merged 31 commits into from
Feb 8, 2021
Merged

Conversation

pentschev
Copy link
Member

This PR adds Per Thread Default Stream (PTDS) support to CuPy, entirely controllable at runtime via an environment variable. The advantage of this approach is that multiple builds aren't necessary, greatly simplifying the shipping procedure of CuPy.

I don't currently know if this is the "right" approach for CuPy, but is certainly useful for us to get some experience with PTDS. Right now I've verified this with some pure Python multithreaded code (see below) and Dask, I was able to get a 2x speedup for some Dask workflows but I didn't do any performance testing with standalone CuPy as this is going to vary greatly w.r.t. application, number of threads, etc. As of now, I also don't know how to write tests to ensure that we aren't using the default stream when setting CUPY_CUDA_PER_THREAD_DEFAULT_STREAM=1, something I verified with the help of nsys and nvprof.

The following code serves as an example, verifiable with nsys/nvprof, each host thread runs on a different stream when CUPY_CUDA_PER_THREAD_DEFAULT_STREAM=1 when compared to running only on the default stream (shown as stream 7) when the variable isn't set:

import threading

import cupy

NumThreads = 4


def thread_func():
    a = cupy.random.random((5000, 5000))
    c = cupy.fft.fft(a)


threads = [threading.Thread(target=thread_func, name="Thread %d" % (i,)) for i in range(NumThreads)]
for t in threads:
    t.start()
for t in threads:
    t.join()

Note we have a similar effort being done for PTDS support in RMM rapidsai/rmm#633.

@pentschev
Copy link
Member Author

cc @leofang @quasiben @jakirkham

Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick pass. I have two questions:

  1. Why does the CUDA Programming Guide say we need to recompile an app with --default-stream, while there exists such a simple solution?!
  2. Why moving to PTDS will bring speedup in some of your tests? Is it because the implicit syncs are removed? If so, wouldn't using non-blocking streams also do the job for you? I am very curious! 😛

Thanks for pinging me btw!

cupy_backends/cuda/api/runtime.pxd Outdated Show resolved Hide resolved
Comment on lines 335 to +336
Stream.null = Stream(null=True)
Stream.ptds = Stream(ptds=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's some logic (likely in BaseStream?) that requires Stream.null to be a singleton. If a user explicitly requests for PTDS, I suppose we should just make Stream.null default to PTDS, instead of creating another special stream?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work too, but I'm not sure if it's the right solution for two reasons:

  1. The null stream may become ambiguous, are we really using the null stream or the per-thread one?
  2. It's not anymore possible to enforce an explicit usage of the default stream when the environment variable is on, not sure how useful it is to use the default stream in that case anyway but it's probably useful to have that option.

Copy link
Member

@leofang leofang Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's best for the core devs to decide, as I could be overlooking some design considerations. My thoughts are:

  1. Stream.null in CuPy means the default stream, and previously it meant the legacy one.
  2. If a user explicitly asks for PTDS by setting the env var, it implicitly suggests that the legacy default stream should be avoided to the maximum possible extent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the legacy stream should be avoided when PTDS is requested, but forcing that the Stream.null is indeed the PTDS stream will leave no choice for the user to fallback to the legacy one if necessary for some reason, so perhaps it's good to still have that possibility regardless.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pentschev I just noticed in CuPy's test suite it's expected that when users do not specify any stream, this is true:

assert cuda.Stream.null == cuda.get_current_stream()

See https://github.com/cupy/cupy/blob/master/tests/cupy_tests/cuda_tests/test_stream.py. I guess this assumption might have propagated to some user codes as well? At least I just recalled I got one 😅 But we don't plan to use PTDS so it's probably OK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that's interesting, indeed I didn't see that before. I suppose this could pose a problem, so we should probably try to disambiguate that first. I would still prefer to have a possibility to enforce using Stream.null, even when PTDS is enabled, but I'm happy to change that if the preferred is not to.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is using Stream.null when PTDS is enabled allowed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is, but the semantics needs to be clearly defined (which is what we were discussing above -- should Stream.null mean stream 1 (legacy) or 2 (per-thread) when it's enabled?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed and the majority was in favor of leaving Stream.null as the legacy one.
Anyway, we can just proceed as it is now and revisit it later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for discussing that!

if null:
self.ptr = 0
self.ptr = runtime.cudaStreamDefault
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even after the long discussion we've been in (numba/numba#5162), I am still a bit uncertain whether it's better to choose runtime.cudaStreamDefault or runtime.cudaStreamLegacy here. I think the latter is the status quo?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically speaking, the cudaStreamDefault/cudaStreamNonBlocking are flags for cudaStreamCreateWithFlags, but it happens that cudaStreamDefault is also 0, which is consistent with the default behavior for when no streams are passed and is also used by many people, CuPy included. On the other hand cudaStreamLegacy and cudaStreamPerThread are stream handles and these are the ones that we should be using when specifying a stream explicitly.

Copy link
Member

@leofang leofang Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so this is exactly I was thinking: Don't we expect self.ptr to be a stream handle and so it should be runtime.cudaStreamLegacy?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. The default was 0, so I thought I would play it safe and avoid replacing that by cudaStreamLegacy for now. I believe the behavior should be the same though. Maybe it's a good time to replace and test that out, unless someone has any reasons not to do so that I may be overlooking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adjusted this now to runtime.cudaStreamLegacy, so far everything looks correct on my verification.

Comment on lines 457 to 458
elif stream_module.is_ptds_enabled():
setStream(handle, runtime.cudaStreamPerThread)
Copy link
Member

@leofang leofang Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that this extra branch (here and in several other places) is not necessary, though I didn't think too hard. The "current stream" in CuPy can simply refer to PTDS if it's enabled when launching a Python interpreter.

Copy link
Member

@leofang leofang Nov 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you already allow get_current_stream_ptr() to return PTDS if it's enabled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting we replace those two lines simply by the following?

setStream(handle, stream_module.get_current_stream_ptr())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it's already done above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't, note how above there's if stream_module.enable_current_stream. The enable_current_stream is only set when using an explicit stream. However, we could perhaps remove that conditional too, I'm not sure if that would work though, my understanding is that get_current_stream_ptr returns self.current_stream, which is undefined until setting a stream, or am I missing something else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I do think we should explicitly initialize it to True: note that it's a module attribute left uninitialized:

So when would it be False? If we're forcing an initialization to True, it becomes a completely unnecessary attribute, as I do see it being set to False anywhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it is completely unnecessary to have.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed it, adjusted the setStream functions and get_current_stream_ptr. I'm not very happy with the condition in https://github.com/cupy/cupy/pull/4322/files#diff-c7ca9ae40b6014a65e9d4420604c0e83adb8aadd77011ce04fd2026d74f43a01R59-R60 , it doesn't look nice to compare to <void*>0, and I'm not sure what is the most appropriate/preferred/cleanest way to do that with Cython.

Copy link
Member

@leofang leofang Nov 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_stream could be changed from void* to intptr_t IMHO so that we just compare with 0. I don't have strong opinion on this, but would like to note this is done in one of the two stream modules, and best to be unified when they are merged one way (#3584) or another (#4322 (comment)).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this now in ece5c55, seems that this change doesn't require modifications elsewhere.

@leofang
Copy link
Member

leofang commented Nov 25, 2020

As of now, I also don't know how to write tests to ensure that we aren't using the default stream when setting CUPY_CUDA_PER_THREAD_DEFAULT_STREAM=1, something I verified with the help of nsys and nvprof.

btw This was my impression too, but I just begin to wonder if it's possible to add a callback to the stream, and let it throw out something that we can capture in Python? CuPy already wraps the stream callback API:
https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html#cupy.cuda.Stream.add_callback
(Disclaimer: I've never tried this, I just read it from the Programming Guide 😅)

@kmaehashi kmaehashi self-assigned this Nov 25, 2020
@kmaehashi kmaehashi added cat:feature New features/APIs prio:medium labels Nov 25, 2020
@pentschev
Copy link
Member Author

  1. Why does the CUDA Programming Guide say we need to recompile an app with --default-stream, while there exists such a simple solution?!

Where does it say users need to recompile? In https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html both alternatives are presented. The advantage of compiling with --default-stream=per-thread is that you need no code changes whatsoever when you're using the default stream already, but that may be incompatible when interfacing with applications that do not compile the same way.

  1. Why moving to PTDS will bring speedup in some of your tests? Is it because the implicit syncs are removed? If so, wouldn't using non-blocking streams also do the job for you? I am very curious! 😛

They would, but using non-blocking streams can be very challenging for RAPIDS, and Dask more specifically, we don't want users to require doing so on their own, so PTDS may be a very convenient solution, although it's still unclear whether we will indeed see major improvements for the average workflow. Below is a Dask example with 1 stream (i.e., thread) vs 4 streams:

1 Dask worker with 1 thread
1 Dask worker with 4 threads

@jakirkham
Copy link
Member

Thanks for sharing that Peter!

It's interesting that sometimes a gap shows up on one of the streams. Do we know why that is? No worries if not. Just curious 🙂

@pentschev
Copy link
Member Author

It's interesting that sometimes a gap shows up on one of the streams. Do we know why that is? No worries if not. Just curious 🙂

The GPU has a limited amount of resources, most likely launching yet another concurrent kernel wouldn't fit due to one or multiple sets of resources (e.g., shared memory, registers, etc.).

@leofang
Copy link
Member

leofang commented Nov 26, 2020

As of now, I also don't know how to write tests to ensure that we aren't using the default stream when setting CUPY_CUDA_PER_THREAD_DEFAULT_STREAM=1, something I verified with the help of nsys and nvprof.

btw This was my impression too, but I just begin to wonder if it's possible to add a callback to the stream, and let it throw out something that we can capture in Python? CuPy already wraps the stream callback API:
https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html#cupy.cuda.Stream.add_callback
(Disclaimer: I've never tried this, I just read it from the Programming Guide 😅)

I ended up working on #4338 out of curiousity. Note this works:

import cupy as cp


def callback(data):
    print('I am called!')
    data.append('catch me if you can')

def do_work(stream):
    data = []
    with stream:
        // do time consuming work
        a = cp.random.random(100)
        a *= 3
    stream.launch_host_func(callback, data)
    stream.synchronize()
    assert data[0] == 'catch me if you can'
    assert len(data) == 1

s = cp.cuda.Stream()
do_work(s)
cp.cuda.Device().synchronize()
print('done')

I think the idea also applies if you change stream.launch_host_func() (to be added in #4338) to stream.add_callback(), with adjusting the callback function signature. So perhaps what we could do in the tests is to add callbacks to both legacy and per-thread default streams, check the returned data, and make sure only either of them is used but not both (depending on the test condition).

EDIT: fix the snippet

@pentschev
Copy link
Member Author

@leofang thanks for checking the callback idea, I'll try to write some tests using that for this PR.

@pentschev pentschev mentioned this pull request Nov 26, 2020
2 tasks
Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @pentschev! LGTM except for a few minor points. I think we are left with

  • Discussion on the Stream.null singleton
  • Tests

cupy/cuda/stream.pyx Outdated Show resolved Hide resolved
cupy_backends/cuda/api/runtime.pxd Outdated Show resolved Hide resolved
cupy_backends/cuda/libs/cudnn.pyx Outdated Show resolved Hide resolved
@pentschev
Copy link
Member Author

Certainly, thanks for the ping @leofang !

@pentschev
Copy link
Member Author

This seems to be either a bug or unsupported use case, I can also reproduce it with C++ code. I filed an internal bug ID 200699044 and will update here when I hear back from developers.

In the meantime, I think we'll need to discuss what's the best approach here. Both cudaStreamLegacy and cudaStreamPerThread fail in that test, which means we currently can't use any of them. Given we're replacing Stream.null.ptr with cudaStreamLegacy, that seems it will break any NCCL uses, but we could revert that to the null stream for the time being, any other ideas? It seems that the PTDS case won't work at all with NCCL at the moment, unfortunately.

@leofang
Copy link
Member

leofang commented Feb 4, 2021

It seems that the PTDS case won't work at all with NCCL at the moment, unfortunately.

Hi Peter, I did some manual tests and I think the "bug" is limited to when we initialize the communicator for all devices in the same process (viancclCommInitAll). But if the initialization is done by ncclCommInitRank with multiple processes, which is the most common use case IMHO, then it happily accepts stream 1 & 2. You can try this code with MPI:

# mpirun -n 2 python test_nccl.py

from mpi4py import MPI
import cupy as cp


comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

cp.cuda.Device(rank).use()

if rank == 0:
    n_id = cp.cuda.nccl.get_unique_id()
else:
    n_id = None
n_id = comm.bcast(n_id, root=0)
n_comm = cp.cuda.nccl.NcclCommunicator(size, n_id, rank)

a = cp.arange(10)
n_comm.allReduce(a.data.ptr, a.data.ptr,
                 10*a.dtype.itemsize, cp.cuda.nccl.NCCL_FLOAT64, cp.cuda.nccl.NCCL_SUM,
                 2)  # use stream 2
print(rank, a)

Perhaps we could simply print a big warning when initAll() is called to tell users to not use stream 1 & 2 until it's fixed by upstream?

@leofang
Copy link
Member

leofang commented Feb 4, 2021

(btw I found this as I was wondering if the bug happens only when send/recv is called, but it's not the case)

@pentschev
Copy link
Member Author

Thanks @leofang for further investigating this.

Perhaps we could simply print a big warning when initAll() is called to tell users to not use stream 1 & 2 until it's fixed by upstream?

I would be fine with that, but the part that becomes maybe syntactically wrong then is that we moved the meaning of Stream.null to cudaStreamLegacy (stream 1). So in a sense, the user would claim "I'm using the null stream with Stream.null". Particularly, I have no objections to either emitting a warning or moving back to Stream.null.ptr = 0 for the time being.

(btw I found this as I was wondering if the bug happens only when send/recv is called, but it's not the case)

Indeed, I observed the same in C++ with ReduceAll.

@emcastillo
Copy link
Member

Yes, let's revert to the null stream.
Thanks!

@leofang
Copy link
Member

leofang commented Feb 5, 2021

Particularly, I have no objections to either emitting a warning or moving back to Stream.null.ptr = 0 for the time being.

Indeed, it seems to be a good idea to revert to the status quo (stream 0 being null).

@pentschev
Copy link
Member Author

I reverted it back to streamDefault. I worked with NCCL developers and the issue has already been fixed, it was simply the case that special threads were just not checked for. That fix should be in the version after 2.8.3-1, so I added a TODO note in the code to change that in the future.

@pentschev
Copy link
Member Author

pfnCI, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 3cc00ce, target branch master) succeeded!

# TODO(pentschev): move to streamLegacy. This wasn't possible
# because of a NCCL bug that should be fixed in the version
# following 2.8.3-1.
self.ptr = runtime.streamDefault
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: streamDefault is a flag, not a stream macro

Suggested change
self.ptr = runtime.streamDefault
self.ptr = 0

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stop using my own arguments from #4322 (comment) against myself! 🤣

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ha ha, forgot about that one! 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's why you shouldn't write things that can be used against yourself, I trapped myself with that one! 😄

@leofang
Copy link
Member

leofang commented Feb 5, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 6ae421d, target branch master) succeeded!

@pentschev
Copy link
Member Author

I was informed that the fix for cudaStreamLegacy/cudaStreamPerThread should be in NCCL 2.9, which will be out probably in Q2 2021.

Copy link
Member

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @pentschev for pushing it through.

@emcastillo emcastillo added the st:test-and-merge (deprecated) Ready to merge after test pass. label Feb 8, 2021
@emcastillo emcastillo added this to the v9.0.0b3 milestone Feb 8, 2021
@mergify mergify bot merged commit c3f0ecb into cupy:master Feb 8, 2021
@pentschev
Copy link
Member Author

Thanks for help with reviews here @leofang and @emcastillo ! 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature New features/APIs prio:medium st:test-and-merge (deprecated) Ready to merge after test pass.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants