Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for async memory pool allocations in the CUDA HAL. #13440

Merged
merged 1 commit into from
Jun 8, 2023

Conversation

benvanik
Copy link
Collaborator

@benvanik benvanik commented May 6, 2023

These aren't actually async as the CUDA HAL is synchronous but will make use of CUDA's memory pooling features to reduce the alloc/free cost in a way more friendly to CUDA's memory management than just the normal allocator.

With this our queue-ordered allocations in CUDA now average a few us (for me) and the caching allocator (or any other) is only needed for caching non-queue-ordered allocations. A few compiler tweaks to switch all allocations to queue-ordered will mean only explicitly allocated buffers (constants/variables, stuff the user does manually, etc) will not route to the pools.

It'd also be possible to explore using the same pools the queue-ordered allocations use for explicit synchronous allocations (at least the device-local pool) but it'd be nicer to get those out of the critical path and then keep the pools separate such that the transient pool isn't filled with persistent allocations.

Due to #13984 this relies on the --iree-stream-emulate-memset flag added in #13994 being set when graphs are enabled. Since this is not the default path today and there's just two test suites using it we just flip the flag for them.

@benvanik benvanik added performance ⚡ Performance/optimization related work across the compiler and runtime hal/cuda Runtime CUDA HAL backend benchmarks:cuda Run default CUDA benchmarks labels May 6, 2023
@github-actions
Copy link

github-actions bot commented May 6, 2023

Abbreviated Benchmark Summary

@ commit 68c15298a6ca7937a146b84ae788776e53e9abd2 (no previous benchmark results to compare)

Raw Latencies

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
BertForMaskedLMTF(stablehlo) [cuda-sm\_80-linux\_gnu-cuda][default-flags] cuda(none)[full-inference,default-flags] with zeros @ a2-highgpu-1g[gpu] 7.124 7.091 0.102
BertLargeTF(stablehlo) [cuda-sm\_80-linux\_gnu-cuda][default-flags] cuda(none)[full-inference,default-flags] with zeros @ a2-highgpu-1g[gpu] 10.579 10.579 0.002
BertLargefp16PTBatch1(linalg) [cuda-sm\_80-linux\_gnu-cuda][default-flags] cuda(none)[full-inference,default-flags] with zeros @ a2-highgpu-1g[gpu] 5.889 5.873 0.053

[Top 3 out of 18 results showed]

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

@pjannaty
Copy link
Contributor

Hi @benvanik, could you break this into approachable steps, initially targeting an MVP of sorts as prove out? If this is more parcelable I may be able to staff it.

@benvanik
Copy link
Collaborator Author

This is the first step! There's two failing tests for an unknown reason - if someone can help diagnose the internal driver error that'd unblock this landing and get us using memory pools. The stream-based execution we do seems to work but graph-based execution (--cuda_use_streams=false to iree-run-module & co) fails in cuGraphAddMemsetNode with CUDA_ERROR_INVALID_VALUE - lots of tests seem to work and they're likely doing the same thing, though, so I'm not sure what the root causes is in these two tests given that the parameters to the memset should be the same we use on the stream-based implementation. Maybe graphs perform additional validation or have stricter rules?

https://github.com/openxla/iree/actions/runs/4902664048/jobs/8754727897?pr=13440

e.g.

 561/1664 Test  #559: iree/tests/e2e/xla_ops/check_cuda_graph_convolution.mlir ...................................................***Failed    3.71 sec
[==========] Running 10 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 10 tests from module
[ RUN      ] module.conv2d_nopadding
[       OK ] module.conv2d_nopadding (677 ms)
[ RUN      ] module.conv2d_nopadding_batch_feature
[       OK ] module.conv2d_nopadding_batch_feature (345 ms)
[ RUN      ] module.conv2d_reorder_input_spatial
[       OK ] module.conv2d_reorder_input_spatial (345 ms)
[ RUN      ] module.conv2d_reorder_kernel
[       OK ] module.conv2d_reorder_kernel (325 ms)
[ RUN      ] module.conv2d_reorder_output
[       OK ] module.conv2d_reorder_output (312 ms)
[ RUN      ] module.conv2d_1452x3221_same
/work/tools/iree-check-module-main.cc:68: Failure
Value of: iree_vm_invoke(context_, function_, IREE_VM_INVOCATION_FLAG_NONE, nullptr, nullptr, nullptr, iree_vm_instance_allocator(instance_))
Expected: error code OK
  Actual: 0x21ca1ad, whose error code is INTERNAL: /work/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c:406: INTERNAL; CUDA driver error 'CUDA_ERROR_INVALID_VALUE' (1): invalid argument; cuGraphAddMemsetNode; while invoking native function hal.command_buffer.fill_buffer; while calling import; 
[ 1]   native hal.command_buffer.fill_buffer:0 -
[ 0] bytecode module.conv2d_1452x3221_same:742 /work/tests/e2e/xla_ops/convolution.mlir:182:10
      at /work/tests/e2e/xla_ops/convolution.mlir:172:1
[  FAILED  ] module.conv2d_1452x3221_same, where GetParam() = 5 (265 ms)
[ RUN      ] module.conv2d_2451x2311_same
/work/tools/iree-check-module-main.cc:68: Failure
Value of: iree_vm_invoke(context_, function_, IREE_VM_INVOCATION_FLAG_NONE, nullptr, nullptr, nullptr, iree_vm_instance_allocator(instance_))
Expected: error code OK
  Actual: 0x2923a8d, whose error code is INTERNAL: /work/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c:406: INTERNAL; CUDA driver error 'CUDA_ERROR_INVALID_VALUE' (1): invalid argument; cuGraphAddMemsetNode; while invoking native function hal.command_buffer.fill_buffer; while calling import; 
[ 1]   native hal.command_buffer.fill_buffer:0 -
[ 0] bytecode module.conv2d_2451x2311_same:736 /work/tests/e2e/xla_ops/convolution.mlir:221:10
      at /work/tests/e2e/xla_ops/convolution.mlir:208:1
[  FAILED  ] module.conv2d_2451x2311_same, where GetParam() = 6 (314 ms)
[ RUN      ] module.conv2d_no_padding2
[       OK ] module.conv2d_no_padding2 (360 ms)
[ RUN      ] module.conv2d_1452x2223_dilated_valid
[       OK ] module.conv2d_1452x2223_dilated_valid (377 ms)
[ RUN      ] module.depthwise_conv_non_1_channel_multiplier
[       OK ] module.depthwise_conv_non_1_channel_multiplier (367 ms)
[----------] 10 tests from module (3689 ms total)

[----------] Global test environment tear-down
[==========] 10 tests from 1 test suite ran. (3689 ms total)
[  PASSED  ] 8 tests.
[  FAILED  ] 2 tests, listed below:
[  FAILED  ] module.conv2d_1452x3221_same, where GetParam() = 5
[  FAILED  ] module.conv2d_2451x2311_same, where GetParam() = 6

 2 FAILED TESTS
Test failed

@pjannaty
Copy link
Contributor

pjannaty commented Jun 5, 2023

Sorry for the delay. I can confirm the issue persists in cuda 12.0. Asking...

@pjannaty
Copy link
Contributor

pjannaty commented Jun 5, 2023

So far I have this hint:

Not sure if this applies in this situation, but the cuMemset* APIs return CUDA_SUCCESS for 0-byte memsets. cuGraphAddMemsetNode will return CUDA_ERROR_INVALID_VALUE for a 0-byte memset.

Are we sure the sizes are not trivially zero?

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 5, 2023

Ahh that may be it! I'll fix the code and give it a shot! Thank you!

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 5, 2023

Interesting, looks like there's a legit error happening and the device pointers we are passing in look bogus.

[ RUN      ] module.conv2d_2451x2311_same
copy_buffer
  srcDevice 0x7f0363200000
  srcXInBytes 0
  dstDevice 0x7f0363400000
  dstXInBytes 0
  WidthInBytes 3264
fill_buffer
  dst 0x302000000
  element size 1
  width 280
  height 1
  value 0 00000000
/work/tools/iree-check-module-main.cc:68: Failure
Value of: iree_vm_invoke(context_, function_, IREE_VM_INVOCATION_FLAG_NONE, nullptr, nullptr, nullptr, iree_vm_instance_allocator(instance_))
Expected: error code OK
  Actual: 0xe76b8d, whose error code is INTERNAL: /work/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c:417: INTERNAL; CUDA driver error 'CUDA_ERROR_INVALID_VALUE' (1): invalid argument; cuGraphAddMemsetNode; while invoking native function hal.command_buffer.fill_buffer; while calling import; 

The 0x302000000 address shows up in all other failing tests - so something is corrupting it or the code path deriving it is busted.
This is only happening on the graph path and only for a few tests - I don't think we run CUDA + ASAN but I'll see if I can add more logging to track it down and audit the graph command buffer code.

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 5, 2023

(reproduced it locally, which helps a lot!)

@pjannaty
Copy link
Contributor

pjannaty commented Jun 6, 2023

Ok perfect. On our side we're out of ideas currently, but if you get new signals I'll relay.

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 6, 2023

Thanks! That tip was really helpful in tracking down the root cause - silly mistake on my part (was possible to end up issuing a cuMemFree after a cuMemFreeAsync, which is obviously bad :) and I'm going to rework things to prevent that. It's nice that graphs perform that validation and I'm excited to have those on by default so we can catch this kind of error earlier!

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 6, 2023

Uh oh, still hitting the issue. @pjannaty could you please check my understanding that calling cuMemAllocAsync and then using the resulting pointer when constructing a graph (as an arg to cuGraphAddMemsetNode) is legal? If I use streams things work fine but graphs explode even with the same pointers returned across process launches, same params to the node or cuMemsetD32Async call, and same execution order. I'd have expected the virtual address returned from the cuMemAllocAsync to be legal to encode in graphs even without synchronization but tried sprinkling inserting cuStreamSynchronize around to no avail. Allocations returned from cuMemAlloc work fine, and both async allocations from the default pool (cuMemAllocAsync) and my own (cuMemAllocFromPoolAsync) fail.

This memset passes with the params on the stream implementation with both sync and async allocations and on the graph implementation when using normal cuMemAlloc but not async:

  CUdeviceptr device_ptr_from_async_alloc = 0;
  cuMemAllocAsync(&device_ptr_from_async_alloc, 192, stream);
  // cuStreamSynchronize(stream); -- doesn't help
  CUDA_MEMSET_NODE_PARAMS params = {
      .dst = device_ptr_from_async_alloc,  // 0000000204E00000, 192b
      .elementSize = 4,
      .pitch = 0,  // unused with height == 1
      .width = 144 / 4,  // element count = 36
      .height = 1,
      .value = 0,
  };
  // returns CUDA_ERROR_INVALID_VALUE
  cuGraphAddMemsetNode(NULL, graph, NULL, 0, &params, context);

I tried changing all the params (elementSize=1/width=1, elementSize=4/width=1, etc) and all returned the same error so I'm pretty sure it's the pointer being the issue.

Thoughts on what to try next? Everything else except memset seems to be fine so I'm at a loss as to why that particular call is having issues. I'm getting this issue on Windows CUDA 12.1 / driver 532.03 but we also see this on the A100 bot with CUDA 12.1 / driver 530.41.03.

@pjannaty
Copy link
Contributor

pjannaty commented Jun 6, 2023

We think passing phGraphNode as NULL is causing the explicit graph api to return the error invalid value.
Are you trying to add a new node to a captured graph? But how can you return a newly created node if he passed in NULL?
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g89dc8fc3743392777c0daa2c4aca40d3

Also the switch between the stream call there and the explicit graph api is weird. Also what do we mean by "allocations from cuMemAlloc work fine"? Because the invalid value should be coming from passing NULL for the return value of the graph node.

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 6, 2023

oh whoops sorry, that was just a typo in the example - in the real code it is passed a value but I was just trying to keep it simple:

  CUDA_RETURN_IF_ERROR(
      command_buffer->context->syms,
      cuGraphAddMemsetNode(&command_buffer->last_node, command_buffer->graph,
                           dep, numNode, &params,
                           command_buffer->context->cu_context),
      "cuGraphAddMemsetNode");

so unfortunately that's not the issue ;(

(and to clarify: this works just fine if passing a pointer as returned from cuMemAlloc and fails if passed a pointer from cuMemAllocAsync with no other changes - just who allocated the memory)

@pjannaty
Copy link
Contributor

pjannaty commented Jun 7, 2023

@benvanik, from a cuda-graph expert of ours:

Some validation logic inside cuGraphAddMemsetNode missed handling the non-graph async allocator case.
The logic works correctly for allocations that are part of the graph. The stream capture also doesn't hit this issue. The error only appears for asynchronous allocations made outside of the graph and then referenced by the cuGraphAddMemsetNode operation.
I've made a local reproducer & the fix should be pretty simple...

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 7, 2023

Wow, great! Thank you! I think for now since we don't have graphs enabled by default we could disable the tests so that we can get the stream ordered allocation goodness. Since users may be allocating buffers from code we don't control in the compiler being able to mix stream-ordered calls with graph calls will always be a thing, but knowing it'll work is more important than it working today. If at any point there's beta drivers on Windows I'm happy to run more tests for verification!

@benvanik benvanik force-pushed the benvanik-emulate-splat-fill branch from 1384be0 to 02c6370 Compare June 7, 2023 23:41
@benvanik benvanik requested a review from antiagainst June 7, 2023 23:45
@benvanik benvanik force-pushed the benvanik-emulate-splat-fill branch from 02c6370 to bd597a1 Compare June 8, 2023 00:20
@benvanik benvanik marked this pull request as ready for review June 8, 2023 01:40
@pjannaty
Copy link
Contributor

pjannaty commented Jun 8, 2023

It'll take some time for us to confirm which version will have the fix. In the meantime, we have identified 3 potential WARs, one of which we are ready to share:

Essentially the WAR is to make the asynchronous allocation part of the graph. (eg. use cuGraphAddMemAllocNode). Although this appears as a design change, might as well be a better approach anyways.

Also created an internal bug that I'm recording the number here so I don't forget: 4151197.

@@ -24,6 +24,7 @@ typedef struct iree_hal_cuda_allocator_t {
iree_hal_cuda_context_wrapper_t* context;
CUdevice device;
CUstream stream;
iree_hal_cuda_memory_pools_t* pools;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd assume eventually we'll allocate from this pools for this "unpooled" allocator too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my thought! we can make the pools arbitrarily complex (add as many as we want, segregate resources by lifetime, based on usage requirements as export across process requires a different type of pool, etc). I've got a few fixes incoming that switch nearly all our allocations besides variables/constants and whatever the user decides to do to queue-ordered so hopefully our usage of the base allocator will be for weird stuff (import/export, etc)!

@benvanik
Copy link
Collaborator Author

benvanik commented Jun 8, 2023

Thanks for tracking this @pjannaty!

Reliably getting the allocations into the graphs isn't possible as the code doing the allocations may not even be our own, such as a user allocating a buffer to scribble numpy arrays into or something at the python level, vs. the code we control generated by the compiler where we can potentially get things fused together. There's still issues with that around lifetime management, graph reuse, child graphs, and allocations spanning multiple graphs, and multiple programs in the same process that may pass buffers between each other. I do think there's probably some we could absorb but for now since we can't solve all cases with that the workaround I went with is turning memset into kernel launches instead of using the graph nodes. That's probably throwing performance on the table (can't use the DMA engines, etc) but I haven't measured it yet. Thankfully very few of our programs need memset so this isn't too much of a loss today, but will be nice to roll back to using the proper CUDA primitives once an update is available! (tracking in #13984 because I tend to forget these things - I also forget issues, but there's a better chance of someone else reminding me :)

Base automatically changed from benvanik-emulate-splat-fill to main June 8, 2023 05:08
These aren't actually async as the CUDA HAL is synchronous but will
make use of CUDA's memory pooling features to reduce the alloc/free
cost in a way more friendly to CUDA's memory management than just
the normal allocator.
@benvanik benvanik merged commit 8f9e962 into main Jun 8, 2023
49 checks passed
@benvanik benvanik deleted the benvanik-cuda-async-alloc branch June 8, 2023 15:44
NatashaKnk pushed a commit to NatashaKnk/iree that referenced this pull request Jul 6, 2023
…ee-org#13440)

These aren't actually async as the CUDA HAL is synchronous but will make
use of CUDA's memory pooling features to reduce the alloc/free cost in a
way more friendly to CUDA's memory management than just the normal
allocator.

With this our queue-ordered allocations in CUDA now average a few us
(for me) and the caching allocator (or any other) is only needed for
caching non-queue-ordered allocations. A few compiler tweaks to switch
all allocations to queue-ordered will mean only explicitly allocated
buffers (constants/variables, stuff the user does manually, etc) will
not route to the pools.

It'd also be possible to explore using the same pools the queue-ordered
allocations use for explicit synchronous allocations (at least the
device-local pool) but it'd be nicer to get those out of the critical
path and then keep the pools separate such that the transient pool isn't
filled with persistent allocations.

Due to iree-org#13984 this relies on the `--iree-stream-emulate-memset` flag
added in iree-org#13994 being set when graphs are enabled. Since this is not the
default path today and there's just two test suites using it we just
flip the flag for them.
nhasabni pushed a commit to plaidml/iree that referenced this pull request Aug 24, 2023
…ee-org#13440)

These aren't actually async as the CUDA HAL is synchronous but will make
use of CUDA's memory pooling features to reduce the alloc/free cost in a
way more friendly to CUDA's memory management than just the normal
allocator.

With this our queue-ordered allocations in CUDA now average a few us
(for me) and the caching allocator (or any other) is only needed for
caching non-queue-ordered allocations. A few compiler tweaks to switch
all allocations to queue-ordered will mean only explicitly allocated
buffers (constants/variables, stuff the user does manually, etc) will
not route to the pools.

It'd also be possible to explore using the same pools the queue-ordered
allocations use for explicit synchronous allocations (at least the
device-local pool) but it'd be nicer to get those out of the critical
path and then keep the pools separate such that the transient pool isn't
filled with persistent allocations.

Due to iree-org#13984 this relies on the `--iree-stream-emulate-memset` flag
added in iree-org#13994 being set when graphs are enabled. Since this is not the
default path today and there's just two test suites using it we just
flip the flag for them.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarks:cuda Run default CUDA benchmarks hal/cuda Runtime CUDA HAL backend performance ⚡ Performance/optimization related work across the compiler and runtime
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants