-
Notifications
You must be signed in to change notification settings - Fork 560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for async memory pool allocations in the CUDA HAL. #13440
Conversation
Abbreviated Benchmark Summary@ commit 68c15298a6ca7937a146b84ae788776e53e9abd2 (no previous benchmark results to compare) Raw Latencies
[Top 3 out of 18 results showed] No improved or regressed compilation metrics 🏖️ For more information: |
Hi @benvanik, could you break this into approachable steps, initially targeting an MVP of sorts as prove out? If this is more parcelable I may be able to staff it. |
This is the first step! There's two failing tests for an unknown reason - if someone can help diagnose the internal driver error that'd unblock this landing and get us using memory pools. The stream-based execution we do seems to work but graph-based execution ( https://github.com/openxla/iree/actions/runs/4902664048/jobs/8754727897?pr=13440 e.g.
|
Sorry for the delay. I can confirm the issue persists in cuda 12.0. Asking... |
So far I have this hint:
Are we sure the sizes are not trivially zero? |
Ahh that may be it! I'll fix the code and give it a shot! Thank you! |
4581c71
to
cd87156
Compare
Interesting, looks like there's a legit error happening and the device pointers we are passing in look bogus.
The 0x302000000 address shows up in all other failing tests - so something is corrupting it or the code path deriving it is busted. |
(reproduced it locally, which helps a lot!) |
Ok perfect. On our side we're out of ideas currently, but if you get new signals I'll relay. |
Thanks! That tip was really helpful in tracking down the root cause - silly mistake on my part (was possible to end up issuing a cuMemFree after a cuMemFreeAsync, which is obviously bad :) and I'm going to rework things to prevent that. It's nice that graphs perform that validation and I'm excited to have those on by default so we can catch this kind of error earlier! |
4321928
to
09aa3ac
Compare
Uh oh, still hitting the issue. @pjannaty could you please check my understanding that calling cuMemAllocAsync and then using the resulting pointer when constructing a graph (as an arg to cuGraphAddMemsetNode) is legal? If I use streams things work fine but graphs explode even with the same pointers returned across process launches, same params to the node or cuMemsetD32Async call, and same execution order. I'd have expected the virtual address returned from the cuMemAllocAsync to be legal to encode in graphs even without synchronization but tried sprinkling inserting cuStreamSynchronize around to no avail. Allocations returned from cuMemAlloc work fine, and both async allocations from the default pool (cuMemAllocAsync) and my own (cuMemAllocFromPoolAsync) fail. This memset passes with the params on the stream implementation with both sync and async allocations and on the graph implementation when using normal cuMemAlloc but not async: CUdeviceptr device_ptr_from_async_alloc = 0;
cuMemAllocAsync(&device_ptr_from_async_alloc, 192, stream);
// cuStreamSynchronize(stream); -- doesn't help
CUDA_MEMSET_NODE_PARAMS params = {
.dst = device_ptr_from_async_alloc, // 0000000204E00000, 192b
.elementSize = 4,
.pitch = 0, // unused with height == 1
.width = 144 / 4, // element count = 36
.height = 1,
.value = 0,
};
// returns CUDA_ERROR_INVALID_VALUE
cuGraphAddMemsetNode(NULL, graph, NULL, 0, ¶ms, context); I tried changing all the params (elementSize=1/width=1, elementSize=4/width=1, etc) and all returned the same error so I'm pretty sure it's the pointer being the issue. Thoughts on what to try next? Everything else except memset seems to be fine so I'm at a loss as to why that particular call is having issues. I'm getting this issue on Windows CUDA 12.1 / driver 532.03 but we also see this on the A100 bot with CUDA 12.1 / driver 530.41.03. |
09aa3ac
to
f7d20cf
Compare
We think passing phGraphNode as NULL is causing the explicit graph api to return the error invalid value. Also the switch between the stream call there and the explicit graph api is weird. Also what do we mean by "allocations from cuMemAlloc work fine"? Because the invalid value should be coming from passing NULL for the return value of the graph node. |
oh whoops sorry, that was just a typo in the example - in the real code it is passed a value but I was just trying to keep it simple:
so unfortunately that's not the issue ;( (and to clarify: this works just fine if passing a pointer as returned from cuMemAlloc and fails if passed a pointer from cuMemAllocAsync with no other changes - just who allocated the memory) |
@benvanik, from a cuda-graph expert of ours:
|
Wow, great! Thank you! I think for now since we don't have graphs enabled by default we could disable the tests so that we can get the stream ordered allocation goodness. Since users may be allocating buffers from code we don't control in the compiler being able to mix stream-ordered calls with graph calls will always be a thing, but knowing it'll work is more important than it working today. If at any point there's beta drivers on Windows I'm happy to run more tests for verification! |
f7d20cf
to
ef276ed
Compare
ef276ed
to
1438fe4
Compare
bd1c008
to
1384be0
Compare
1384be0
to
02c6370
Compare
1438fe4
to
d546dac
Compare
02c6370
to
bd597a1
Compare
d546dac
to
a923578
Compare
It'll take some time for us to confirm which version will have the fix. In the meantime, we have identified 3 potential WARs, one of which we are ready to share: Essentially the WAR is to make the asynchronous allocation part of the graph. (eg. use cuGraphAddMemAllocNode). Although this appears as a design change, might as well be a better approach anyways. Also created an internal bug that I'm recording the number here so I don't forget: 4151197. |
@@ -24,6 +24,7 @@ typedef struct iree_hal_cuda_allocator_t { | |||
iree_hal_cuda_context_wrapper_t* context; | |||
CUdevice device; | |||
CUstream stream; | |||
iree_hal_cuda_memory_pools_t* pools; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd assume eventually we'll allocate from this pools
for this "unpooled" allocator too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's my thought! we can make the pools arbitrarily complex (add as many as we want, segregate resources by lifetime, based on usage requirements as export across process requires a different type of pool, etc). I've got a few fixes incoming that switch nearly all our allocations besides variables/constants and whatever the user decides to do to queue-ordered so hopefully our usage of the base allocator will be for weird stuff (import/export, etc)!
Thanks for tracking this @pjannaty! Reliably getting the allocations into the graphs isn't possible as the code doing the allocations may not even be our own, such as a user allocating a buffer to scribble numpy arrays into or something at the python level, vs. the code we control generated by the compiler where we can potentially get things fused together. There's still issues with that around lifetime management, graph reuse, child graphs, and allocations spanning multiple graphs, and multiple programs in the same process that may pass buffers between each other. I do think there's probably some we could absorb but for now since we can't solve all cases with that the workaround I went with is turning memset into kernel launches instead of using the graph nodes. That's probably throwing performance on the table (can't use the DMA engines, etc) but I haven't measured it yet. Thankfully very few of our programs need memset so this isn't too much of a loss today, but will be nice to roll back to using the proper CUDA primitives once an update is available! (tracking in #13984 because I tend to forget these things - I also forget issues, but there's a better chance of someone else reminding me :) |
These aren't actually async as the CUDA HAL is synchronous but will make use of CUDA's memory pooling features to reduce the alloc/free cost in a way more friendly to CUDA's memory management than just the normal allocator.
a923578
to
258f36d
Compare
…ee-org#13440) These aren't actually async as the CUDA HAL is synchronous but will make use of CUDA's memory pooling features to reduce the alloc/free cost in a way more friendly to CUDA's memory management than just the normal allocator. With this our queue-ordered allocations in CUDA now average a few us (for me) and the caching allocator (or any other) is only needed for caching non-queue-ordered allocations. A few compiler tweaks to switch all allocations to queue-ordered will mean only explicitly allocated buffers (constants/variables, stuff the user does manually, etc) will not route to the pools. It'd also be possible to explore using the same pools the queue-ordered allocations use for explicit synchronous allocations (at least the device-local pool) but it'd be nicer to get those out of the critical path and then keep the pools separate such that the transient pool isn't filled with persistent allocations. Due to iree-org#13984 this relies on the `--iree-stream-emulate-memset` flag added in iree-org#13994 being set when graphs are enabled. Since this is not the default path today and there's just two test suites using it we just flip the flag for them.
…ee-org#13440) These aren't actually async as the CUDA HAL is synchronous but will make use of CUDA's memory pooling features to reduce the alloc/free cost in a way more friendly to CUDA's memory management than just the normal allocator. With this our queue-ordered allocations in CUDA now average a few us (for me) and the caching allocator (or any other) is only needed for caching non-queue-ordered allocations. A few compiler tweaks to switch all allocations to queue-ordered will mean only explicitly allocated buffers (constants/variables, stuff the user does manually, etc) will not route to the pools. It'd also be possible to explore using the same pools the queue-ordered allocations use for explicit synchronous allocations (at least the device-local pool) but it'd be nicer to get those out of the critical path and then keep the pools separate such that the transient pool isn't filled with persistent allocations. Due to iree-org#13984 this relies on the `--iree-stream-emulate-memset` flag added in iree-org#13994 being set when graphs are enabled. Since this is not the default path today and there's just two test suites using it we just flip the flag for them.
These aren't actually async as the CUDA HAL is synchronous but will make use of CUDA's memory pooling features to reduce the alloc/free cost in a way more friendly to CUDA's memory management than just the normal allocator.
With this our queue-ordered allocations in CUDA now average a few us (for me) and the caching allocator (or any other) is only needed for caching non-queue-ordered allocations. A few compiler tweaks to switch all allocations to queue-ordered will mean only explicitly allocated buffers (constants/variables, stuff the user does manually, etc) will not route to the pools.
It'd also be possible to explore using the same pools the queue-ordered allocations use for explicit synchronous allocations (at least the device-local pool) but it'd be nicer to get those out of the critical path and then keep the pools separate such that the transient pool isn't filled with persistent allocations.
Due to #13984 this relies on the
--iree-stream-emulate-memset
flag added in #13994 being set when graphs are enabled. Since this is not the default path today and there's just two test suites using it we just flip the flag for them.