Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA 11.2: Support the built-in Stream Ordered Memory Allocator #4537

Merged
merged 25 commits into from Feb 5, 2021

Conversation

leofang
Copy link
Member

@leofang leofang commented Jan 10, 2021

While this is working, I mark it as Work in Progress as there are some issues to be discussed with our NVIDIA friends 🙂

May be blocked by #4443 (?)

This PR exposes CUDA's new Stream Ordered Memory Allocator added since 11.2 to CuPy. A new memory type, MemoryAsync, is added, which is backed by cudaMallocAsync() and cudaFreeAsync().

To use this feature, one simply sets the allocator to malloc_async, similar to what's done for managed memory:

import cupy as cp

cp.cuda.set_allocator(cp.cuda.malloc_async)
# from now on the memory is allocated on the current stream from Stream Ordered Memory Allocator

On older CUDA (<11.2) or unsupported devices, using this new allocator will raise an error at runtime.

(I didn't add the support with a customized mempool cudaMemPool*()/cudaMallocFromPoolAsync() -- which could be the next PR -- as it's unclear to me the benefit of using non-default mempools. Also, note that there is no API to expose any current information of the mempool, so it wouldn't be compatible with CuPy's MemoryPool API, such as used_bytes() etc.)

Currently observed issues

I think nothing is wrong with my implementation, most likely these are from CUDA 😁

  1. UPDATE: This is irrelevant of this PR, see CUDA 11.2: Support the built-in Stream Ordered Memory Allocator #4537 (comment) and CUDA 11.2: Fix empty NVRTC program name #4538.
  2. It is unclear from the CUDA documentation if a stream is allowed to be destroyed before all memory allocated on it is freed. It could be that the driver performs a ref count internally (so we don't have to), but we need to make sure. If it's not the case, then in MemoryAsync we will also need to hold the reference to the stream (object), not just its pointer.
  3. nvprof python my_script.py will fail if malloc_async is used in the workload:
$ nvprof --device-buffer-size 2048 --profiling-semaphore-pool-size 128000 pytest tests/cupy_tests/fft_tests/test_fft.py -k TestFFt
========================================================================= test session starts =========================================================================
platform linux -- Python 3.7.9, pytest-6.2.1, py-1.10.0, pluggy-0.13.1
rootdir: /home/leofang/dev/cupy_cuda112, configfile: setup.cfg
collecting ... ==31333== NVPROF is profiling process 31333, command: /home/leofang/miniconda3/envs/cupy_cuda112_dev/bin/python /home/leofang/miniconda3/envs/cupy_cuda112_dev/bin/pytest tests/cupy_tests/fft_tests/test_fft.py -k TestFFt
collected 717 items / 410 deselected / 307 selected                                                                                                                   

tests/cupy_tests/fft_tests/test_fft.py ..................................................==31333== Error: Internal profiling error 3938:999.
.....................................^C======== Warning: 569 records have invalid timestamps due to insufficient device buffer space. You can configure the buffer space using the option --device-buffer-size.
======== Warning: 293 records have invalid timestamps due to insufficient semaphore pool size. You can configure the pool size using the option --profiling-semaphore-pool-size.
======== Profiling result:
...

We need to confirm if this is nvprof's problem/limitation (very likely it is), as it could be annoying to our users.

TODO

  • Add tests
  • Add tutorial to docs/source/reference/memory.rst?
  • Fix/update docstrings
  • Mark it experimental (as I did)?

cc: @jakirkham @pentschev @maxpkatz Could you help address the three observed issues? 🙂

@leofang

This comment has been minimized.

@leofang
Copy link
Member Author

leofang commented Jan 10, 2021

invalid filename
!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus, file: !1, producer: "lgenfe: EDG 6.0", isOptimized: true, runtimeVersion: 0, emissionKind: NoDebug, enums: !2)
!1 = !DIFile(filename: "", directory: "/home/leofang/dev/cupy_cuda112")
invalid filename
!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus, file: !1, producer: "lgenfe: EDG 6.0", isOptimized: true, runtimeVersion: 0, emissionKind: NoDebug, enums: !2)
!1 = !DIFile(filename: "", directory: "/home/leofang/dev/cupy_cuda112")

This seems to be a CUDA 11.2 issue, irrelevant of the async allocator. My best guess is from NVRTC.

Bingo. It's due to the NVRTC program name being empty. UPDATE: #4538 fixes it.

@pentschev
Copy link
Member

@leofang thanks for working on this, I've raised the question about destroying the stream before all memory allocated on it is freed, I'll let you know as soon as I get an answer about that.

As for nvprof, could you also try Nsight Systems before we raise questions on nvprof? The former has been deprecated, so I'm not entirely sure whether we can expect new features such as the Stream Ordered Memory Allocator should work correctly on it, therefore it's probably best if we know whether this is a problem in both or nvprof only. Sorry for not checking it myself, but I don't have access to a machine with 11.2 at the moment.

@pentschev
Copy link
Member

@leofang the following is allowed:

cudaMallocAsync(&ptr, size, stream);
cudaStreamDestroy(stream);
cudaFreeAsync(ptr); // or cudaFree(ptr);

I believe you're fine on implementing it as you see fits best to CuPy.

@leofang leofang changed the title [WIP] Support CUDA 11.2 Stream Ordered Memory Allocator [WIP] CUDA 11.2: Support the built-in Stream Ordered Memory Allocator Jan 11, 2021
@jakirkham
Copy link
Member

cc @jrhemstad @harrism @nsakharnykh (who may have thoughts on this 🙂)

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

@leofang the following is allowed:

cudaMallocAsync(&ptr, size, stream);
cudaStreamDestroy(stream);
cudaFreeAsync(ptr); // or cudaFree(ptr);

@pentschev @jrhemstad I don't think this works. In Python I got cudaErrorInvalidResourceHandle while in C I got cudaErrorContextIsDestroyed. I am still trying to figure out why the error code is different, but this doesn't look good...

>>> import cupy as cp
>>> s = cp.cuda.Stream()
>>> s.ptr
94049875546000
>>> ptr = cp.cuda.runtime.mallocAsync(100, s.ptr)
>>> del s
>>> cp.cuda.runtime.freeAsync(ptr, 94049875546000)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "cupy_backends/cuda/api/runtime.pyx", line 638, in cupy_backends.cuda.api.runtime.freeAsync
    cpdef freeAsync(intptr_t ptr, intptr_t stream):
  File "cupy_backends/cuda/api/runtime.pyx", line 643, in cupy_backends.cuda.api.runtime.freeAsync
    check_status(status)
  File "cupy_backends/cuda/api/runtime.pyx", line 253, in cupy_backends.cuda.api.runtime.check_status
    raise CUDARuntimeError(status)
cupy_backends.cuda.api.runtime.CUDARuntimeError: cudaErrorInvalidResourceHandle: invalid resource handle

C:

// nvcc -std=c++11 -arch=sm_75 test_mallocAsync.cu -o test_mallocAsync
#include <cstdio>
#include <cstdlib>


int main() {
    double* a;
    cudaStream_t s;

    int error = cudaStreamCreate(&s);
    if (error != 0) {
        printf("failed!\n");
        exit(1);
    }
    printf("stream ok!\n");
    cudaDeviceSynchronize();

    error = cudaMallocAsync(&a, 100, s);
    if (error != 0) {
        printf("failed!\n");
        exit(1);
    }
    printf("malloc ok!\n");
    cudaDeviceSynchronize();

    error = cudaStreamDestroy(s);
    if (error != 0) {
        printf("failed!\n");
        exit(1);
    }
    printf("stream destory ok!\n");
    cudaDeviceSynchronize();

    error = cudaFreeAsync(a, s);
    if (error != 0) {
        printf("failed: %i ", error);
        printf("%s\n", cudaGetErrorString((cudaError_t)error));
        exit(1);
    }
    printf("free ok!\n");
    cudaDeviceSynchronize();

    return 0;
}

/* output: */
// stream ok!
// malloc ok!
// stream destory ok!
// failed: 709 context is destroyed

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

(If the stream is hold valid before calling cudaFreeAsync, everything works fine as expected.)

@jrhemstad
Copy link

This isn't valid.

cudaStreamCreate(&s);
cudaMallocAsync(p, s);
cudaStreamDestroy(s)
cudaFreeAsync(a, s);

You can't use a stream after it's been destroyed. There's nothing unique about cudaMallocAsync in this regard. The above would be the same as:

cudaStreamCreate(&s);
kernel<<<..., s>>>(...);
cudaStreamDestroy(s)
kernel<<<..., s>>>(...); // you can't do this

That's a use-after-free error. It's like trying to use a pointer after it's been freed.

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

cc: @maxpkatz

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

Thanks for quick reply, @jrhemstad! So, the only legit way to free a stream-ordered memory after the stream is destroyed (see Issue No.2 in the PR description) is to call cudaFree? If so, does it mean we should better always just call cudaFree instead of cudaFreeAsync?

@pentschev
Copy link
Member

pentschev commented Jan 11, 2021

@leofang sorry for my original answer in #4537 (comment), it is incomplete at best. In addition to cudaFree, you can also call cudaFreeAsync on a different stream that has been synchronized with that initially used for the allocation, but never on the stream that has been destroyed already.

EDIT: added comment on synchronization of both streams.

@jrhemstad
Copy link

jrhemstad commented Jan 11, 2021

So, the only legit way to free a stream-ordered memory after the stream is destroyed (see Issue No.2 in the PR description) is to call cudaFree?

No. You can free on another stream so long as it was somehow ordered with the original stream the pointer was allocated on. There's a number of ways you can do this. Here's one example using events:

cudaMallocAsync(p, s0);
kernel<<<..., s0>>>(p, ...);
cudaEventRecord(e, s0);
cudaStreamWaitEvent(s1, e);
cudaStreamDestroy(s0);
cudaFreeAsync(p, s1);

If so, does it mean we should better always just call cudaFree instead of cudaFreeAsync?

No, because you lose a lot of the benefit of using cudaMalloc/FreeAsync in the first place, e.g., avoiding the device sync implicit in cudaFree.

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

Ah...OK thanks for speedy replies @pentschev @jrhemstad!!! So to follow up further: If at some point all other non-default streams are destroyed, I can always call cudaFreeAsync on the legacy or per-thread default stream, or stream 0, right?

@jrhemstad
Copy link

I can always call cudaFreeAsync on the legacy or per-thread default stream, or stream 0, right?

You technically can with the legacy default stream as it implicitly synchronous with other streams (except non-blocking streams!), but I don't suggest going this route. You cannot with the per-thread default stream as it is not implicitly synchronous with other streams.

@leofang
Copy link
Member Author

leofang commented Jan 11, 2021

OK @jrhemstad I think this is the final question (I hope!)

You can free on another stream so long as it was somehow ordered with the original stream the pointer was allocated on.

You cannot with the per-thread default stream as it is not implicitly synchronous with other streams.

So does it mean the driver does not guarantee the correct stream order and it's users' responsibility to ensure it (via cudaStreamSynchronize or other means)?

For example, for the case I just described (no other streams are alive), if before I destroy the stream on which the memory was allocated I add an event to wait on the PTDS, I can then free it on the PTDS right? Something like

cudaMallocAsync(&x, 1024, s);
// do something with x on s...
cudaEventCreate(&e);
cudaEventRecord(e, s);  
cudaStreamDestroy(s);
cudaStreamWaitEvent(2, e, 0);  // wait on PTDS (2)
cudaFreeAsync(x, 2);  // free on PTDS

(UPDATE use 2 for PTDS...)

@jrhemstad
Copy link

So does it mean the driver does not guarantee the correct stream order and it's users' responsibility to ensure it (via cudaStreamSynchronize or other means)?

Yes. The per-thread default stream is effectively the same as any other stream.

For example, for the case I just described (no other streams are alive), if before I destroy the stream on which the memory was allocated I add an event to wait on the PTDS, I can then free it on the PTDS right?

Yep, that works. That said, event creation and even recording events is relatively expensive, so it's not something you want to do all the time if you can avoid it.

@harrism wrote a nice benchmark of event overheads recently: https://github.com/harrism/cuda_event_benchmark

@leofang leofang marked this pull request as ready for review February 1, 2021 18:11
@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

@harrism wrote a nice benchmark of event overheads recently: https://github.com/harrism/cuda_event_benchmark

I recommend checking out this blog post we wrote about RAPIDS Memory Manager: https://developer.nvidia.com/blog/fast-flexible-allocation-for-cuda-with-rapids-memory-manager/

Thanks for the pointers @jrhemstad @harrism! It is very nice to see these performance benchmarks. I think we should add RMM to CuPy's interoperable list 🙂

@emcastillo I split the PR into two, so this PR only exposes the necessary APIs (async malloc & free). In the next PR (#4592) the handling of cudaMemPool_t will be added. This PR should be ready for review now.

@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 325571a, target branch master) failed with status FAILURE.

@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

CI errored out due to No space left on device. cc: @kmaehashi

@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit 325571a, target branch master) failed with status FAILURE.

@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit f21fe82, target branch master) failed with status FAILURE.

@leofang
Copy link
Member Author

leofang commented Feb 1, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit f2f8f4f, target branch master) failed with status FAILURE.

@leofang
Copy link
Member Author

leofang commented Feb 2, 2021

Jenkins, test this please

@leofang
Copy link
Member Author

leofang commented Feb 2, 2021

Jenkins never started....?

Jenkins, test this please

@leofang
Copy link
Member Author

leofang commented Feb 2, 2021

pfnCI, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit f2f8f4f, target branch master) failed with status FAILURE.

@leofang
Copy link
Member Author

leofang commented Feb 3, 2021

CI failures are known and unrelated.

@leofang
Copy link
Member Author

leofang commented Feb 4, 2021

Jenkins, test this please

@chainer-ci
Copy link
Member

Jenkins CI test (for commit f2f8f4f, target branch master) succeeded!

@emcastillo emcastillo added the st:test-and-merge (deprecated) Ready to merge after test pass. label Feb 5, 2021
@emcastillo emcastillo added this to the v9.0.0b3 milestone Feb 5, 2021
@mergify mergify bot merged commit 5828de0 into cupy:master Feb 5, 2021
@leofang leofang deleted the cuda112_mempool branch February 5, 2021 01:32
@leofang
Copy link
Member Author

leofang commented Feb 5, 2021

Thanks @emcastillo and all!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature New features/APIs prio:medium st:test-and-merge (deprecated) Ready to merge after test pass.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants